gengeneeval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +129 -0
- geneval/cli.py +333 -0
- geneval/config.py +141 -0
- geneval/core.py +41 -0
- geneval/data/__init__.py +23 -0
- geneval/data/gene_expression_datamodule.py +211 -0
- geneval/data/loader.py +437 -0
- geneval/evaluator.py +359 -0
- geneval/evaluators/__init__.py +4 -0
- geneval/evaluators/base_evaluator.py +178 -0
- geneval/evaluators/gene_expression_evaluator.py +218 -0
- geneval/metrics/__init__.py +65 -0
- geneval/metrics/base_metric.py +229 -0
- geneval/metrics/correlation.py +232 -0
- geneval/metrics/distances.py +516 -0
- geneval/metrics/metrics.py +134 -0
- geneval/models/__init__.py +1 -0
- geneval/models/base_model.py +53 -0
- geneval/results.py +334 -0
- geneval/testing.py +393 -0
- geneval/utils/__init__.py +1 -0
- geneval/utils/io.py +27 -0
- geneval/utils/preprocessing.py +82 -0
- geneval/visualization/__init__.py +38 -0
- geneval/visualization/plots.py +499 -0
- geneval/visualization/visualizer.py +1096 -0
- gengeneeval-0.1.0.dist-info/METADATA +172 -0
- gengeneeval-0.1.0.dist-info/RECORD +31 -0
- gengeneeval-0.1.0.dist-info/WHEEL +4 -0
- gengeneeval-0.1.0.dist-info/entry_points.txt +3 -0
- gengeneeval-0.1.0.dist-info/licenses/LICENSE +9 -0
|
@@ -0,0 +1,516 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Distribution distance metrics for gene expression evaluation.
|
|
3
|
+
|
|
4
|
+
Provides Wasserstein, MMD, and Energy distance metrics with per-gene computation.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.stats import wasserstein_distance
|
|
10
|
+
from typing import Optional, Tuple
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
from .base_metric import DistributionMetric
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _ensure_2d(arr: np.ndarray) -> np.ndarray:
|
|
17
|
+
"""Ensure array is 2D (samples x genes)."""
|
|
18
|
+
arr = np.asarray(arr, dtype=np.float64)
|
|
19
|
+
if arr.ndim == 1:
|
|
20
|
+
arr = arr.reshape(-1, 1)
|
|
21
|
+
return arr
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Wasserstein1Distance(DistributionMetric):
|
|
25
|
+
"""
|
|
26
|
+
Wasserstein-1 (Earth Mover's) distance between distributions.
|
|
27
|
+
|
|
28
|
+
Measures the minimum amount of work to transform one distribution
|
|
29
|
+
into another. Computed per gene using 1D Wasserstein distance.
|
|
30
|
+
|
|
31
|
+
Lower values indicate more similar distributions.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
super().__init__(
|
|
36
|
+
name="wasserstein_1",
|
|
37
|
+
description="Wasserstein-1 (Earth Mover's) distance per gene"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def compute_per_gene(
|
|
41
|
+
self,
|
|
42
|
+
real: np.ndarray,
|
|
43
|
+
generated: np.ndarray,
|
|
44
|
+
) -> np.ndarray:
|
|
45
|
+
"""
|
|
46
|
+
Compute Wasserstein-1 distance for each gene.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
real : np.ndarray
|
|
51
|
+
Real data, shape (n_samples_real, n_genes)
|
|
52
|
+
generated : np.ndarray
|
|
53
|
+
Generated data, shape (n_samples_gen, n_genes)
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
np.ndarray
|
|
58
|
+
W1 distance per gene
|
|
59
|
+
"""
|
|
60
|
+
real = _ensure_2d(real)
|
|
61
|
+
generated = _ensure_2d(generated)
|
|
62
|
+
n_genes = real.shape[1]
|
|
63
|
+
|
|
64
|
+
distances = np.zeros(n_genes)
|
|
65
|
+
|
|
66
|
+
for i in range(n_genes):
|
|
67
|
+
r_vals = real[:, i]
|
|
68
|
+
g_vals = generated[:, i]
|
|
69
|
+
|
|
70
|
+
# Filter NaN values
|
|
71
|
+
r_vals = r_vals[~np.isnan(r_vals)]
|
|
72
|
+
g_vals = g_vals[~np.isnan(g_vals)]
|
|
73
|
+
|
|
74
|
+
if len(r_vals) == 0 or len(g_vals) == 0:
|
|
75
|
+
distances[i] = np.nan
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
distances[i] = wasserstein_distance(r_vals, g_vals)
|
|
79
|
+
|
|
80
|
+
return distances
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class Wasserstein2Distance(DistributionMetric):
|
|
84
|
+
"""
|
|
85
|
+
Wasserstein-2 distance (quadratic cost) between distributions.
|
|
86
|
+
|
|
87
|
+
Uses p=2 norm for transport cost. More sensitive to outliers than W1.
|
|
88
|
+
Computed per gene.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self, use_geomloss: bool = True):
|
|
92
|
+
"""
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
use_geomloss : bool
|
|
96
|
+
If True, use geomloss for GPU-accelerated computation.
|
|
97
|
+
Falls back to scipy otherwise.
|
|
98
|
+
"""
|
|
99
|
+
super().__init__(
|
|
100
|
+
name="wasserstein_2",
|
|
101
|
+
description="Wasserstein-2 distance per gene"
|
|
102
|
+
)
|
|
103
|
+
self.use_geomloss = use_geomloss
|
|
104
|
+
self._geomloss_available = False
|
|
105
|
+
|
|
106
|
+
if use_geomloss:
|
|
107
|
+
try:
|
|
108
|
+
import torch
|
|
109
|
+
from geomloss import SamplesLoss
|
|
110
|
+
self._geomloss_available = True
|
|
111
|
+
except ImportError:
|
|
112
|
+
warnings.warn(
|
|
113
|
+
"geomloss not available, falling back to scipy implementation"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def _w2_scipy(self, r_vals: np.ndarray, g_vals: np.ndarray) -> float:
|
|
117
|
+
"""Compute W2 using scipy (approximation via sorted quantiles)."""
|
|
118
|
+
# Sort values and compute quadratic Wasserstein
|
|
119
|
+
r_sorted = np.sort(r_vals)
|
|
120
|
+
g_sorted = np.sort(g_vals)
|
|
121
|
+
|
|
122
|
+
# Resample to same length for comparison
|
|
123
|
+
n = max(len(r_sorted), len(g_sorted))
|
|
124
|
+
r_quantiles = np.interp(
|
|
125
|
+
np.linspace(0, 1, n),
|
|
126
|
+
np.linspace(0, 1, len(r_sorted)),
|
|
127
|
+
r_sorted
|
|
128
|
+
)
|
|
129
|
+
g_quantiles = np.interp(
|
|
130
|
+
np.linspace(0, 1, n),
|
|
131
|
+
np.linspace(0, 1, len(g_sorted)),
|
|
132
|
+
g_sorted
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return np.sqrt(np.mean((r_quantiles - g_quantiles) ** 2))
|
|
136
|
+
|
|
137
|
+
def compute_per_gene(
|
|
138
|
+
self,
|
|
139
|
+
real: np.ndarray,
|
|
140
|
+
generated: np.ndarray,
|
|
141
|
+
) -> np.ndarray:
|
|
142
|
+
"""
|
|
143
|
+
Compute Wasserstein-2 distance for each gene.
|
|
144
|
+
"""
|
|
145
|
+
real = _ensure_2d(real)
|
|
146
|
+
generated = _ensure_2d(generated)
|
|
147
|
+
n_genes = real.shape[1]
|
|
148
|
+
|
|
149
|
+
distances = np.zeros(n_genes)
|
|
150
|
+
|
|
151
|
+
if self._geomloss_available and self.use_geomloss:
|
|
152
|
+
import torch
|
|
153
|
+
from geomloss import SamplesLoss
|
|
154
|
+
loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.01, backend="tensorized")
|
|
155
|
+
|
|
156
|
+
for i in range(n_genes):
|
|
157
|
+
r_vals = real[:, i]
|
|
158
|
+
g_vals = generated[:, i]
|
|
159
|
+
|
|
160
|
+
r_vals = r_vals[~np.isnan(r_vals)]
|
|
161
|
+
g_vals = g_vals[~np.isnan(g_vals)]
|
|
162
|
+
|
|
163
|
+
if len(r_vals) == 0 or len(g_vals) == 0:
|
|
164
|
+
distances[i] = np.nan
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
# Reshape for geomloss (N, D)
|
|
168
|
+
r_tensor = torch.tensor(r_vals.reshape(-1, 1), dtype=torch.float32)
|
|
169
|
+
g_tensor = torch.tensor(g_vals.reshape(-1, 1), dtype=torch.float32)
|
|
170
|
+
|
|
171
|
+
distances[i] = loss_fn(r_tensor, g_tensor).item()
|
|
172
|
+
else:
|
|
173
|
+
for i in range(n_genes):
|
|
174
|
+
r_vals = real[:, i]
|
|
175
|
+
g_vals = generated[:, i]
|
|
176
|
+
|
|
177
|
+
r_vals = r_vals[~np.isnan(r_vals)]
|
|
178
|
+
g_vals = g_vals[~np.isnan(g_vals)]
|
|
179
|
+
|
|
180
|
+
if len(r_vals) == 0 or len(g_vals) == 0:
|
|
181
|
+
distances[i] = np.nan
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
distances[i] = self._w2_scipy(r_vals, g_vals)
|
|
185
|
+
|
|
186
|
+
return distances
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class MMDDistance(DistributionMetric):
|
|
190
|
+
"""
|
|
191
|
+
Maximum Mean Discrepancy (MMD) between distributions.
|
|
192
|
+
|
|
193
|
+
Non-parametric distance based on kernel embeddings.
|
|
194
|
+
Uses RBF (Gaussian) kernel. Computed per gene.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def __init__(self, kernel: str = "rbf", sigma: Optional[float] = None):
|
|
198
|
+
"""
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
kernel : str
|
|
202
|
+
Kernel type ("rbf" for Gaussian)
|
|
203
|
+
sigma : float, optional
|
|
204
|
+
Kernel bandwidth. If None, uses median heuristic.
|
|
205
|
+
"""
|
|
206
|
+
super().__init__(
|
|
207
|
+
name="mmd",
|
|
208
|
+
description="Maximum Mean Discrepancy with RBF kernel"
|
|
209
|
+
)
|
|
210
|
+
self.kernel = kernel
|
|
211
|
+
self.sigma = sigma
|
|
212
|
+
|
|
213
|
+
def _rbf_kernel(
|
|
214
|
+
self,
|
|
215
|
+
x: np.ndarray,
|
|
216
|
+
y: np.ndarray,
|
|
217
|
+
sigma: float
|
|
218
|
+
) -> np.ndarray:
|
|
219
|
+
"""Compute RBF kernel matrix."""
|
|
220
|
+
x = x.reshape(-1, 1) if x.ndim == 1 else x
|
|
221
|
+
y = y.reshape(-1, 1) if y.ndim == 1 else y
|
|
222
|
+
|
|
223
|
+
# Compute pairwise squared distances
|
|
224
|
+
diff = x[:, np.newaxis, :] - y[np.newaxis, :, :]
|
|
225
|
+
sq_dist = np.sum(diff ** 2, axis=-1)
|
|
226
|
+
|
|
227
|
+
return np.exp(-sq_dist / (2 * sigma ** 2))
|
|
228
|
+
|
|
229
|
+
def _median_heuristic(self, x: np.ndarray, y: np.ndarray) -> float:
|
|
230
|
+
"""Compute bandwidth using median heuristic."""
|
|
231
|
+
combined = np.concatenate([x, y])
|
|
232
|
+
pairwise = np.abs(combined[:, np.newaxis] - combined[np.newaxis, :])
|
|
233
|
+
return float(np.median(pairwise[pairwise > 0]))
|
|
234
|
+
|
|
235
|
+
def _compute_mmd_single(
|
|
236
|
+
self,
|
|
237
|
+
x: np.ndarray,
|
|
238
|
+
y: np.ndarray,
|
|
239
|
+
sigma: Optional[float] = None
|
|
240
|
+
) -> float:
|
|
241
|
+
"""Compute MMD for single gene."""
|
|
242
|
+
if sigma is None:
|
|
243
|
+
sigma = self._median_heuristic(x, y)
|
|
244
|
+
if sigma == 0:
|
|
245
|
+
sigma = 1.0
|
|
246
|
+
|
|
247
|
+
K_xx = self._rbf_kernel(x, x, sigma)
|
|
248
|
+
K_yy = self._rbf_kernel(y, y, sigma)
|
|
249
|
+
K_xy = self._rbf_kernel(x, y, sigma)
|
|
250
|
+
|
|
251
|
+
n_x = len(x)
|
|
252
|
+
n_y = len(y)
|
|
253
|
+
|
|
254
|
+
# Unbiased MMD estimator
|
|
255
|
+
mmd = (
|
|
256
|
+
(np.sum(K_xx) - np.trace(K_xx)) / (n_x * (n_x - 1)) +
|
|
257
|
+
(np.sum(K_yy) - np.trace(K_yy)) / (n_y * (n_y - 1)) -
|
|
258
|
+
2 * np.sum(K_xy) / (n_x * n_y)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
return max(0, mmd) # Ensure non-negative
|
|
262
|
+
|
|
263
|
+
def compute_per_gene(
|
|
264
|
+
self,
|
|
265
|
+
real: np.ndarray,
|
|
266
|
+
generated: np.ndarray,
|
|
267
|
+
) -> np.ndarray:
|
|
268
|
+
"""
|
|
269
|
+
Compute MMD for each gene.
|
|
270
|
+
"""
|
|
271
|
+
real = _ensure_2d(real)
|
|
272
|
+
generated = _ensure_2d(generated)
|
|
273
|
+
n_genes = real.shape[1]
|
|
274
|
+
|
|
275
|
+
distances = np.zeros(n_genes)
|
|
276
|
+
|
|
277
|
+
for i in range(n_genes):
|
|
278
|
+
r_vals = real[:, i]
|
|
279
|
+
g_vals = generated[:, i]
|
|
280
|
+
|
|
281
|
+
r_vals = r_vals[~np.isnan(r_vals)]
|
|
282
|
+
g_vals = g_vals[~np.isnan(g_vals)]
|
|
283
|
+
|
|
284
|
+
if len(r_vals) < 2 or len(g_vals) < 2:
|
|
285
|
+
distances[i] = np.nan
|
|
286
|
+
continue
|
|
287
|
+
|
|
288
|
+
distances[i] = self._compute_mmd_single(r_vals, g_vals, self.sigma)
|
|
289
|
+
|
|
290
|
+
return distances
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class EnergyDistance(DistributionMetric):
|
|
294
|
+
"""
|
|
295
|
+
Energy distance between distributions.
|
|
296
|
+
|
|
297
|
+
Based on statistical potential energy. Related to but different from
|
|
298
|
+
Wasserstein distance. Computed per gene.
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
def __init__(self, use_geomloss: bool = True):
|
|
302
|
+
super().__init__(
|
|
303
|
+
name="energy",
|
|
304
|
+
description="Energy distance per gene"
|
|
305
|
+
)
|
|
306
|
+
self.use_geomloss = use_geomloss
|
|
307
|
+
self._geomloss_available = False
|
|
308
|
+
|
|
309
|
+
if use_geomloss:
|
|
310
|
+
try:
|
|
311
|
+
import torch
|
|
312
|
+
from geomloss import SamplesLoss
|
|
313
|
+
self._geomloss_available = True
|
|
314
|
+
except ImportError:
|
|
315
|
+
pass
|
|
316
|
+
|
|
317
|
+
def _energy_scipy(self, x: np.ndarray, y: np.ndarray) -> float:
|
|
318
|
+
"""Compute energy distance using scipy."""
|
|
319
|
+
n_x, n_y = len(x), len(y)
|
|
320
|
+
|
|
321
|
+
# E[|X - Y|]
|
|
322
|
+
xy_dist = np.mean(np.abs(x[:, np.newaxis] - y[np.newaxis, :]))
|
|
323
|
+
|
|
324
|
+
# E[|X - X'|]
|
|
325
|
+
xx_dist = np.mean(np.abs(x[:, np.newaxis] - x[np.newaxis, :]))
|
|
326
|
+
|
|
327
|
+
# E[|Y - Y'|]
|
|
328
|
+
yy_dist = np.mean(np.abs(y[:, np.newaxis] - y[np.newaxis, :]))
|
|
329
|
+
|
|
330
|
+
energy = 2 * xy_dist - xx_dist - yy_dist
|
|
331
|
+
return max(0, energy)
|
|
332
|
+
|
|
333
|
+
def compute_per_gene(
|
|
334
|
+
self,
|
|
335
|
+
real: np.ndarray,
|
|
336
|
+
generated: np.ndarray,
|
|
337
|
+
) -> np.ndarray:
|
|
338
|
+
"""
|
|
339
|
+
Compute energy distance for each gene.
|
|
340
|
+
"""
|
|
341
|
+
real = _ensure_2d(real)
|
|
342
|
+
generated = _ensure_2d(generated)
|
|
343
|
+
n_genes = real.shape[1]
|
|
344
|
+
|
|
345
|
+
distances = np.zeros(n_genes)
|
|
346
|
+
|
|
347
|
+
if self._geomloss_available and self.use_geomloss:
|
|
348
|
+
import torch
|
|
349
|
+
from geomloss import SamplesLoss
|
|
350
|
+
loss_fn = SamplesLoss(loss="energy", blur=0.5, backend="tensorized")
|
|
351
|
+
|
|
352
|
+
for i in range(n_genes):
|
|
353
|
+
r_vals = real[:, i]
|
|
354
|
+
g_vals = generated[:, i]
|
|
355
|
+
|
|
356
|
+
r_vals = r_vals[~np.isnan(r_vals)]
|
|
357
|
+
g_vals = g_vals[~np.isnan(g_vals)]
|
|
358
|
+
|
|
359
|
+
if len(r_vals) == 0 or len(g_vals) == 0:
|
|
360
|
+
distances[i] = np.nan
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
r_tensor = torch.tensor(r_vals.reshape(-1, 1), dtype=torch.float32)
|
|
364
|
+
g_tensor = torch.tensor(g_vals.reshape(-1, 1), dtype=torch.float32)
|
|
365
|
+
|
|
366
|
+
distances[i] = loss_fn(r_tensor, g_tensor).item()
|
|
367
|
+
else:
|
|
368
|
+
for i in range(n_genes):
|
|
369
|
+
r_vals = real[:, i]
|
|
370
|
+
g_vals = generated[:, i]
|
|
371
|
+
|
|
372
|
+
r_vals = r_vals[~np.isnan(r_vals)]
|
|
373
|
+
g_vals = g_vals[~np.isnan(g_vals)]
|
|
374
|
+
|
|
375
|
+
if len(r_vals) < 2 or len(g_vals) < 2:
|
|
376
|
+
distances[i] = np.nan
|
|
377
|
+
continue
|
|
378
|
+
|
|
379
|
+
distances[i] = self._energy_scipy(r_vals, g_vals)
|
|
380
|
+
|
|
381
|
+
return distances
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
# Multivariate distance metrics (computed on full gene space)
|
|
385
|
+
|
|
386
|
+
class MultivariateWasserstein(DistributionMetric):
|
|
387
|
+
"""
|
|
388
|
+
Multivariate Wasserstein distance on full gene expression space.
|
|
389
|
+
|
|
390
|
+
Unlike per-gene metrics, this computes distance in the joint space
|
|
391
|
+
of all genes. Typically applied after PCA dimensionality reduction.
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
def __init__(self, p: int = 2, blur: float = 0.01):
|
|
395
|
+
super().__init__(
|
|
396
|
+
name="multivariate_wasserstein",
|
|
397
|
+
description=f"Multivariate Wasserstein-{p} distance"
|
|
398
|
+
)
|
|
399
|
+
self.p = p
|
|
400
|
+
self.blur = blur
|
|
401
|
+
|
|
402
|
+
def compute_per_gene(
|
|
403
|
+
self,
|
|
404
|
+
real: np.ndarray,
|
|
405
|
+
generated: np.ndarray,
|
|
406
|
+
) -> np.ndarray:
|
|
407
|
+
"""
|
|
408
|
+
Compute multivariate distance (returns same value for all genes).
|
|
409
|
+
"""
|
|
410
|
+
real = _ensure_2d(real)
|
|
411
|
+
generated = _ensure_2d(generated)
|
|
412
|
+
n_genes = real.shape[1]
|
|
413
|
+
|
|
414
|
+
try:
|
|
415
|
+
import torch
|
|
416
|
+
from geomloss import SamplesLoss
|
|
417
|
+
|
|
418
|
+
loss_fn = SamplesLoss(
|
|
419
|
+
loss="sinkhorn",
|
|
420
|
+
p=self.p,
|
|
421
|
+
blur=self.blur,
|
|
422
|
+
backend="tensorized"
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
r_tensor = torch.tensor(real, dtype=torch.float32)
|
|
426
|
+
g_tensor = torch.tensor(generated, dtype=torch.float32)
|
|
427
|
+
|
|
428
|
+
distance = loss_fn(r_tensor, g_tensor).item()
|
|
429
|
+
except ImportError:
|
|
430
|
+
# Fallback: use sliced Wasserstein approximation
|
|
431
|
+
warnings.warn("geomloss not available, using sliced Wasserstein approximation")
|
|
432
|
+
distance = self._sliced_wasserstein(real, generated)
|
|
433
|
+
|
|
434
|
+
return np.full(n_genes, distance)
|
|
435
|
+
|
|
436
|
+
def _sliced_wasserstein(
|
|
437
|
+
self,
|
|
438
|
+
x: np.ndarray,
|
|
439
|
+
y: np.ndarray,
|
|
440
|
+
n_projections: int = 100
|
|
441
|
+
) -> float:
|
|
442
|
+
"""Compute sliced Wasserstein distance as fallback."""
|
|
443
|
+
d = x.shape[1]
|
|
444
|
+
|
|
445
|
+
# Random projections
|
|
446
|
+
projections = np.random.randn(d, n_projections)
|
|
447
|
+
projections /= np.linalg.norm(projections, axis=0)
|
|
448
|
+
|
|
449
|
+
distances = []
|
|
450
|
+
for i in range(n_projections):
|
|
451
|
+
proj = projections[:, i]
|
|
452
|
+
x_proj = x @ proj
|
|
453
|
+
y_proj = y @ proj
|
|
454
|
+
distances.append(wasserstein_distance(x_proj, y_proj))
|
|
455
|
+
|
|
456
|
+
return float(np.mean(distances))
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
class MultivariateMMD(DistributionMetric):
|
|
460
|
+
"""
|
|
461
|
+
Multivariate MMD on full gene expression space.
|
|
462
|
+
"""
|
|
463
|
+
|
|
464
|
+
def __init__(self, sigma: Optional[float] = None):
|
|
465
|
+
super().__init__(
|
|
466
|
+
name="multivariate_mmd",
|
|
467
|
+
description="Multivariate MMD with RBF kernel"
|
|
468
|
+
)
|
|
469
|
+
self.sigma = sigma
|
|
470
|
+
|
|
471
|
+
def compute_per_gene(
|
|
472
|
+
self,
|
|
473
|
+
real: np.ndarray,
|
|
474
|
+
generated: np.ndarray,
|
|
475
|
+
) -> np.ndarray:
|
|
476
|
+
"""
|
|
477
|
+
Compute multivariate MMD.
|
|
478
|
+
"""
|
|
479
|
+
real = _ensure_2d(real)
|
|
480
|
+
generated = _ensure_2d(generated)
|
|
481
|
+
n_genes = real.shape[1]
|
|
482
|
+
|
|
483
|
+
# Use median heuristic for bandwidth
|
|
484
|
+
if self.sigma is None:
|
|
485
|
+
combined = np.vstack([real, generated])
|
|
486
|
+
pairwise_sq = np.sum(
|
|
487
|
+
(combined[:, np.newaxis, :] - combined[np.newaxis, :, :]) ** 2,
|
|
488
|
+
axis=-1
|
|
489
|
+
)
|
|
490
|
+
sigma = float(np.sqrt(np.median(pairwise_sq[pairwise_sq > 0])))
|
|
491
|
+
if sigma == 0:
|
|
492
|
+
sigma = 1.0
|
|
493
|
+
else:
|
|
494
|
+
sigma = self.sigma
|
|
495
|
+
|
|
496
|
+
# Compute kernel matrices
|
|
497
|
+
def rbf_kernel(x, y, sigma):
|
|
498
|
+
pairwise_sq = np.sum(
|
|
499
|
+
(x[:, np.newaxis, :] - y[np.newaxis, :, :]) ** 2,
|
|
500
|
+
axis=-1
|
|
501
|
+
)
|
|
502
|
+
return np.exp(-pairwise_sq / (2 * sigma ** 2))
|
|
503
|
+
|
|
504
|
+
K_xx = rbf_kernel(real, real, sigma)
|
|
505
|
+
K_yy = rbf_kernel(generated, generated, sigma)
|
|
506
|
+
K_xy = rbf_kernel(real, generated, sigma)
|
|
507
|
+
|
|
508
|
+
n_x, n_y = len(real), len(generated)
|
|
509
|
+
|
|
510
|
+
mmd = (
|
|
511
|
+
(np.sum(K_xx) - np.trace(K_xx)) / (n_x * (n_x - 1)) +
|
|
512
|
+
(np.sum(K_yy) - np.trace(K_yy)) / (n_y * (n_y - 1)) -
|
|
513
|
+
2 * np.sum(K_xy) / (n_x * n_y)
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
return np.full(n_genes, max(0, mmd))
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from geomloss import SamplesLoss
|
|
2
|
+
import anndata as ad
|
|
3
|
+
import scanpy as sc
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.stats import pearsonr, spearmanr
|
|
7
|
+
import torch
|
|
8
|
+
from . import metric_MMD
|
|
9
|
+
|
|
10
|
+
class Metric():
|
|
11
|
+
def __init__(self, name: str, fn):
|
|
12
|
+
self.name = name
|
|
13
|
+
self.fn = fn
|
|
14
|
+
|
|
15
|
+
def compute(self, x, y):
|
|
16
|
+
return self.fn(x, y)
|
|
17
|
+
|
|
18
|
+
class PerturbationMetric():
|
|
19
|
+
def __init__(self, name: str, fn):
|
|
20
|
+
self.name = name
|
|
21
|
+
self.fn = fn
|
|
22
|
+
|
|
23
|
+
def compute(self, adata_true: ad.AnnData, adata_generated: ad.AnnData, groupby: str):
|
|
24
|
+
return self.fn(adata_true, adata_generated, groupby)
|
|
25
|
+
|
|
26
|
+
def compute_metrics(original_data, generated_data, metric_fn):
|
|
27
|
+
metric_funcs = {
|
|
28
|
+
'w1': SamplesLoss(loss="sinkhorn", p=1, blur=0.01),
|
|
29
|
+
'w2': SamplesLoss(loss="sinkhorn", p=2, blur=0.01),
|
|
30
|
+
'mmd': metric_MMD.iface_compute_MMD,
|
|
31
|
+
'energy': SamplesLoss(loss="energy", blur=0.5),
|
|
32
|
+
}
|
|
33
|
+
metric_fn = metric_funcs[metric_fn]
|
|
34
|
+
original_data = torch.tensor(original_data)
|
|
35
|
+
generated_data = torch.tensor(generated_data)
|
|
36
|
+
metric = metric_fn(generated_data, original_data)
|
|
37
|
+
return metric.item()
|
|
38
|
+
|
|
39
|
+
def W1(x, y):
|
|
40
|
+
loss_fn = SamplesLoss(loss="sinkhorn", p=1, blur=0.01, backend="tensorized")
|
|
41
|
+
return loss_fn(x, y).item()
|
|
42
|
+
|
|
43
|
+
def W2(x, y):
|
|
44
|
+
loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.01, backend="tensorized")
|
|
45
|
+
return loss_fn(x, y).item()
|
|
46
|
+
|
|
47
|
+
def W1_complete(x, y, preprocess=False):
|
|
48
|
+
if preprocess:
|
|
49
|
+
x = scanpy_preprocessing(x)
|
|
50
|
+
y = scanpy_preprocessing(y)
|
|
51
|
+
|
|
52
|
+
x_reduced = scanpy_pca(x)
|
|
53
|
+
y_reduced = scanpy_pca(y)
|
|
54
|
+
|
|
55
|
+
x_pca = torch.tensor(x_reduced.obsm['X_pca'], dtype=torch.float32)
|
|
56
|
+
y_pca = torch.tensor(y_reduced.obsm['X_pca'], dtype=torch.float32)
|
|
57
|
+
|
|
58
|
+
return W1(x_pca, y_pca)
|
|
59
|
+
|
|
60
|
+
def get_deg_genes(adata: ad.AnnData, groupby: str = "condition_ID", method: str = "wilcoxon", alpha: float = 0.05):
|
|
61
|
+
sc.tl.rank_genes_groups(adata, groupby=groupby, method=method, use_raw=False, n_genes=adata.shape[1])
|
|
62
|
+
|
|
63
|
+
degs = set()
|
|
64
|
+
rg_results = adata.uns["rank_genes_groups"]
|
|
65
|
+
|
|
66
|
+
for group in rg_results["names"].dtype.names:
|
|
67
|
+
pvals_adj = rg_results["pvals_adj"][group]
|
|
68
|
+
genes = rg_results["names"][group]
|
|
69
|
+
|
|
70
|
+
for gene, pval in zip(genes, pvals_adj):
|
|
71
|
+
if pval < alpha:
|
|
72
|
+
degs.add(gene)
|
|
73
|
+
|
|
74
|
+
return degs
|
|
75
|
+
|
|
76
|
+
def get_avg_expression(adata: ad.AnnData, genes: set) -> pd.Series:
|
|
77
|
+
common_genes = list(set(adata.var_names).intersection(genes))
|
|
78
|
+
if len(common_genes) == 0:
|
|
79
|
+
return pd.Series(dtype=float)
|
|
80
|
+
|
|
81
|
+
sub_adata = adata[:, common_genes]
|
|
82
|
+
avg_exp = np.array(sub_adata.X.mean(axis=0)).ravel()
|
|
83
|
+
|
|
84
|
+
return pd.Series(data=avg_exp, index=common_genes)
|
|
85
|
+
|
|
86
|
+
def pearson_dict(x, y):
|
|
87
|
+
common_keys = set(x.keys()).intersection(y.keys())
|
|
88
|
+
true_values = [x[key] for key in common_keys]
|
|
89
|
+
calculated_values = [y[key] for key in common_keys]
|
|
90
|
+
correlation, _ = pearsonr(true_values, calculated_values)
|
|
91
|
+
|
|
92
|
+
return correlation
|
|
93
|
+
|
|
94
|
+
def spearman_dict(x, y):
|
|
95
|
+
common_keys = set(x.keys()).intersection(y.keys())
|
|
96
|
+
true_values = [x[key] for key in common_keys]
|
|
97
|
+
calculated_values = [y[key] for key in common_keys]
|
|
98
|
+
correlation, _ = spearmanr(true_values, calculated_values)
|
|
99
|
+
|
|
100
|
+
return correlation
|
|
101
|
+
|
|
102
|
+
def mse_dict(x, y):
|
|
103
|
+
common_keys = set(x.keys()).intersection(y.keys())
|
|
104
|
+
true_values = np.array([x[key] for key in common_keys])
|
|
105
|
+
calculated_values = np.array([y[key] for key in common_keys])
|
|
106
|
+
mse = np.mean((true_values - calculated_values) ** 2)
|
|
107
|
+
|
|
108
|
+
return mse
|
|
109
|
+
|
|
110
|
+
def compute_pearson(x, y):
|
|
111
|
+
common_genes = x.index.intersection(y.index)
|
|
112
|
+
|
|
113
|
+
if len(common_genes) == 0:
|
|
114
|
+
return float('nan')
|
|
115
|
+
|
|
116
|
+
x_vals = x.loc[common_genes].values
|
|
117
|
+
y_vals = y.loc[common_genes].values
|
|
118
|
+
|
|
119
|
+
pearson_corr, _ = pearsonr(x_vals, y_vals)
|
|
120
|
+
|
|
121
|
+
return pearson_corr
|
|
122
|
+
|
|
123
|
+
def compute_spearman(x, y):
|
|
124
|
+
common_genes = x.index.intersection(y.index)
|
|
125
|
+
|
|
126
|
+
if len(common_genes) == 0:
|
|
127
|
+
return float('nan')
|
|
128
|
+
|
|
129
|
+
x_vals = x.loc[common_genes].values
|
|
130
|
+
y_vals = y.loc[common_genes].values
|
|
131
|
+
|
|
132
|
+
spearman_corr, _ = spearmanr(x_vals, y_vals)
|
|
133
|
+
|
|
134
|
+
return spearman_corr
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# This file initializes the models module.
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
class BaseModel:
|
|
2
|
+
"""
|
|
3
|
+
Base class for all models in the gene expression evaluation system.
|
|
4
|
+
|
|
5
|
+
This class provides a foundation for model classes that may be implemented in the future.
|
|
6
|
+
It can include common methods and attributes that all models should have.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
def __init__(self):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
def fit(self, data):
|
|
13
|
+
"""
|
|
14
|
+
Fit the model to the provided data.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
data : Any
|
|
19
|
+
The data to fit the model on.
|
|
20
|
+
"""
|
|
21
|
+
raise NotImplementedError("Subclasses should implement this method.")
|
|
22
|
+
|
|
23
|
+
def predict(self, data):
|
|
24
|
+
"""
|
|
25
|
+
Make predictions using the fitted model.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
data : Any
|
|
30
|
+
The data to make predictions on.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
Any
|
|
35
|
+
The predictions made by the model.
|
|
36
|
+
"""
|
|
37
|
+
raise NotImplementedError("Subclasses should implement this method.")
|
|
38
|
+
|
|
39
|
+
def evaluate(self, data):
|
|
40
|
+
"""
|
|
41
|
+
Evaluate the model's performance on the provided data.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
data : Any
|
|
46
|
+
The data to evaluate the model on.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
Any
|
|
51
|
+
The evaluation metrics.
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError("Subclasses should implement this method.")
|