gengeneeval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geneval/testing.py ADDED
@@ -0,0 +1,393 @@
1
+ """
2
+ Testing utilities for GenEval.
3
+
4
+ This module provides mock data generators and testing helpers
5
+ that users can use to test their own integrations with GenEval.
6
+
7
+ Examples
8
+ --------
9
+ >>> from geneval.testing import MockDataGenerator
10
+ >>>
11
+ >>> # Generate synthetic paired datasets
12
+ >>> generator = MockDataGenerator(n_samples=100, n_genes=50, seed=42)
13
+ >>> real, generated = generator.generate_paired_data(noise_level=0.3)
14
+ >>>
15
+ >>> # Use with evaluation
16
+ >>> from geneval import evaluate
17
+ >>> results = evaluate(
18
+ ... real_data=real,
19
+ ... generated_data=generated,
20
+ ... condition_columns=["perturbation"],
21
+ ... )
22
+ """
23
+ from __future__ import annotations
24
+
25
+ from pathlib import Path
26
+ from typing import Dict, List, Optional, Tuple, Union
27
+ import numpy as np
28
+
29
+ try:
30
+ import anndata as ad
31
+ HAS_ANNDATA = True
32
+ except ImportError:
33
+ HAS_ANNDATA = False
34
+
35
+
36
+ class MockDataGenerator:
37
+ """
38
+ Generator for synthetic gene expression data.
39
+
40
+ Creates realistic-looking gene expression data with perturbation
41
+ and cell type effects for testing evaluation pipelines.
42
+
43
+ Parameters
44
+ ----------
45
+ n_samples : int
46
+ Number of samples to generate.
47
+ n_genes : int
48
+ Number of genes.
49
+ n_perturbations : int
50
+ Number of different perturbation conditions.
51
+ n_cell_types : int
52
+ Number of different cell types.
53
+ seed : int, optional
54
+ Random seed for reproducibility.
55
+
56
+ Examples
57
+ --------
58
+ >>> generator = MockDataGenerator(n_samples=100, n_genes=50, seed=42)
59
+ >>> real = generator.generate_real_data()
60
+ >>> generated = generator.generate_generated_data(real, noise_level=0.3)
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ n_samples: int = 100,
66
+ n_genes: int = 50,
67
+ n_perturbations: int = 3,
68
+ n_cell_types: int = 2,
69
+ seed: Optional[int] = None,
70
+ ):
71
+ self.n_samples = n_samples
72
+ self.n_genes = n_genes
73
+ self.n_perturbations = n_perturbations
74
+ self.n_cell_types = n_cell_types
75
+ self.seed = seed
76
+
77
+ if seed is not None:
78
+ np.random.seed(seed)
79
+
80
+ # Generate gene names
81
+ self.gene_names = [f"gene_{i}" for i in range(n_genes)]
82
+
83
+ # Generate perturbation names
84
+ self.perturbations = [f"perturbation_{i}" for i in range(n_perturbations)]
85
+
86
+ # Generate cell type names
87
+ self.cell_types = [f"cell_type_{i}" for i in range(n_cell_types)]
88
+
89
+ # Generate random effects
90
+ self._perturbation_effects = {
91
+ p: np.random.randn(n_genes) * 0.5 for p in self.perturbations
92
+ }
93
+ self._cell_type_effects = {
94
+ c: np.random.randn(n_genes) * 0.3 for c in self.cell_types
95
+ }
96
+
97
+ def generate_real_data(self) -> "ad.AnnData":
98
+ """
99
+ Generate realistic gene expression data.
100
+
101
+ Returns
102
+ -------
103
+ AnnData
104
+ Synthetic gene expression data with perturbation and cell type effects.
105
+ """
106
+ if not HAS_ANNDATA:
107
+ raise ImportError("anndata is required for AnnData generation")
108
+
109
+ # Assign perturbations and cell types
110
+ perturbations = np.random.choice(self.perturbations, self.n_samples)
111
+ cell_types = np.random.choice(self.cell_types, self.n_samples)
112
+
113
+ # Base expression (log-normal-like)
114
+ base_expression = np.random.exponential(1.0, (self.n_samples, self.n_genes))
115
+
116
+ # Add perturbation effects
117
+ for i, pert in enumerate(perturbations):
118
+ base_expression[i] += self._perturbation_effects[pert]
119
+
120
+ # Add cell type effects
121
+ for i, ct in enumerate(cell_types):
122
+ base_expression[i] += self._cell_type_effects[ct]
123
+
124
+ # Add noise
125
+ base_expression += np.random.randn(self.n_samples, self.n_genes) * 0.2
126
+
127
+ # Clip to realistic range
128
+ base_expression = np.clip(base_expression, 0, None)
129
+
130
+ # Create AnnData
131
+ adata = ad.AnnData(X=base_expression)
132
+ adata.var_names = self.gene_names
133
+ adata.obs["perturbation"] = perturbations
134
+ adata.obs["cell_type"] = cell_types
135
+ adata.obs_names = [f"cell_{i}" for i in range(self.n_samples)]
136
+
137
+ return adata
138
+
139
+ def generate_generated_data(
140
+ self,
141
+ real_data: "ad.AnnData",
142
+ noise_level: float = 0.3,
143
+ quality: str = "good",
144
+ ) -> "ad.AnnData":
145
+ """
146
+ Generate synthetic data matching real data structure.
147
+
148
+ Parameters
149
+ ----------
150
+ real_data : AnnData
151
+ Real data to match structure from.
152
+ noise_level : float
153
+ Amount of noise to add (0-1 scale).
154
+ quality : str
155
+ Quality level: "good", "medium", or "poor".
156
+
157
+ Returns
158
+ -------
159
+ AnnData
160
+ Generated data with same structure as real.
161
+ """
162
+ if not HAS_ANNDATA:
163
+ raise ImportError("anndata is required for AnnData generation")
164
+
165
+ # Copy structure
166
+ X = real_data.X.copy()
167
+
168
+ # Apply quality-based noise
169
+ if quality == "good":
170
+ noise_mult = noise_level
171
+ bias = 0.0
172
+ elif quality == "medium":
173
+ noise_mult = noise_level * 1.5
174
+ bias = 0.1
175
+ else: # poor
176
+ noise_mult = noise_level * 2.0
177
+ bias = 0.2
178
+
179
+ # Add noise
180
+ X = X + np.random.randn(*X.shape) * noise_mult
181
+
182
+ # Add bias
183
+ X = X + bias
184
+
185
+ # Clip
186
+ X = np.clip(X, 0, None)
187
+
188
+ # Create AnnData
189
+ generated = ad.AnnData(X=X)
190
+ generated.var_names = list(real_data.var_names)
191
+ generated.obs = real_data.obs.copy()
192
+ generated.obs_names = [f"gen_cell_{i}" for i in range(len(X))]
193
+
194
+ return generated
195
+
196
+ def generate_paired_data(
197
+ self,
198
+ noise_level: float = 0.3,
199
+ quality: str = "good",
200
+ include_split: bool = False,
201
+ train_fraction: float = 0.7,
202
+ ) -> Tuple["ad.AnnData", "ad.AnnData"]:
203
+ """
204
+ Generate paired real and generated datasets.
205
+
206
+ Parameters
207
+ ----------
208
+ noise_level : float
209
+ Noise level for generated data.
210
+ quality : str
211
+ Quality of generated data.
212
+ include_split : bool
213
+ Whether to include train/test split column.
214
+ train_fraction : float
215
+ Fraction of samples in training set.
216
+
217
+ Returns
218
+ -------
219
+ Tuple[AnnData, AnnData]
220
+ (real_data, generated_data) tuple.
221
+ """
222
+ real = self.generate_real_data()
223
+ generated = self.generate_generated_data(real, noise_level, quality)
224
+
225
+ if include_split:
226
+ n_train = int(self.n_samples * train_fraction)
227
+ splits = np.array(["train"] * n_train + ["test"] * (self.n_samples - n_train))
228
+ np.random.shuffle(splits)
229
+ real.obs["split"] = splits
230
+ generated.obs["split"] = splits
231
+
232
+ return real, generated
233
+
234
+ def save_paired_data(
235
+ self,
236
+ output_dir: Union[str, Path],
237
+ noise_level: float = 0.3,
238
+ quality: str = "good",
239
+ include_split: bool = True,
240
+ ) -> Tuple[Path, Path]:
241
+ """
242
+ Generate and save paired datasets to h5ad files.
243
+
244
+ Parameters
245
+ ----------
246
+ output_dir : Path
247
+ Directory to save files.
248
+ noise_level : float
249
+ Noise level for generated data.
250
+ quality : str
251
+ Quality of generated data.
252
+ include_split : bool
253
+ Whether to include train/test split column.
254
+
255
+ Returns
256
+ -------
257
+ Tuple[Path, Path]
258
+ (real_path, generated_path) tuple.
259
+ """
260
+ output_dir = Path(output_dir)
261
+ output_dir.mkdir(parents=True, exist_ok=True)
262
+
263
+ real, generated = self.generate_paired_data(
264
+ noise_level=noise_level,
265
+ quality=quality,
266
+ include_split=include_split,
267
+ )
268
+
269
+ real_path = output_dir / "real.h5ad"
270
+ generated_path = output_dir / "generated.h5ad"
271
+
272
+ real.write(real_path)
273
+ generated.write(generated_path)
274
+
275
+ return real_path, generated_path
276
+
277
+
278
+ class MockMetricData:
279
+ """
280
+ Generator for mock metric testing data.
281
+
282
+ Creates numpy arrays with specific statistical properties
283
+ for testing metric implementations.
284
+
285
+ Parameters
286
+ ----------
287
+ seed : int, optional
288
+ Random seed for reproducibility.
289
+ """
290
+
291
+ def __init__(self, seed: Optional[int] = None):
292
+ self.seed = seed
293
+ if seed is not None:
294
+ np.random.seed(seed)
295
+
296
+ def identical_distributions(
297
+ self,
298
+ n_samples: int = 100,
299
+ n_features: int = 50,
300
+ ) -> Tuple[np.ndarray, np.ndarray]:
301
+ """Generate identical distributions for testing zero distance."""
302
+ data = np.random.randn(n_samples, n_features)
303
+ return data.copy(), data.copy()
304
+
305
+ def similar_distributions(
306
+ self,
307
+ n_samples: int = 100,
308
+ n_features: int = 50,
309
+ noise: float = 0.3,
310
+ ) -> Tuple[np.ndarray, np.ndarray]:
311
+ """Generate similar but not identical distributions."""
312
+ real = np.random.randn(n_samples, n_features)
313
+ generated = real + np.random.randn(n_samples, n_features) * noise
314
+ return real, generated
315
+
316
+ def different_distributions(
317
+ self,
318
+ n_samples: int = 100,
319
+ n_features: int = 50,
320
+ ) -> Tuple[np.ndarray, np.ndarray]:
321
+ """Generate clearly different distributions."""
322
+ real = np.random.randn(n_samples, n_features)
323
+ generated = np.random.randn(n_samples, n_features) + 3.0 # Shifted mean
324
+ return real, generated
325
+
326
+ def with_outliers(
327
+ self,
328
+ n_samples: int = 100,
329
+ n_features: int = 50,
330
+ outlier_fraction: float = 0.1,
331
+ ) -> Tuple[np.ndarray, np.ndarray]:
332
+ """Generate data with outliers in generated."""
333
+ real = np.random.randn(n_samples, n_features)
334
+ generated = real.copy()
335
+
336
+ n_outliers = int(n_samples * outlier_fraction)
337
+ outlier_indices = np.random.choice(n_samples, n_outliers, replace=False)
338
+ generated[outlier_indices] = np.random.randn(n_outliers, n_features) * 10
339
+
340
+ return real, generated
341
+
342
+ def sparse_data(
343
+ self,
344
+ n_samples: int = 100,
345
+ n_features: int = 50,
346
+ sparsity: float = 0.8,
347
+ ) -> Tuple[np.ndarray, np.ndarray]:
348
+ """Generate sparse data (many zeros)."""
349
+ real = np.random.randn(n_samples, n_features)
350
+ generated = np.random.randn(n_samples, n_features)
351
+
352
+ # Zero out based on sparsity
353
+ mask_real = np.random.random((n_samples, n_features)) < sparsity
354
+ mask_gen = np.random.random((n_samples, n_features)) < sparsity
355
+
356
+ real[mask_real] = 0
357
+ generated[mask_gen] = 0
358
+
359
+ return real, generated
360
+
361
+
362
+ # Convenience functions
363
+ def create_test_data(
364
+ n_samples: int = 100,
365
+ n_genes: int = 50,
366
+ noise_level: float = 0.3,
367
+ seed: int = 42,
368
+ ) -> Tuple["ad.AnnData", "ad.AnnData"]:
369
+ """
370
+ Create synthetic test data quickly.
371
+
372
+ Parameters
373
+ ----------
374
+ n_samples : int
375
+ Number of samples.
376
+ n_genes : int
377
+ Number of genes.
378
+ noise_level : float
379
+ Noise level for generated data.
380
+ seed : int
381
+ Random seed.
382
+
383
+ Returns
384
+ -------
385
+ Tuple[AnnData, AnnData]
386
+ (real, generated) tuple.
387
+ """
388
+ generator = MockDataGenerator(
389
+ n_samples=n_samples,
390
+ n_genes=n_genes,
391
+ seed=seed,
392
+ )
393
+ return generator.generate_paired_data(noise_level=noise_level)
@@ -0,0 +1 @@
1
+ # This file is intentionally left blank.
geneval/utils/io.py ADDED
@@ -0,0 +1,27 @@
1
+ from typing import Any, Dict
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os
5
+ import json
6
+
7
+ def load_data(file_path: str) -> pd.DataFrame:
8
+ """Load data from a CSV file."""
9
+ if not os.path.exists(file_path):
10
+ raise FileNotFoundError(f"The file {file_path} does not exist.")
11
+ return pd.read_csv(file_path)
12
+
13
+ def save_data(data: pd.DataFrame, file_path: str) -> None:
14
+ """Save data to a CSV file."""
15
+ data.to_csv(file_path, index=False)
16
+
17
+ def load_json(file_path: str) -> Dict[str, Any]:
18
+ """Load data from a JSON file."""
19
+ if not os.path.exists(file_path):
20
+ raise FileNotFoundError(f"The file {file_path} does not exist.")
21
+ with open(file_path, 'r') as f:
22
+ return json.load(f)
23
+
24
+ def save_json(data: Dict[str, Any], file_path: str) -> None:
25
+ """Save data to a JSON file."""
26
+ with open(file_path, 'w') as f:
27
+ json.dump(data, f, indent=4)
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Union
4
+ import numpy as np
5
+ import pandas as pd
6
+ from scipy import sparse
7
+
8
+ def normalize_data(data: Union[np.ndarray, pd.DataFrame]) -> Union[np.ndarray, pd.DataFrame]:
9
+ """
10
+ Normalize the gene expression data to have zero mean and unit variance.
11
+
12
+ Parameters
13
+ ----------
14
+ data : np.ndarray or pd.DataFrame
15
+ The gene expression data to normalize.
16
+
17
+ Returns
18
+ -------
19
+ np.ndarray or pd.DataFrame
20
+ The normalized gene expression data.
21
+ """
22
+ if isinstance(data, pd.DataFrame):
23
+ return (data - data.mean()) / data.std()
24
+ elif isinstance(data, np.ndarray):
25
+ return (data - np.mean(data, axis=0)) / np.std(data, axis=0)
26
+ else:
27
+ raise TypeError("Input data must be a numpy array or a pandas DataFrame.")
28
+
29
+ def log_transform(data: Union[np.ndarray, pd.DataFrame]) -> Union[np.ndarray, pd.DataFrame]:
30
+ """
31
+ Apply log transformation to the gene expression data.
32
+
33
+ Parameters
34
+ ----------
35
+ data : np.ndarray or pd.DataFrame
36
+ The gene expression data to transform.
37
+
38
+ Returns
39
+ -------
40
+ np.ndarray or pd.DataFrame
41
+ The log-transformed gene expression data.
42
+ """
43
+ if isinstance(data, pd.DataFrame):
44
+ return np.log1p(data)
45
+ elif isinstance(data, np.ndarray):
46
+ return np.log1p(data)
47
+ else:
48
+ raise TypeError("Input data must be a numpy array or a pandas DataFrame.")
49
+
50
+ def scale_data(data: Union[np.ndarray, pd.DataFrame], min_val: float = 0, max_val: float = 1) -> Union[np.ndarray, pd.DataFrame]:
51
+ """
52
+ Scale the gene expression data to a specified range.
53
+
54
+ Parameters
55
+ ----------
56
+ data : np.ndarray or pd.DataFrame
57
+ The gene expression data to scale.
58
+ min_val : float
59
+ The minimum value of the scaled data.
60
+ max_val : float
61
+ The maximum value of the scaled data.
62
+
63
+ Returns
64
+ -------
65
+ np.ndarray or pd.DataFrame
66
+ The scaled gene expression data.
67
+ """
68
+ if isinstance(data, pd.DataFrame):
69
+ return (data - data.min()) / (data.max() - data.min()) * (max_val - min_val) + min_val
70
+ elif isinstance(data, np.ndarray):
71
+ return (data - np.min(data, axis=0)) / (np.max(data, axis=0) - np.min(data, axis=0)) * (max_val - min_val) + min_val
72
+ else:
73
+ raise TypeError("Input data must be a numpy array or a pandas DataFrame.")
74
+
75
+ def to_dense(X):
76
+ """
77
+ Safely convert a matrix-like to a dense numpy array without copying if already dense.
78
+ Handles scipy.sparse matrices.
79
+ """
80
+ if sparse.issparse(X):
81
+ return X.toarray()
82
+ return np.asarray(X)
@@ -0,0 +1,38 @@
1
+ """
2
+ Visualization module for gene expression evaluation.
3
+
4
+ Provides publication-quality plots:
5
+ - Boxplots and violin plots for metric distributions
6
+ - Radar plots for multi-metric comparison
7
+ - Scatter plots for real vs generated expression
8
+ - Embedding plots (PCA, UMAP)
9
+ - Heatmaps for per-gene metrics
10
+ """
11
+
12
+ from .plots import (
13
+ EvaluationPlotter,
14
+ create_boxplot,
15
+ create_violin_plot,
16
+ create_heatmap,
17
+ create_scatter,
18
+ create_radar_chart,
19
+ )
20
+ from .visualizer import (
21
+ EvaluationVisualizer,
22
+ PlotStyle,
23
+ visualize,
24
+ )
25
+
26
+ __all__ = [
27
+ # Classes
28
+ "EvaluationPlotter",
29
+ "EvaluationVisualizer",
30
+ "PlotStyle",
31
+ # Functions
32
+ "visualize",
33
+ "create_boxplot",
34
+ "create_violin_plot",
35
+ "create_heatmap",
36
+ "create_scatter",
37
+ "create_radar_chart",
38
+ ]