gengeneeval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,516 @@
1
+ """
2
+ Distribution distance metrics for gene expression evaluation.
3
+
4
+ Provides Wasserstein, MMD, and Energy distance metrics with per-gene computation.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+ from scipy.stats import wasserstein_distance
10
+ from typing import Optional, Tuple
11
+ import warnings
12
+
13
+ from .base_metric import DistributionMetric
14
+
15
+
16
+ def _ensure_2d(arr: np.ndarray) -> np.ndarray:
17
+ """Ensure array is 2D (samples x genes)."""
18
+ arr = np.asarray(arr, dtype=np.float64)
19
+ if arr.ndim == 1:
20
+ arr = arr.reshape(-1, 1)
21
+ return arr
22
+
23
+
24
+ class Wasserstein1Distance(DistributionMetric):
25
+ """
26
+ Wasserstein-1 (Earth Mover's) distance between distributions.
27
+
28
+ Measures the minimum amount of work to transform one distribution
29
+ into another. Computed per gene using 1D Wasserstein distance.
30
+
31
+ Lower values indicate more similar distributions.
32
+ """
33
+
34
+ def __init__(self):
35
+ super().__init__(
36
+ name="wasserstein_1",
37
+ description="Wasserstein-1 (Earth Mover's) distance per gene"
38
+ )
39
+
40
+ def compute_per_gene(
41
+ self,
42
+ real: np.ndarray,
43
+ generated: np.ndarray,
44
+ ) -> np.ndarray:
45
+ """
46
+ Compute Wasserstein-1 distance for each gene.
47
+
48
+ Parameters
49
+ ----------
50
+ real : np.ndarray
51
+ Real data, shape (n_samples_real, n_genes)
52
+ generated : np.ndarray
53
+ Generated data, shape (n_samples_gen, n_genes)
54
+
55
+ Returns
56
+ -------
57
+ np.ndarray
58
+ W1 distance per gene
59
+ """
60
+ real = _ensure_2d(real)
61
+ generated = _ensure_2d(generated)
62
+ n_genes = real.shape[1]
63
+
64
+ distances = np.zeros(n_genes)
65
+
66
+ for i in range(n_genes):
67
+ r_vals = real[:, i]
68
+ g_vals = generated[:, i]
69
+
70
+ # Filter NaN values
71
+ r_vals = r_vals[~np.isnan(r_vals)]
72
+ g_vals = g_vals[~np.isnan(g_vals)]
73
+
74
+ if len(r_vals) == 0 or len(g_vals) == 0:
75
+ distances[i] = np.nan
76
+ continue
77
+
78
+ distances[i] = wasserstein_distance(r_vals, g_vals)
79
+
80
+ return distances
81
+
82
+
83
+ class Wasserstein2Distance(DistributionMetric):
84
+ """
85
+ Wasserstein-2 distance (quadratic cost) between distributions.
86
+
87
+ Uses p=2 norm for transport cost. More sensitive to outliers than W1.
88
+ Computed per gene.
89
+ """
90
+
91
+ def __init__(self, use_geomloss: bool = True):
92
+ """
93
+ Parameters
94
+ ----------
95
+ use_geomloss : bool
96
+ If True, use geomloss for GPU-accelerated computation.
97
+ Falls back to scipy otherwise.
98
+ """
99
+ super().__init__(
100
+ name="wasserstein_2",
101
+ description="Wasserstein-2 distance per gene"
102
+ )
103
+ self.use_geomloss = use_geomloss
104
+ self._geomloss_available = False
105
+
106
+ if use_geomloss:
107
+ try:
108
+ import torch
109
+ from geomloss import SamplesLoss
110
+ self._geomloss_available = True
111
+ except ImportError:
112
+ warnings.warn(
113
+ "geomloss not available, falling back to scipy implementation"
114
+ )
115
+
116
+ def _w2_scipy(self, r_vals: np.ndarray, g_vals: np.ndarray) -> float:
117
+ """Compute W2 using scipy (approximation via sorted quantiles)."""
118
+ # Sort values and compute quadratic Wasserstein
119
+ r_sorted = np.sort(r_vals)
120
+ g_sorted = np.sort(g_vals)
121
+
122
+ # Resample to same length for comparison
123
+ n = max(len(r_sorted), len(g_sorted))
124
+ r_quantiles = np.interp(
125
+ np.linspace(0, 1, n),
126
+ np.linspace(0, 1, len(r_sorted)),
127
+ r_sorted
128
+ )
129
+ g_quantiles = np.interp(
130
+ np.linspace(0, 1, n),
131
+ np.linspace(0, 1, len(g_sorted)),
132
+ g_sorted
133
+ )
134
+
135
+ return np.sqrt(np.mean((r_quantiles - g_quantiles) ** 2))
136
+
137
+ def compute_per_gene(
138
+ self,
139
+ real: np.ndarray,
140
+ generated: np.ndarray,
141
+ ) -> np.ndarray:
142
+ """
143
+ Compute Wasserstein-2 distance for each gene.
144
+ """
145
+ real = _ensure_2d(real)
146
+ generated = _ensure_2d(generated)
147
+ n_genes = real.shape[1]
148
+
149
+ distances = np.zeros(n_genes)
150
+
151
+ if self._geomloss_available and self.use_geomloss:
152
+ import torch
153
+ from geomloss import SamplesLoss
154
+ loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.01, backend="tensorized")
155
+
156
+ for i in range(n_genes):
157
+ r_vals = real[:, i]
158
+ g_vals = generated[:, i]
159
+
160
+ r_vals = r_vals[~np.isnan(r_vals)]
161
+ g_vals = g_vals[~np.isnan(g_vals)]
162
+
163
+ if len(r_vals) == 0 or len(g_vals) == 0:
164
+ distances[i] = np.nan
165
+ continue
166
+
167
+ # Reshape for geomloss (N, D)
168
+ r_tensor = torch.tensor(r_vals.reshape(-1, 1), dtype=torch.float32)
169
+ g_tensor = torch.tensor(g_vals.reshape(-1, 1), dtype=torch.float32)
170
+
171
+ distances[i] = loss_fn(r_tensor, g_tensor).item()
172
+ else:
173
+ for i in range(n_genes):
174
+ r_vals = real[:, i]
175
+ g_vals = generated[:, i]
176
+
177
+ r_vals = r_vals[~np.isnan(r_vals)]
178
+ g_vals = g_vals[~np.isnan(g_vals)]
179
+
180
+ if len(r_vals) == 0 or len(g_vals) == 0:
181
+ distances[i] = np.nan
182
+ continue
183
+
184
+ distances[i] = self._w2_scipy(r_vals, g_vals)
185
+
186
+ return distances
187
+
188
+
189
+ class MMDDistance(DistributionMetric):
190
+ """
191
+ Maximum Mean Discrepancy (MMD) between distributions.
192
+
193
+ Non-parametric distance based on kernel embeddings.
194
+ Uses RBF (Gaussian) kernel. Computed per gene.
195
+ """
196
+
197
+ def __init__(self, kernel: str = "rbf", sigma: Optional[float] = None):
198
+ """
199
+ Parameters
200
+ ----------
201
+ kernel : str
202
+ Kernel type ("rbf" for Gaussian)
203
+ sigma : float, optional
204
+ Kernel bandwidth. If None, uses median heuristic.
205
+ """
206
+ super().__init__(
207
+ name="mmd",
208
+ description="Maximum Mean Discrepancy with RBF kernel"
209
+ )
210
+ self.kernel = kernel
211
+ self.sigma = sigma
212
+
213
+ def _rbf_kernel(
214
+ self,
215
+ x: np.ndarray,
216
+ y: np.ndarray,
217
+ sigma: float
218
+ ) -> np.ndarray:
219
+ """Compute RBF kernel matrix."""
220
+ x = x.reshape(-1, 1) if x.ndim == 1 else x
221
+ y = y.reshape(-1, 1) if y.ndim == 1 else y
222
+
223
+ # Compute pairwise squared distances
224
+ diff = x[:, np.newaxis, :] - y[np.newaxis, :, :]
225
+ sq_dist = np.sum(diff ** 2, axis=-1)
226
+
227
+ return np.exp(-sq_dist / (2 * sigma ** 2))
228
+
229
+ def _median_heuristic(self, x: np.ndarray, y: np.ndarray) -> float:
230
+ """Compute bandwidth using median heuristic."""
231
+ combined = np.concatenate([x, y])
232
+ pairwise = np.abs(combined[:, np.newaxis] - combined[np.newaxis, :])
233
+ return float(np.median(pairwise[pairwise > 0]))
234
+
235
+ def _compute_mmd_single(
236
+ self,
237
+ x: np.ndarray,
238
+ y: np.ndarray,
239
+ sigma: Optional[float] = None
240
+ ) -> float:
241
+ """Compute MMD for single gene."""
242
+ if sigma is None:
243
+ sigma = self._median_heuristic(x, y)
244
+ if sigma == 0:
245
+ sigma = 1.0
246
+
247
+ K_xx = self._rbf_kernel(x, x, sigma)
248
+ K_yy = self._rbf_kernel(y, y, sigma)
249
+ K_xy = self._rbf_kernel(x, y, sigma)
250
+
251
+ n_x = len(x)
252
+ n_y = len(y)
253
+
254
+ # Unbiased MMD estimator
255
+ mmd = (
256
+ (np.sum(K_xx) - np.trace(K_xx)) / (n_x * (n_x - 1)) +
257
+ (np.sum(K_yy) - np.trace(K_yy)) / (n_y * (n_y - 1)) -
258
+ 2 * np.sum(K_xy) / (n_x * n_y)
259
+ )
260
+
261
+ return max(0, mmd) # Ensure non-negative
262
+
263
+ def compute_per_gene(
264
+ self,
265
+ real: np.ndarray,
266
+ generated: np.ndarray,
267
+ ) -> np.ndarray:
268
+ """
269
+ Compute MMD for each gene.
270
+ """
271
+ real = _ensure_2d(real)
272
+ generated = _ensure_2d(generated)
273
+ n_genes = real.shape[1]
274
+
275
+ distances = np.zeros(n_genes)
276
+
277
+ for i in range(n_genes):
278
+ r_vals = real[:, i]
279
+ g_vals = generated[:, i]
280
+
281
+ r_vals = r_vals[~np.isnan(r_vals)]
282
+ g_vals = g_vals[~np.isnan(g_vals)]
283
+
284
+ if len(r_vals) < 2 or len(g_vals) < 2:
285
+ distances[i] = np.nan
286
+ continue
287
+
288
+ distances[i] = self._compute_mmd_single(r_vals, g_vals, self.sigma)
289
+
290
+ return distances
291
+
292
+
293
+ class EnergyDistance(DistributionMetric):
294
+ """
295
+ Energy distance between distributions.
296
+
297
+ Based on statistical potential energy. Related to but different from
298
+ Wasserstein distance. Computed per gene.
299
+ """
300
+
301
+ def __init__(self, use_geomloss: bool = True):
302
+ super().__init__(
303
+ name="energy",
304
+ description="Energy distance per gene"
305
+ )
306
+ self.use_geomloss = use_geomloss
307
+ self._geomloss_available = False
308
+
309
+ if use_geomloss:
310
+ try:
311
+ import torch
312
+ from geomloss import SamplesLoss
313
+ self._geomloss_available = True
314
+ except ImportError:
315
+ pass
316
+
317
+ def _energy_scipy(self, x: np.ndarray, y: np.ndarray) -> float:
318
+ """Compute energy distance using scipy."""
319
+ n_x, n_y = len(x), len(y)
320
+
321
+ # E[|X - Y|]
322
+ xy_dist = np.mean(np.abs(x[:, np.newaxis] - y[np.newaxis, :]))
323
+
324
+ # E[|X - X'|]
325
+ xx_dist = np.mean(np.abs(x[:, np.newaxis] - x[np.newaxis, :]))
326
+
327
+ # E[|Y - Y'|]
328
+ yy_dist = np.mean(np.abs(y[:, np.newaxis] - y[np.newaxis, :]))
329
+
330
+ energy = 2 * xy_dist - xx_dist - yy_dist
331
+ return max(0, energy)
332
+
333
+ def compute_per_gene(
334
+ self,
335
+ real: np.ndarray,
336
+ generated: np.ndarray,
337
+ ) -> np.ndarray:
338
+ """
339
+ Compute energy distance for each gene.
340
+ """
341
+ real = _ensure_2d(real)
342
+ generated = _ensure_2d(generated)
343
+ n_genes = real.shape[1]
344
+
345
+ distances = np.zeros(n_genes)
346
+
347
+ if self._geomloss_available and self.use_geomloss:
348
+ import torch
349
+ from geomloss import SamplesLoss
350
+ loss_fn = SamplesLoss(loss="energy", blur=0.5, backend="tensorized")
351
+
352
+ for i in range(n_genes):
353
+ r_vals = real[:, i]
354
+ g_vals = generated[:, i]
355
+
356
+ r_vals = r_vals[~np.isnan(r_vals)]
357
+ g_vals = g_vals[~np.isnan(g_vals)]
358
+
359
+ if len(r_vals) == 0 or len(g_vals) == 0:
360
+ distances[i] = np.nan
361
+ continue
362
+
363
+ r_tensor = torch.tensor(r_vals.reshape(-1, 1), dtype=torch.float32)
364
+ g_tensor = torch.tensor(g_vals.reshape(-1, 1), dtype=torch.float32)
365
+
366
+ distances[i] = loss_fn(r_tensor, g_tensor).item()
367
+ else:
368
+ for i in range(n_genes):
369
+ r_vals = real[:, i]
370
+ g_vals = generated[:, i]
371
+
372
+ r_vals = r_vals[~np.isnan(r_vals)]
373
+ g_vals = g_vals[~np.isnan(g_vals)]
374
+
375
+ if len(r_vals) < 2 or len(g_vals) < 2:
376
+ distances[i] = np.nan
377
+ continue
378
+
379
+ distances[i] = self._energy_scipy(r_vals, g_vals)
380
+
381
+ return distances
382
+
383
+
384
+ # Multivariate distance metrics (computed on full gene space)
385
+
386
+ class MultivariateWasserstein(DistributionMetric):
387
+ """
388
+ Multivariate Wasserstein distance on full gene expression space.
389
+
390
+ Unlike per-gene metrics, this computes distance in the joint space
391
+ of all genes. Typically applied after PCA dimensionality reduction.
392
+ """
393
+
394
+ def __init__(self, p: int = 2, blur: float = 0.01):
395
+ super().__init__(
396
+ name="multivariate_wasserstein",
397
+ description=f"Multivariate Wasserstein-{p} distance"
398
+ )
399
+ self.p = p
400
+ self.blur = blur
401
+
402
+ def compute_per_gene(
403
+ self,
404
+ real: np.ndarray,
405
+ generated: np.ndarray,
406
+ ) -> np.ndarray:
407
+ """
408
+ Compute multivariate distance (returns same value for all genes).
409
+ """
410
+ real = _ensure_2d(real)
411
+ generated = _ensure_2d(generated)
412
+ n_genes = real.shape[1]
413
+
414
+ try:
415
+ import torch
416
+ from geomloss import SamplesLoss
417
+
418
+ loss_fn = SamplesLoss(
419
+ loss="sinkhorn",
420
+ p=self.p,
421
+ blur=self.blur,
422
+ backend="tensorized"
423
+ )
424
+
425
+ r_tensor = torch.tensor(real, dtype=torch.float32)
426
+ g_tensor = torch.tensor(generated, dtype=torch.float32)
427
+
428
+ distance = loss_fn(r_tensor, g_tensor).item()
429
+ except ImportError:
430
+ # Fallback: use sliced Wasserstein approximation
431
+ warnings.warn("geomloss not available, using sliced Wasserstein approximation")
432
+ distance = self._sliced_wasserstein(real, generated)
433
+
434
+ return np.full(n_genes, distance)
435
+
436
+ def _sliced_wasserstein(
437
+ self,
438
+ x: np.ndarray,
439
+ y: np.ndarray,
440
+ n_projections: int = 100
441
+ ) -> float:
442
+ """Compute sliced Wasserstein distance as fallback."""
443
+ d = x.shape[1]
444
+
445
+ # Random projections
446
+ projections = np.random.randn(d, n_projections)
447
+ projections /= np.linalg.norm(projections, axis=0)
448
+
449
+ distances = []
450
+ for i in range(n_projections):
451
+ proj = projections[:, i]
452
+ x_proj = x @ proj
453
+ y_proj = y @ proj
454
+ distances.append(wasserstein_distance(x_proj, y_proj))
455
+
456
+ return float(np.mean(distances))
457
+
458
+
459
+ class MultivariateMMD(DistributionMetric):
460
+ """
461
+ Multivariate MMD on full gene expression space.
462
+ """
463
+
464
+ def __init__(self, sigma: Optional[float] = None):
465
+ super().__init__(
466
+ name="multivariate_mmd",
467
+ description="Multivariate MMD with RBF kernel"
468
+ )
469
+ self.sigma = sigma
470
+
471
+ def compute_per_gene(
472
+ self,
473
+ real: np.ndarray,
474
+ generated: np.ndarray,
475
+ ) -> np.ndarray:
476
+ """
477
+ Compute multivariate MMD.
478
+ """
479
+ real = _ensure_2d(real)
480
+ generated = _ensure_2d(generated)
481
+ n_genes = real.shape[1]
482
+
483
+ # Use median heuristic for bandwidth
484
+ if self.sigma is None:
485
+ combined = np.vstack([real, generated])
486
+ pairwise_sq = np.sum(
487
+ (combined[:, np.newaxis, :] - combined[np.newaxis, :, :]) ** 2,
488
+ axis=-1
489
+ )
490
+ sigma = float(np.sqrt(np.median(pairwise_sq[pairwise_sq > 0])))
491
+ if sigma == 0:
492
+ sigma = 1.0
493
+ else:
494
+ sigma = self.sigma
495
+
496
+ # Compute kernel matrices
497
+ def rbf_kernel(x, y, sigma):
498
+ pairwise_sq = np.sum(
499
+ (x[:, np.newaxis, :] - y[np.newaxis, :, :]) ** 2,
500
+ axis=-1
501
+ )
502
+ return np.exp(-pairwise_sq / (2 * sigma ** 2))
503
+
504
+ K_xx = rbf_kernel(real, real, sigma)
505
+ K_yy = rbf_kernel(generated, generated, sigma)
506
+ K_xy = rbf_kernel(real, generated, sigma)
507
+
508
+ n_x, n_y = len(real), len(generated)
509
+
510
+ mmd = (
511
+ (np.sum(K_xx) - np.trace(K_xx)) / (n_x * (n_x - 1)) +
512
+ (np.sum(K_yy) - np.trace(K_yy)) / (n_y * (n_y - 1)) -
513
+ 2 * np.sum(K_xy) / (n_x * n_y)
514
+ )
515
+
516
+ return np.full(n_genes, max(0, mmd))
@@ -0,0 +1,134 @@
1
+ from geomloss import SamplesLoss
2
+ import anndata as ad
3
+ import scanpy as sc
4
+ import pandas as pd
5
+ import numpy as np
6
+ from scipy.stats import pearsonr, spearmanr
7
+ import torch
8
+ from . import metric_MMD
9
+
10
+ class Metric():
11
+ def __init__(self, name: str, fn):
12
+ self.name = name
13
+ self.fn = fn
14
+
15
+ def compute(self, x, y):
16
+ return self.fn(x, y)
17
+
18
+ class PerturbationMetric():
19
+ def __init__(self, name: str, fn):
20
+ self.name = name
21
+ self.fn = fn
22
+
23
+ def compute(self, adata_true: ad.AnnData, adata_generated: ad.AnnData, groupby: str):
24
+ return self.fn(adata_true, adata_generated, groupby)
25
+
26
+ def compute_metrics(original_data, generated_data, metric_fn):
27
+ metric_funcs = {
28
+ 'w1': SamplesLoss(loss="sinkhorn", p=1, blur=0.01),
29
+ 'w2': SamplesLoss(loss="sinkhorn", p=2, blur=0.01),
30
+ 'mmd': metric_MMD.iface_compute_MMD,
31
+ 'energy': SamplesLoss(loss="energy", blur=0.5),
32
+ }
33
+ metric_fn = metric_funcs[metric_fn]
34
+ original_data = torch.tensor(original_data)
35
+ generated_data = torch.tensor(generated_data)
36
+ metric = metric_fn(generated_data, original_data)
37
+ return metric.item()
38
+
39
+ def W1(x, y):
40
+ loss_fn = SamplesLoss(loss="sinkhorn", p=1, blur=0.01, backend="tensorized")
41
+ return loss_fn(x, y).item()
42
+
43
+ def W2(x, y):
44
+ loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.01, backend="tensorized")
45
+ return loss_fn(x, y).item()
46
+
47
+ def W1_complete(x, y, preprocess=False):
48
+ if preprocess:
49
+ x = scanpy_preprocessing(x)
50
+ y = scanpy_preprocessing(y)
51
+
52
+ x_reduced = scanpy_pca(x)
53
+ y_reduced = scanpy_pca(y)
54
+
55
+ x_pca = torch.tensor(x_reduced.obsm['X_pca'], dtype=torch.float32)
56
+ y_pca = torch.tensor(y_reduced.obsm['X_pca'], dtype=torch.float32)
57
+
58
+ return W1(x_pca, y_pca)
59
+
60
+ def get_deg_genes(adata: ad.AnnData, groupby: str = "condition_ID", method: str = "wilcoxon", alpha: float = 0.05):
61
+ sc.tl.rank_genes_groups(adata, groupby=groupby, method=method, use_raw=False, n_genes=adata.shape[1])
62
+
63
+ degs = set()
64
+ rg_results = adata.uns["rank_genes_groups"]
65
+
66
+ for group in rg_results["names"].dtype.names:
67
+ pvals_adj = rg_results["pvals_adj"][group]
68
+ genes = rg_results["names"][group]
69
+
70
+ for gene, pval in zip(genes, pvals_adj):
71
+ if pval < alpha:
72
+ degs.add(gene)
73
+
74
+ return degs
75
+
76
+ def get_avg_expression(adata: ad.AnnData, genes: set) -> pd.Series:
77
+ common_genes = list(set(adata.var_names).intersection(genes))
78
+ if len(common_genes) == 0:
79
+ return pd.Series(dtype=float)
80
+
81
+ sub_adata = adata[:, common_genes]
82
+ avg_exp = np.array(sub_adata.X.mean(axis=0)).ravel()
83
+
84
+ return pd.Series(data=avg_exp, index=common_genes)
85
+
86
+ def pearson_dict(x, y):
87
+ common_keys = set(x.keys()).intersection(y.keys())
88
+ true_values = [x[key] for key in common_keys]
89
+ calculated_values = [y[key] for key in common_keys]
90
+ correlation, _ = pearsonr(true_values, calculated_values)
91
+
92
+ return correlation
93
+
94
+ def spearman_dict(x, y):
95
+ common_keys = set(x.keys()).intersection(y.keys())
96
+ true_values = [x[key] for key in common_keys]
97
+ calculated_values = [y[key] for key in common_keys]
98
+ correlation, _ = spearmanr(true_values, calculated_values)
99
+
100
+ return correlation
101
+
102
+ def mse_dict(x, y):
103
+ common_keys = set(x.keys()).intersection(y.keys())
104
+ true_values = np.array([x[key] for key in common_keys])
105
+ calculated_values = np.array([y[key] for key in common_keys])
106
+ mse = np.mean((true_values - calculated_values) ** 2)
107
+
108
+ return mse
109
+
110
+ def compute_pearson(x, y):
111
+ common_genes = x.index.intersection(y.index)
112
+
113
+ if len(common_genes) == 0:
114
+ return float('nan')
115
+
116
+ x_vals = x.loc[common_genes].values
117
+ y_vals = y.loc[common_genes].values
118
+
119
+ pearson_corr, _ = pearsonr(x_vals, y_vals)
120
+
121
+ return pearson_corr
122
+
123
+ def compute_spearman(x, y):
124
+ common_genes = x.index.intersection(y.index)
125
+
126
+ if len(common_genes) == 0:
127
+ return float('nan')
128
+
129
+ x_vals = x.loc[common_genes].values
130
+ y_vals = y.loc[common_genes].values
131
+
132
+ spearman_corr, _ = spearmanr(x_vals, y_vals)
133
+
134
+ return spearman_corr
@@ -0,0 +1 @@
1
+ # This file initializes the models module.
@@ -0,0 +1,53 @@
1
+ class BaseModel:
2
+ """
3
+ Base class for all models in the gene expression evaluation system.
4
+
5
+ This class provides a foundation for model classes that may be implemented in the future.
6
+ It can include common methods and attributes that all models should have.
7
+ """
8
+
9
+ def __init__(self):
10
+ pass
11
+
12
+ def fit(self, data):
13
+ """
14
+ Fit the model to the provided data.
15
+
16
+ Parameters
17
+ ----------
18
+ data : Any
19
+ The data to fit the model on.
20
+ """
21
+ raise NotImplementedError("Subclasses should implement this method.")
22
+
23
+ def predict(self, data):
24
+ """
25
+ Make predictions using the fitted model.
26
+
27
+ Parameters
28
+ ----------
29
+ data : Any
30
+ The data to make predictions on.
31
+
32
+ Returns
33
+ -------
34
+ Any
35
+ The predictions made by the model.
36
+ """
37
+ raise NotImplementedError("Subclasses should implement this method.")
38
+
39
+ def evaluate(self, data):
40
+ """
41
+ Evaluate the model's performance on the provided data.
42
+
43
+ Parameters
44
+ ----------
45
+ data : Any
46
+ The data to evaluate the model on.
47
+
48
+ Returns
49
+ -------
50
+ Any
51
+ The evaluation metrics.
52
+ """
53
+ raise NotImplementedError("Subclasses should implement this method.")