pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,35 +1,46 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import multiprocessing
|
3
4
|
from abc import ABC, abstractmethod
|
4
|
-
from typing import TYPE_CHECKING
|
5
|
+
from typing import TYPE_CHECKING, Literal, NamedTuple
|
5
6
|
|
7
|
+
import numba
|
6
8
|
import numpy as np
|
7
9
|
import pandas as pd
|
8
10
|
from ott.geometry.geometry import Geometry
|
9
11
|
from ott.geometry.pointcloud import PointCloud
|
10
12
|
from ott.problems.linear.linear_problem import LinearProblem
|
11
13
|
from ott.solvers.linear.sinkhorn import Sinkhorn
|
14
|
+
from pandas import Series
|
12
15
|
from rich.progress import track
|
13
16
|
from scipy.sparse import issparse
|
14
|
-
from scipy.spatial.distance import cosine
|
17
|
+
from scipy.spatial.distance import cosine, mahalanobis
|
15
18
|
from scipy.special import gammaln
|
16
|
-
from scipy.stats import kendalltau, pearsonr, spearmanr
|
19
|
+
from scipy.stats import kendalltau, kstest, pearsonr, spearmanr
|
20
|
+
from sklearn.linear_model import LogisticRegression
|
17
21
|
from sklearn.metrics import pairwise_distances, r2_score
|
18
22
|
from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel
|
23
|
+
from sklearn.neighbors import KernelDensity
|
19
24
|
from statsmodels.discrete.discrete_model import NegativeBinomialP
|
20
25
|
|
21
26
|
if TYPE_CHECKING:
|
22
|
-
from collections.abc import
|
27
|
+
from collections.abc import Callable
|
23
28
|
|
24
29
|
from anndata import AnnData
|
25
30
|
|
26
31
|
|
32
|
+
class MeanVar(NamedTuple):
|
33
|
+
mean: float
|
34
|
+
variance: float
|
35
|
+
|
36
|
+
|
27
37
|
class Distance:
|
28
38
|
"""Distance class, used to compute distances between groups of cells.
|
29
39
|
|
30
40
|
The distance metric can be specified by the user. This class also provides a
|
31
41
|
method to compute the pairwise distances between all groups of cells.
|
32
42
|
Currently available metrics:
|
43
|
+
|
33
44
|
- "edistance": Energy distance (Default metric).
|
34
45
|
In essence, it is twice the mean pairwise distance between cells of two
|
35
46
|
groups minus the mean pairwise distance between cells within each group
|
@@ -55,8 +66,6 @@ class Distance:
|
|
55
66
|
Coefficient of determination distance between the means of cells from two groups.
|
56
67
|
- "mean_pairwise": Mean pairwise distance.
|
57
68
|
Mean of the pairwise euclidean distances between cells of two groups.
|
58
|
-
- "mean_pairwise": Mean pairwise distance.
|
59
|
-
Mean of the pairwise euclidean distances between cells of two groups.
|
60
69
|
- "mmd": Maximum mean discrepancy
|
61
70
|
Maximum mean discrepancy between the cells of two groups.
|
62
71
|
Here, uses linear, rbf, and quadratic polynomial MMD. For theory on MMD in single-cell applications, see
|
@@ -66,14 +75,25 @@ class Distance:
|
|
66
75
|
OTT-JAX implementation of the Sinkhorn algorithm to compute the distance.
|
67
76
|
For more information on the optimal transport solver, see
|
68
77
|
`Cuturi et al. (2013) <https://proceedings.neurips.cc/paper/2013/file/af21d0c97db2e27e13572cbf59eb343d-Paper.pdf>`__.
|
69
|
-
- "
|
78
|
+
- "sym_kldiv": symmetrized Kullback–Leibler divergence distance.
|
70
79
|
Kullback–Leibler divergence of the gaussian distributions between cells of two groups.
|
71
|
-
Here we fit a gaussian distribution over
|
80
|
+
Here we fit a gaussian distribution over one group of cells and then calculate the KL divergence on the other, and vice versa.
|
72
81
|
- "t_test": t-test statistic.
|
73
82
|
T-test statistic measure between cells of two groups.
|
83
|
+
- "ks_test": Kolmogorov-Smirnov test statistic.
|
84
|
+
Kolmogorov-Smirnov test statistic measure between cells of two groups.
|
74
85
|
- "nb_ll": log-likelihood over negative binomial
|
75
86
|
Average of log-likelihoods of samples of the secondary group after fitting a negative binomial distribution
|
76
87
|
over the samples of the first group.
|
88
|
+
- "classifier_proba": probability of a binary classifier
|
89
|
+
Average of the classification probability of the perturbation for a binary classifier.
|
90
|
+
- "classifier_cp": classifier class projection
|
91
|
+
Average of the class
|
92
|
+
- "mean_var_distribution": Distance between mean-variance distributions between cells of 2 groups.
|
93
|
+
Mean square distance between the mean-variance distributions of cells from 2 groups using Kernel Density Estimation (KDE).
|
94
|
+
- "mahalanobis": Mahalanobis distance between the means of cells from two groups.
|
95
|
+
It is originally used to measure distance between a point and a distribution.
|
96
|
+
in this context, it quantifies the difference between the mean profiles of a target group and a reference group.
|
77
97
|
|
78
98
|
Attributes:
|
79
99
|
metric: Name of distance metric.
|
@@ -93,6 +113,7 @@ class Distance:
|
|
93
113
|
def __init__(
|
94
114
|
self,
|
95
115
|
metric: str = "edistance",
|
116
|
+
agg_fct: Callable = np.mean,
|
96
117
|
layer_key: str = None,
|
97
118
|
obsm_key: str = None,
|
98
119
|
cell_wise_metric: str = "euclidean",
|
@@ -100,57 +121,67 @@ class Distance:
|
|
100
121
|
"""Initialize Distance class.
|
101
122
|
|
102
123
|
Args:
|
103
|
-
metric: Distance metric to use.
|
124
|
+
metric: Distance metric to use.
|
125
|
+
agg_fct: Aggregation function to generate pseudobulk vectors.
|
104
126
|
layer_key: Name of the counts layer containing raw counts to calculate distances for.
|
105
127
|
Mutually exclusive with 'obsm_key'.
|
106
|
-
|
128
|
+
Is not used if `None`.
|
107
129
|
obsm_key: Name of embedding in adata.obsm to use.
|
108
|
-
Mutually exclusive with '
|
109
|
-
Defaults to None, but is set to "X_pca" if not set
|
130
|
+
Mutually exclusive with 'layer_key'.
|
131
|
+
Defaults to None, but is set to "X_pca" if not explicitly set internally.
|
110
132
|
cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells.
|
111
|
-
Defaults to "euclidean".
|
112
133
|
"""
|
113
134
|
metric_fct: AbstractDistance = None
|
135
|
+
self.aggregation_func = agg_fct
|
114
136
|
if metric == "edistance":
|
115
137
|
metric_fct = Edistance()
|
116
138
|
elif metric == "euclidean":
|
117
|
-
metric_fct = EuclideanDistance()
|
139
|
+
metric_fct = EuclideanDistance(self.aggregation_func)
|
118
140
|
elif metric == "root_mean_squared_error":
|
119
|
-
metric_fct = EuclideanDistance()
|
141
|
+
metric_fct = EuclideanDistance(self.aggregation_func)
|
120
142
|
elif metric == "mse":
|
121
|
-
metric_fct = MeanSquaredDistance()
|
143
|
+
metric_fct = MeanSquaredDistance(self.aggregation_func)
|
122
144
|
elif metric == "mean_absolute_error":
|
123
|
-
metric_fct = MeanAbsoluteDistance()
|
145
|
+
metric_fct = MeanAbsoluteDistance(self.aggregation_func)
|
124
146
|
elif metric == "pearson_distance":
|
125
|
-
metric_fct = PearsonDistance()
|
147
|
+
metric_fct = PearsonDistance(self.aggregation_func)
|
126
148
|
elif metric == "spearman_distance":
|
127
|
-
metric_fct = SpearmanDistance()
|
149
|
+
metric_fct = SpearmanDistance(self.aggregation_func)
|
128
150
|
elif metric == "kendalltau_distance":
|
129
|
-
metric_fct = KendallTauDistance()
|
151
|
+
metric_fct = KendallTauDistance(self.aggregation_func)
|
130
152
|
elif metric == "cosine_distance":
|
131
|
-
metric_fct = CosineDistance()
|
153
|
+
metric_fct = CosineDistance(self.aggregation_func)
|
132
154
|
elif metric == "r2_distance":
|
133
|
-
metric_fct = R2ScoreDistance()
|
155
|
+
metric_fct = R2ScoreDistance(self.aggregation_func)
|
134
156
|
elif metric == "mean_pairwise":
|
135
157
|
metric_fct = MeanPairwiseDistance()
|
136
158
|
elif metric == "mmd":
|
137
159
|
metric_fct = MMD()
|
138
160
|
elif metric == "wasserstein":
|
139
161
|
metric_fct = WassersteinDistance()
|
140
|
-
elif metric == "
|
141
|
-
metric_fct =
|
162
|
+
elif metric == "sym_kldiv":
|
163
|
+
metric_fct = SymmetricKLDivergence()
|
142
164
|
elif metric == "t_test":
|
143
165
|
metric_fct = TTestDistance()
|
166
|
+
elif metric == "ks_test":
|
167
|
+
metric_fct = KSTestDistance()
|
144
168
|
elif metric == "nb_ll":
|
145
169
|
metric_fct = NBLL()
|
170
|
+
elif metric == "classifier_proba":
|
171
|
+
metric_fct = ClassifierProbaDistance()
|
172
|
+
elif metric == "classifier_cp":
|
173
|
+
metric_fct = ClassifierClassProjection()
|
174
|
+
elif metric == "mean_var_distribution":
|
175
|
+
metric_fct = MeanVarDistributionDistance()
|
176
|
+
elif metric == "mahalanobis":
|
177
|
+
metric_fct = MahalanobisDistance(self.aggregation_func)
|
146
178
|
else:
|
147
179
|
raise ValueError(f"Metric {metric} not recognized.")
|
148
180
|
self.metric_fct = metric_fct
|
149
181
|
|
150
182
|
if layer_key and obsm_key:
|
151
183
|
raise ValueError(
|
152
|
-
"Cannot use '
|
153
|
-
"Please provide only one of the two keys."
|
184
|
+
"Cannot use 'layer_key' and 'obsm_key' at the same time.\n" "Please provide only one of the two keys."
|
154
185
|
)
|
155
186
|
if not layer_key and not obsm_key:
|
156
187
|
obsm_key = "X_pca"
|
@@ -183,37 +214,80 @@ class Distance:
|
|
183
214
|
>>> D = Distance(X, Y)
|
184
215
|
"""
|
185
216
|
if issparse(X):
|
186
|
-
X = X.
|
217
|
+
X = X.toarray()
|
187
218
|
if issparse(Y):
|
188
|
-
Y = Y.
|
219
|
+
Y = Y.toarray()
|
189
220
|
|
190
221
|
if len(X) == 0 or len(Y) == 0:
|
191
222
|
raise ValueError("Neither X nor Y can be empty.")
|
192
223
|
|
193
224
|
return self.metric_fct(X, Y, **kwargs)
|
194
225
|
|
226
|
+
def bootstrap(
|
227
|
+
self,
|
228
|
+
X: np.ndarray,
|
229
|
+
Y: np.ndarray,
|
230
|
+
*,
|
231
|
+
n_bootstrap: int = 100,
|
232
|
+
random_state: int = 0,
|
233
|
+
**kwargs,
|
234
|
+
) -> MeanVar:
|
235
|
+
"""Bootstrap computation of mean and variance of the distance between vectors X and Y.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
X: First vector of shape (n_samples, n_features).
|
239
|
+
Y: Second vector of shape (n_samples, n_features).
|
240
|
+
n_bootstrap: Number of bootstrap samples.
|
241
|
+
random_state: Random state for bootstrapping.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
MeanVar: Mean and variance of distance between X and Y.
|
245
|
+
|
246
|
+
Examples:
|
247
|
+
>>> import pertpy as pt
|
248
|
+
>>> adata = pt.dt.distance_example()
|
249
|
+
>>> Distance = pt.tools.Distance(metric="edistance")
|
250
|
+
>>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
|
251
|
+
>>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
|
252
|
+
>>> D = Distance.bootstrap(X, Y)
|
253
|
+
"""
|
254
|
+
return self._bootstrap_mode(
|
255
|
+
X,
|
256
|
+
Y,
|
257
|
+
n_bootstraps=n_bootstrap,
|
258
|
+
random_state=random_state,
|
259
|
+
**kwargs,
|
260
|
+
)
|
261
|
+
|
195
262
|
def pairwise(
|
196
263
|
self,
|
197
264
|
adata: AnnData,
|
198
265
|
groupby: str,
|
199
266
|
groups: list[str] | None = None,
|
267
|
+
bootstrap: bool = False,
|
268
|
+
n_bootstrap: int = 100,
|
269
|
+
random_state: int = 0,
|
200
270
|
show_progressbar: bool = True,
|
201
271
|
n_jobs: int = -1,
|
202
272
|
**kwargs,
|
203
|
-
) -> pd.DataFrame:
|
273
|
+
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
|
204
274
|
"""Get pairwise distances between groups of cells.
|
205
275
|
|
206
276
|
Args:
|
207
277
|
adata: Annotated data matrix.
|
208
278
|
groupby: Column name in adata.obs.
|
209
279
|
groups: List of groups to compute pairwise distances for.
|
210
|
-
If None, uses all groups.
|
211
|
-
|
280
|
+
If None, uses all groups.
|
281
|
+
bootstrap: Whether to bootstrap the distance.
|
282
|
+
n_bootstrap: Number of bootstrap samples.
|
283
|
+
random_state: Random state for bootstrapping.
|
284
|
+
show_progressbar: Whether to show progress bar.
|
212
285
|
n_jobs: Number of cores to use. Defaults to -1 (all).
|
213
286
|
kwargs: Additional keyword arguments passed to the metric function.
|
214
287
|
|
215
288
|
Returns:
|
216
289
|
pd.DataFrame: Dataframe with pairwise distances.
|
290
|
+
tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of pairwise distances.
|
217
291
|
|
218
292
|
Examples:
|
219
293
|
>>> import pertpy as pt
|
@@ -224,6 +298,8 @@ class Distance:
|
|
224
298
|
groups = adata.obs[groupby].unique() if groups is None else groups
|
225
299
|
grouping = adata.obs[groupby].copy()
|
226
300
|
df = pd.DataFrame(index=groups, columns=groups, dtype=float)
|
301
|
+
if bootstrap:
|
302
|
+
df_var = pd.DataFrame(index=groups, columns=groups, dtype=float)
|
227
303
|
fct = track if show_progressbar else lambda iterable: iterable
|
228
304
|
|
229
305
|
# Some metrics are able to handle precomputed distances. This means that
|
@@ -239,16 +315,29 @@ class Distance:
|
|
239
315
|
for index_x, group_x in enumerate(fct(groups)):
|
240
316
|
idx_x = grouping == group_x
|
241
317
|
for group_y in groups[index_x:]: # type: ignore
|
242
|
-
|
243
|
-
|
318
|
+
# subset the pairwise distance matrix to the two groups
|
319
|
+
idx_y = grouping == group_y
|
320
|
+
sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
|
321
|
+
sub_idx = grouping[idx_x | idx_y] == group_x
|
322
|
+
if not bootstrap:
|
323
|
+
if group_x == group_y:
|
324
|
+
dist = 0.0
|
325
|
+
else:
|
326
|
+
dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
|
327
|
+
df.loc[group_x, group_y] = dist
|
328
|
+
df.loc[group_y, group_x] = dist
|
329
|
+
|
244
330
|
else:
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
331
|
+
bootstrap_output = self._bootstrap_mode_precomputed(
|
332
|
+
sub_pwd,
|
333
|
+
sub_idx,
|
334
|
+
n_bootstraps=n_bootstrap,
|
335
|
+
random_state=random_state,
|
336
|
+
**kwargs,
|
337
|
+
)
|
338
|
+
# In the bootstrap case, distance of group to itself is a mean and can be non-zero
|
339
|
+
df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
|
340
|
+
df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
|
252
341
|
else:
|
253
342
|
if self.layer_key:
|
254
343
|
embedding = adata.layers[self.layer_key]
|
@@ -257,18 +346,39 @@ class Distance:
|
|
257
346
|
for index_x, group_x in enumerate(fct(groups)):
|
258
347
|
cells_x = embedding[grouping == group_x].copy()
|
259
348
|
for group_y in groups[index_x:]: # type: ignore
|
260
|
-
|
261
|
-
|
349
|
+
cells_y = embedding[grouping == group_y].copy()
|
350
|
+
if not bootstrap:
|
351
|
+
# By distance axiom, the distance between a group and itself is 0
|
352
|
+
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
353
|
+
|
354
|
+
df.loc[group_x, group_y] = dist
|
355
|
+
df.loc[group_y, group_x] = dist
|
262
356
|
else:
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
357
|
+
bootstrap_output = self.bootstrap(
|
358
|
+
cells_x,
|
359
|
+
cells_y,
|
360
|
+
n_bootstrap=n_bootstrap,
|
361
|
+
random_state=random_state,
|
362
|
+
**kwargs,
|
363
|
+
)
|
364
|
+
# In the bootstrap case, distance of group to itself is a mean and can be non-zero
|
365
|
+
df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
|
366
|
+
df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
|
367
|
+
|
267
368
|
df.index.name = groupby
|
268
369
|
df.columns.name = groupby
|
269
370
|
df.name = f"pairwise {self.metric}"
|
270
371
|
|
271
|
-
|
372
|
+
if not bootstrap:
|
373
|
+
return df
|
374
|
+
else:
|
375
|
+
df = df.fillna(0)
|
376
|
+
df_var.index.name = groupby
|
377
|
+
df_var.columns.name = groupby
|
378
|
+
df_var = df_var.fillna(0)
|
379
|
+
df_var.name = f"pairwise {self.metric} variance"
|
380
|
+
|
381
|
+
return df, df_var
|
272
382
|
|
273
383
|
def onesided_distances(
|
274
384
|
self,
|
@@ -276,24 +386,32 @@ class Distance:
|
|
276
386
|
groupby: str,
|
277
387
|
selected_group: str | None = None,
|
278
388
|
groups: list[str] | None = None,
|
389
|
+
bootstrap: bool = False,
|
390
|
+
n_bootstrap: int = 100,
|
391
|
+
random_state: int = 0,
|
279
392
|
show_progressbar: bool = True,
|
280
393
|
n_jobs: int = -1,
|
281
394
|
**kwargs,
|
282
|
-
) -> pd.DataFrame:
|
283
|
-
"""Get
|
395
|
+
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
|
396
|
+
"""Get distances between one selected cell group and the remaining other cell groups.
|
284
397
|
|
285
398
|
Args:
|
286
399
|
adata: Annotated data matrix.
|
287
400
|
groupby: Column name in adata.obs.
|
288
401
|
selected_group: Group to compute pairwise distances to all other.
|
289
402
|
groups: List of groups to compute distances to selected_group for.
|
290
|
-
If None, uses all groups.
|
291
|
-
|
403
|
+
If None, uses all groups.
|
404
|
+
bootstrap: Whether to bootstrap the distance.
|
405
|
+
n_bootstrap: Number of bootstrap samples.
|
406
|
+
random_state: Random state for bootstrapping.
|
407
|
+
show_progressbar: Whether to show progress bar.
|
292
408
|
n_jobs: Number of cores to use. Defaults to -1 (all).
|
293
409
|
kwargs: Additional keyword arguments passed to the metric function.
|
294
410
|
|
295
411
|
Returns:
|
296
412
|
pd.DataFrame: Dataframe with distances of groups to selected_group.
|
413
|
+
tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group.
|
414
|
+
|
297
415
|
|
298
416
|
Examples:
|
299
417
|
>>> import pertpy as pt
|
@@ -301,16 +419,31 @@ class Distance:
|
|
301
419
|
>>> Distance = pt.tools.Distance(metric="edistance")
|
302
420
|
>>> pairwise_df = Distance.onesided_distances(adata, groupby="perturbation", selected_group="control")
|
303
421
|
"""
|
422
|
+
if self.metric == "classifier_cp":
|
423
|
+
if bootstrap:
|
424
|
+
raise NotImplementedError("Currently, ClassifierClassProjection does not support bootstrapping.")
|
425
|
+
return self.metric_fct.onesided_distances( # type: ignore
|
426
|
+
adata,
|
427
|
+
groupby,
|
428
|
+
selected_group,
|
429
|
+
groups,
|
430
|
+
show_progressbar,
|
431
|
+
n_jobs,
|
432
|
+
**kwargs,
|
433
|
+
)
|
434
|
+
|
304
435
|
groups = adata.obs[groupby].unique() if groups is None else groups
|
305
436
|
grouping = adata.obs[groupby].copy()
|
306
437
|
df = pd.Series(index=groups, dtype=float)
|
438
|
+
if bootstrap:
|
439
|
+
df_var = pd.Series(index=groups, dtype=float)
|
307
440
|
fct = track if show_progressbar else lambda iterable: iterable
|
308
441
|
|
309
442
|
# Some metrics are able to handle precomputed distances. This means that
|
310
443
|
# the pairwise distances between all cells are computed once and then
|
311
444
|
# passed to the metric function. This is much faster than computing the
|
312
445
|
# pairwise distances for each group separately. Other metrics are not
|
313
|
-
# able to handle precomputed distances such as the
|
446
|
+
# able to handle precomputed distances such as the PseudobulkDistance.
|
314
447
|
if self.metric_fct.accepts_precomputed:
|
315
448
|
# Precompute the pairwise distances if needed
|
316
449
|
if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp.keys():
|
@@ -320,28 +453,59 @@ class Distance:
|
|
320
453
|
idx_x = grouping == group_x
|
321
454
|
group_y = selected_group
|
322
455
|
if group_x == group_y:
|
323
|
-
|
456
|
+
df.loc[group_x] = 0.0 # by distance axiom
|
324
457
|
else:
|
325
458
|
idx_y = grouping == group_y
|
326
459
|
# subset the pairwise distance matrix to the two groups
|
327
460
|
sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
|
328
461
|
sub_idx = grouping[idx_x | idx_y] == group_x
|
329
|
-
|
330
|
-
|
462
|
+
if not bootstrap:
|
463
|
+
dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
|
464
|
+
df.loc[group_x] = dist
|
465
|
+
else:
|
466
|
+
bootstrap_output = self._bootstrap_mode_precomputed(
|
467
|
+
sub_pwd,
|
468
|
+
sub_idx,
|
469
|
+
n_bootstraps=n_bootstrap,
|
470
|
+
random_state=random_state,
|
471
|
+
**kwargs,
|
472
|
+
)
|
473
|
+
df.loc[group_x] = bootstrap_output.mean
|
474
|
+
df_var.loc[group_x] = bootstrap_output.variance
|
331
475
|
else:
|
332
|
-
|
476
|
+
if self.layer_key:
|
477
|
+
embedding = adata.layers[self.layer_key]
|
478
|
+
else:
|
479
|
+
embedding = adata.obsm[self.obsm_key].copy()
|
333
480
|
for group_x in fct(groups):
|
334
481
|
cells_x = embedding[grouping == group_x].copy()
|
335
482
|
group_y = selected_group
|
336
|
-
|
337
|
-
|
483
|
+
cells_y = embedding[grouping == group_y].copy()
|
484
|
+
if not bootstrap:
|
485
|
+
# By distance axiom, the distance between a group and itself is 0
|
486
|
+
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
487
|
+
df.loc[group_x] = dist
|
338
488
|
else:
|
339
|
-
|
340
|
-
|
341
|
-
|
489
|
+
bootstrap_output = self.bootstrap(
|
490
|
+
cells_x,
|
491
|
+
cells_y,
|
492
|
+
n_bootstrap=n_bootstrap,
|
493
|
+
random_state=random_state,
|
494
|
+
**kwargs,
|
495
|
+
)
|
496
|
+
# In the bootstrap case, distance of group to itself is a mean and can be non-zero
|
497
|
+
df.loc[group_x] = bootstrap_output.mean
|
498
|
+
df_var.loc[group_x] = bootstrap_output.variance
|
342
499
|
df.index.name = groupby
|
343
500
|
df.name = f"{self.metric} to {selected_group}"
|
344
|
-
|
501
|
+
if not bootstrap:
|
502
|
+
return df
|
503
|
+
else:
|
504
|
+
df_var.index.name = groupby
|
505
|
+
df_var = df_var.fillna(0)
|
506
|
+
df_var.name = f"pairwise {self.metric} variance to {selected_group}"
|
507
|
+
|
508
|
+
return df, df_var
|
345
509
|
|
346
510
|
def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None:
|
347
511
|
"""Precompute pairwise distances between all cells, writes to adata.obsp.
|
@@ -367,6 +531,77 @@ class Distance:
|
|
367
531
|
pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs)
|
368
532
|
adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd
|
369
533
|
|
534
|
+
def compare_distance(
|
535
|
+
self,
|
536
|
+
pert: np.ndarray,
|
537
|
+
pred: np.ndarray,
|
538
|
+
ctrl: np.ndarray,
|
539
|
+
mode: Literal["simple", "scaled"] = "simple",
|
540
|
+
fit_to_pert_and_ctrl: bool = False,
|
541
|
+
**kwargs,
|
542
|
+
) -> float:
|
543
|
+
"""Compute the score of simulating a perturbation.
|
544
|
+
|
545
|
+
Args:
|
546
|
+
pert: Real perturbed data.
|
547
|
+
pred: Simulated perturbed data.
|
548
|
+
ctrl: Control data
|
549
|
+
mode: Mode to use.
|
550
|
+
fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`.
|
551
|
+
kwargs: Additional keyword arguments passed to the metric function.
|
552
|
+
"""
|
553
|
+
if mode == "simple":
|
554
|
+
pass # nothing to be done
|
555
|
+
elif mode == "scaled":
|
556
|
+
from sklearn.preprocessing import MinMaxScaler
|
557
|
+
|
558
|
+
scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl)
|
559
|
+
pred = scaler.transform(pred)
|
560
|
+
pert = scaler.transform(pert)
|
561
|
+
else:
|
562
|
+
raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.")
|
563
|
+
|
564
|
+
d1 = self.metric_fct(pert, pred, **kwargs)
|
565
|
+
d2 = self.metric_fct(ctrl, pred, **kwargs)
|
566
|
+
return d1 / d2
|
567
|
+
|
568
|
+
def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
|
569
|
+
rng = np.random.default_rng(random_state)
|
570
|
+
|
571
|
+
distances = []
|
572
|
+
for _ in range(n_bootstraps):
|
573
|
+
X_bootstrapped = X[rng.choice(a=X.shape[0], size=X.shape[0], replace=True)]
|
574
|
+
Y_bootstrapped = Y[rng.choice(a=Y.shape[0], size=X.shape[0], replace=True)]
|
575
|
+
|
576
|
+
distance = self(X_bootstrapped, Y_bootstrapped, **kwargs)
|
577
|
+
distances.append(distance)
|
578
|
+
|
579
|
+
mean = np.mean(distances)
|
580
|
+
variance = np.var(distances)
|
581
|
+
return MeanVar(mean=mean, variance=variance)
|
582
|
+
|
583
|
+
def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
|
584
|
+
rng = np.random.default_rng(random_state)
|
585
|
+
|
586
|
+
distances = []
|
587
|
+
for _ in range(n_bootstraps):
|
588
|
+
# To maintain the number of cells for both groups (whatever balancing they may have),
|
589
|
+
# we sample the positive and negative indices separately
|
590
|
+
bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True)
|
591
|
+
bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True)
|
592
|
+
bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx])
|
593
|
+
bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx)
|
594
|
+
|
595
|
+
bootstrap_sub_idx = sub_idx[bootstrap_idx]
|
596
|
+
bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs]
|
597
|
+
|
598
|
+
distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs)
|
599
|
+
distances.append(distance)
|
600
|
+
|
601
|
+
mean = np.mean(distances)
|
602
|
+
variance = np.var(distances)
|
603
|
+
return MeanVar(mean=mean, variance=variance)
|
604
|
+
|
370
605
|
|
371
606
|
class AbstractDistance(ABC):
|
372
607
|
"""Abstract class of distance metrics between two sets of vectors."""
|
@@ -471,11 +706,8 @@ class WassersteinDistance(AbstractDistance):
|
|
471
706
|
return self.solve_ot_problem(geom, **kwargs)
|
472
707
|
|
473
708
|
def solve_ot_problem(self, geom: Geometry, **kwargs):
|
474
|
-
# Define a linear problem with that cost structure.
|
475
709
|
ot_prob = LinearProblem(geom)
|
476
|
-
# Create a Sinkhorn solver
|
477
710
|
solver = Sinkhorn()
|
478
|
-
# Solve OT problem
|
479
711
|
ot = solver(ot_prob, **kwargs)
|
480
712
|
return ot.reg_ot_cost.item()
|
481
713
|
|
@@ -483,12 +715,17 @@ class WassersteinDistance(AbstractDistance):
|
|
483
715
|
class EuclideanDistance(AbstractDistance):
|
484
716
|
"""Euclidean distance between pseudobulk vectors."""
|
485
717
|
|
486
|
-
def __init__(self) -> None:
|
718
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
487
719
|
super().__init__()
|
488
720
|
self.accepts_precomputed = False
|
721
|
+
self.aggregation_func = aggregation_func
|
489
722
|
|
490
723
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
491
|
-
return np.linalg.norm(
|
724
|
+
return np.linalg.norm(
|
725
|
+
self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
|
726
|
+
ord=2,
|
727
|
+
**kwargs,
|
728
|
+
)
|
492
729
|
|
493
730
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
494
731
|
raise NotImplementedError("EuclideanDistance cannot be called on a pairwise distance matrix.")
|
@@ -497,12 +734,21 @@ class EuclideanDistance(AbstractDistance):
|
|
497
734
|
class MeanSquaredDistance(AbstractDistance):
|
498
735
|
"""Mean squared distance between pseudobulk vectors."""
|
499
736
|
|
500
|
-
def __init__(self) -> None:
|
737
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
501
738
|
super().__init__()
|
502
739
|
self.accepts_precomputed = False
|
740
|
+
self.aggregation_func = aggregation_func
|
503
741
|
|
504
742
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
505
|
-
return
|
743
|
+
return (
|
744
|
+
np.linalg.norm(
|
745
|
+
self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
|
746
|
+
ord=2,
|
747
|
+
**kwargs,
|
748
|
+
)
|
749
|
+
** 2
|
750
|
+
/ X.shape[1]
|
751
|
+
)
|
506
752
|
|
507
753
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
508
754
|
raise NotImplementedError("MeanSquaredDistance cannot be called on a pairwise distance matrix.")
|
@@ -511,12 +757,20 @@ class MeanSquaredDistance(AbstractDistance):
|
|
511
757
|
class MeanAbsoluteDistance(AbstractDistance):
|
512
758
|
"""Absolute (Norm-1) distance between pseudobulk vectors."""
|
513
759
|
|
514
|
-
def __init__(self) -> None:
|
760
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
515
761
|
super().__init__()
|
516
762
|
self.accepts_precomputed = False
|
763
|
+
self.aggregation_func = aggregation_func
|
517
764
|
|
518
765
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
519
|
-
return
|
766
|
+
return (
|
767
|
+
np.linalg.norm(
|
768
|
+
self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
|
769
|
+
ord=1,
|
770
|
+
**kwargs,
|
771
|
+
)
|
772
|
+
/ X.shape[1]
|
773
|
+
)
|
520
774
|
|
521
775
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
522
776
|
raise NotImplementedError("MeanAbsoluteDistance cannot be called on a pairwise distance matrix.")
|
@@ -541,12 +795,13 @@ class MeanPairwiseDistance(AbstractDistance):
|
|
541
795
|
class PearsonDistance(AbstractDistance):
|
542
796
|
"""Pearson distance between pseudobulk vectors."""
|
543
797
|
|
544
|
-
def __init__(self) -> None:
|
798
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
545
799
|
super().__init__()
|
546
800
|
self.accepts_precomputed = False
|
801
|
+
self.aggregation_func = aggregation_func
|
547
802
|
|
548
803
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
549
|
-
return 1 - pearsonr(
|
804
|
+
return 1 - pearsonr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
|
550
805
|
|
551
806
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
552
807
|
raise NotImplementedError("PearsonDistance cannot be called on a pairwise distance matrix.")
|
@@ -555,12 +810,13 @@ class PearsonDistance(AbstractDistance):
|
|
555
810
|
class SpearmanDistance(AbstractDistance):
|
556
811
|
"""Spearman distance between pseudobulk vectors."""
|
557
812
|
|
558
|
-
def __init__(self) -> None:
|
813
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
559
814
|
super().__init__()
|
560
815
|
self.accepts_precomputed = False
|
816
|
+
self.aggregation_func = aggregation_func
|
561
817
|
|
562
818
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
563
|
-
return 1 - spearmanr(
|
819
|
+
return 1 - spearmanr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
|
564
820
|
|
565
821
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
566
822
|
raise NotImplementedError("SpearmanDistance cannot be called on a pairwise distance matrix.")
|
@@ -569,12 +825,13 @@ class SpearmanDistance(AbstractDistance):
|
|
569
825
|
class KendallTauDistance(AbstractDistance):
|
570
826
|
"""Kendall-tau distance between pseudobulk vectors."""
|
571
827
|
|
572
|
-
def __init__(self) -> None:
|
828
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
573
829
|
super().__init__()
|
574
830
|
self.accepts_precomputed = False
|
831
|
+
self.aggregation_func = aggregation_func
|
575
832
|
|
576
833
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
577
|
-
x, y =
|
834
|
+
x, y = self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)
|
578
835
|
n = len(x)
|
579
836
|
tau_corr = kendalltau(x, y).statistic
|
580
837
|
tau_dist = (1 - tau_corr) * n * (n - 1) / 4
|
@@ -587,12 +844,13 @@ class KendallTauDistance(AbstractDistance):
|
|
587
844
|
class CosineDistance(AbstractDistance):
|
588
845
|
"""Cosine distance between pseudobulk vectors."""
|
589
846
|
|
590
|
-
def __init__(self) -> None:
|
847
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
591
848
|
super().__init__()
|
592
849
|
self.accepts_precomputed = False
|
850
|
+
self.aggregation_func = aggregation_func
|
593
851
|
|
594
852
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
595
|
-
return cosine(
|
853
|
+
return cosine(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
|
596
854
|
|
597
855
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
598
856
|
raise NotImplementedError("CosineDistance cannot be called on a pairwise distance matrix.")
|
@@ -603,22 +861,24 @@ class R2ScoreDistance(AbstractDistance):
|
|
603
861
|
|
604
862
|
# NOTE: This is not a distance metric but a similarity metric.
|
605
863
|
|
606
|
-
def __init__(self) -> None:
|
864
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
607
865
|
super().__init__()
|
608
866
|
self.accepts_precomputed = False
|
867
|
+
self.aggregation_func = aggregation_func
|
609
868
|
|
610
869
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
611
|
-
return 1 - r2_score(
|
870
|
+
return 1 - r2_score(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
|
612
871
|
|
613
872
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
614
873
|
raise NotImplementedError("R2ScoreDistance cannot be called on a pairwise distance matrix.")
|
615
874
|
|
616
875
|
|
617
|
-
class
|
618
|
-
"""Average of KL divergence between gene distributions of two groups
|
876
|
+
class SymmetricKLDivergence(AbstractDistance):
|
877
|
+
"""Average of symmetric KL divergence between gene distributions of two groups
|
619
878
|
|
620
879
|
Assuming a Gaussian distribution for each gene in each group, calculates
|
621
|
-
the KL divergence between them and averages over all genes
|
880
|
+
the KL divergence between them and averages over all genes. Repeats this ABBA to get a symmetrized distance.
|
881
|
+
See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Symmetrised_divergence.
|
622
882
|
|
623
883
|
"""
|
624
884
|
|
@@ -632,11 +892,12 @@ class KLDivergence(AbstractDistance):
|
|
632
892
|
x_mean, x_std = X[:, i].mean(), X[:, i].std() + epsilon
|
633
893
|
y_mean, y_std = Y[:, i].mean(), Y[:, i].std() + epsilon
|
634
894
|
kl = np.log(y_std / x_std) + (x_std**2 + (x_mean - y_mean) ** 2) / (2 * y_std**2) - 1 / 2
|
635
|
-
|
895
|
+
klr = np.log(x_std / y_std) + (y_std**2 + (y_mean - x_mean) ** 2) / (2 * x_std**2) - 1 / 2
|
896
|
+
kl_all.append(kl + klr)
|
636
897
|
return sum(kl_all) / len(kl_all)
|
637
898
|
|
638
899
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
639
|
-
raise NotImplementedError("
|
900
|
+
raise NotImplementedError("SymmetricKLDivergence cannot be called on a pairwise distance matrix.")
|
640
901
|
|
641
902
|
|
642
903
|
class TTestDistance(AbstractDistance):
|
@@ -663,6 +924,23 @@ class TTestDistance(AbstractDistance):
|
|
663
924
|
raise NotImplementedError("TTestDistance cannot be called on a pairwise distance matrix.")
|
664
925
|
|
665
926
|
|
927
|
+
class KSTestDistance(AbstractDistance):
|
928
|
+
"""Average of two-sided KS test statistic between two groups"""
|
929
|
+
|
930
|
+
def __init__(self) -> None:
|
931
|
+
super().__init__()
|
932
|
+
self.accepts_precomputed = False
|
933
|
+
|
934
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
935
|
+
stats = []
|
936
|
+
for i in range(X.shape[1]):
|
937
|
+
stats.append(abs(kstest(X[:, i], Y[:, i])[0]))
|
938
|
+
return sum(stats) / len(stats)
|
939
|
+
|
940
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
941
|
+
raise NotImplementedError("KSTestDistance cannot be called on a pairwise distance matrix.")
|
942
|
+
|
943
|
+
|
666
944
|
class NBLL(AbstractDistance):
|
667
945
|
"""
|
668
946
|
Average of Log likelihood (scalar) of group B cells
|
@@ -683,16 +961,12 @@ class NBLL(AbstractDistance):
|
|
683
961
|
if not _is_count_matrix(matrix=X) or not _is_count_matrix(matrix=Y):
|
684
962
|
raise ValueError("NBLL distance only works for raw counts.")
|
685
963
|
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
if mu[0] == np.nan or theta[0] == np.nan:
|
693
|
-
raise ValueError("Could not fit a negative binomial distribution to the input data")
|
694
|
-
# calculate the nll of y
|
695
|
-
eps = np.repeat(epsilon, y.shape[0])
|
964
|
+
@numba.jit(forceobj=True)
|
965
|
+
def _compute_nll(y: np.ndarray, nb_params: tuple[float, float], epsilon: float) -> float:
|
966
|
+
mu = np.exp(nb_params[0])
|
967
|
+
theta = 1 / nb_params[1]
|
968
|
+
eps = epsilon
|
969
|
+
|
696
970
|
log_theta_mu_eps = np.log(theta + mu + eps)
|
697
971
|
nll = (
|
698
972
|
theta * (np.log(theta + eps) - log_theta_mu_eps)
|
@@ -701,9 +975,221 @@ class NBLL(AbstractDistance):
|
|
701
975
|
- gammaln(theta)
|
702
976
|
- gammaln(y + 1)
|
703
977
|
)
|
704
|
-
|
978
|
+
return nll.mean()
|
979
|
+
|
980
|
+
def _process_gene(x: np.ndarray, y: np.ndarray, epsilon: float) -> float:
|
981
|
+
try:
|
982
|
+
nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
|
983
|
+
return _compute_nll(y, nb_params, epsilon)
|
984
|
+
except np.linalg.linalg.LinAlgError:
|
985
|
+
if x.mean() < 10 and y.mean() < 10:
|
986
|
+
return 0.0
|
987
|
+
else:
|
988
|
+
return np.nan # Use NaN to indicate skipped genes
|
989
|
+
|
990
|
+
nlls = []
|
991
|
+
genes_skipped = 0
|
992
|
+
|
993
|
+
for i in range(X.shape[1]):
|
994
|
+
nll = _process_gene(X[:, i], Y[:, i], epsilon)
|
995
|
+
if np.isnan(nll):
|
996
|
+
genes_skipped += 1
|
997
|
+
else:
|
998
|
+
nlls.append(nll)
|
999
|
+
|
1000
|
+
if genes_skipped > X.shape[1] / 2:
|
1001
|
+
raise AttributeError(f"{genes_skipped} genes could not be fit, which is over half.")
|
705
1002
|
|
706
|
-
return -sum(nlls) / len(nlls)
|
1003
|
+
return -np.sum(nlls) / len(nlls)
|
707
1004
|
|
708
1005
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
709
1006
|
raise NotImplementedError("NBLL cannot be called on a pairwise distance matrix.")
|
1007
|
+
|
1008
|
+
|
1009
|
+
def _sample(X, frac=None, n=None):
|
1010
|
+
"""Returns subsample of cells in format (train, test)."""
|
1011
|
+
if frac and n:
|
1012
|
+
raise ValueError("Cannot pass both frac and n.")
|
1013
|
+
if frac:
|
1014
|
+
n_cells = max(1, int(X.shape[0] * frac))
|
1015
|
+
elif n:
|
1016
|
+
n_cells = n
|
1017
|
+
else:
|
1018
|
+
raise ValueError("Must pass either `frac` or `n`.")
|
1019
|
+
|
1020
|
+
rng = np.random.default_rng()
|
1021
|
+
sampled_indices = rng.choice(X.shape[0], n_cells, replace=False)
|
1022
|
+
remaining_indices = np.setdiff1d(np.arange(X.shape[0]), sampled_indices)
|
1023
|
+
return X[remaining_indices, :], X[sampled_indices, :]
|
1024
|
+
|
1025
|
+
|
1026
|
+
class ClassifierProbaDistance(AbstractDistance):
|
1027
|
+
"""Average of classification probabilites of a binary classifier.
|
1028
|
+
|
1029
|
+
Assumes the first condition is control and the second is perturbed.
|
1030
|
+
Always holds out 20% of the perturbed condition.
|
1031
|
+
"""
|
1032
|
+
|
1033
|
+
def __init__(self) -> None:
|
1034
|
+
super().__init__()
|
1035
|
+
self.accepts_precomputed = False
|
1036
|
+
|
1037
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1038
|
+
Y_train, Y_test = _sample(Y, frac=0.2)
|
1039
|
+
label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0]
|
1040
|
+
train = np.concatenate([X, Y_train])
|
1041
|
+
|
1042
|
+
reg = LogisticRegression()
|
1043
|
+
reg.fit(train, label)
|
1044
|
+
test_labels = reg.predict_proba(Y_test)
|
1045
|
+
return np.mean(test_labels[:, 1])
|
1046
|
+
|
1047
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1048
|
+
raise NotImplementedError("ClassifierProbaDistance cannot be called on a pairwise distance matrix.")
|
1049
|
+
|
1050
|
+
|
1051
|
+
class ClassifierClassProjection(AbstractDistance):
|
1052
|
+
"""Average of 1-(classification probability of control).
|
1053
|
+
|
1054
|
+
Warning: unlike all other distances, this must also take a list of categorical labels the same length as X.
|
1055
|
+
"""
|
1056
|
+
|
1057
|
+
def __init__(self) -> None:
|
1058
|
+
super().__init__()
|
1059
|
+
self.accepts_precomputed = False
|
1060
|
+
|
1061
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1062
|
+
raise NotImplementedError("ClassifierClassProjection can currently only be called with onesided.")
|
1063
|
+
|
1064
|
+
def onesided_distances(
|
1065
|
+
self,
|
1066
|
+
adata: AnnData,
|
1067
|
+
groupby: str,
|
1068
|
+
selected_group: str | None = None,
|
1069
|
+
groups: list[str] | None = None,
|
1070
|
+
show_progressbar: bool = True,
|
1071
|
+
n_jobs: int = -1,
|
1072
|
+
**kwargs,
|
1073
|
+
) -> Series:
|
1074
|
+
"""Unlike the parent function, all groups except the selected group are factored into the classifier.
|
1075
|
+
|
1076
|
+
Similar to the parent function, the returned dataframe contains only the specified groups.
|
1077
|
+
"""
|
1078
|
+
groups = adata.obs[groupby].unique() if groups is None else groups
|
1079
|
+
fct = track if show_progressbar else lambda iterable: iterable
|
1080
|
+
|
1081
|
+
X = adata[adata.obs[groupby] != selected_group].X
|
1082
|
+
labels = adata[adata.obs[groupby] != selected_group].obs[groupby].values
|
1083
|
+
Y = adata[adata.obs[groupby] == selected_group].X
|
1084
|
+
|
1085
|
+
reg = LogisticRegression()
|
1086
|
+
reg.fit(X, labels)
|
1087
|
+
test_probas = reg.predict_proba(Y)
|
1088
|
+
|
1089
|
+
df = pd.Series(index=groups, dtype=float)
|
1090
|
+
|
1091
|
+
for group in fct(groups):
|
1092
|
+
if group == selected_group:
|
1093
|
+
df.loc[group] = 0
|
1094
|
+
else:
|
1095
|
+
class_idx = list(reg.classes_).index(group)
|
1096
|
+
df.loc[group] = 1 - np.mean(test_probas[:, class_idx])
|
1097
|
+
df.index.name = groupby
|
1098
|
+
df.name = f"classifier_cp to {selected_group}"
|
1099
|
+
|
1100
|
+
return df
|
1101
|
+
|
1102
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1103
|
+
raise NotImplementedError("ClassifierClassProjection cannot be called on a pairwise distance matrix.")
|
1104
|
+
|
1105
|
+
|
1106
|
+
class MeanVarDistributionDistance(AbstractDistance):
|
1107
|
+
"""Distance between mean-var distributions of gene expression."""
|
1108
|
+
|
1109
|
+
def __init__(self) -> None:
|
1110
|
+
super().__init__()
|
1111
|
+
self.accepts_precomputed = False
|
1112
|
+
|
1113
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1114
|
+
"""Difference of mean-var distributions in 2 matrices.
|
1115
|
+
|
1116
|
+
Args:
|
1117
|
+
X: Normalized and log transformed cells x genes count matrix.
|
1118
|
+
Y: Normalized and log transformed cells x genes count matrix.
|
1119
|
+
"""
|
1120
|
+
|
1121
|
+
def _mean_var(x, log: bool = False):
|
1122
|
+
mean = np.mean(x, axis=0)
|
1123
|
+
var = np.var(x, axis=0)
|
1124
|
+
positive = mean > 0
|
1125
|
+
mean = mean[positive]
|
1126
|
+
var = var[positive]
|
1127
|
+
if log:
|
1128
|
+
mean = np.log(mean)
|
1129
|
+
var = np.log(var)
|
1130
|
+
return mean, var
|
1131
|
+
|
1132
|
+
def _prep_kde_data(x, y):
|
1133
|
+
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1134
|
+
|
1135
|
+
def _grid_points(d, n_points=100):
|
1136
|
+
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1137
|
+
d_min = d.min()
|
1138
|
+
d_max = d.max()
|
1139
|
+
# Compute bin size
|
1140
|
+
d_bin = (d_max - d_min) / (n_points - 2)
|
1141
|
+
d_min = d_min - d_bin
|
1142
|
+
d_max = d_max + d_bin
|
1143
|
+
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1144
|
+
|
1145
|
+
def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
|
1146
|
+
# the thread_count is determined using the factor 0.875 as recommended here:
|
1147
|
+
# https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
|
1148
|
+
with multiprocessing.Pool(thread_count) as p:
|
1149
|
+
return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
|
1150
|
+
|
1151
|
+
def _kde_eval(d, grid):
|
1152
|
+
# Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
|
1153
|
+
# can not be compared well on regions further away from the data as they are -inf
|
1154
|
+
kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
|
1155
|
+
return _parallel_score_samples(kde, grid)
|
1156
|
+
|
1157
|
+
mean_x, var_x = _mean_var(X, log=True)
|
1158
|
+
mean_y, var_y = _mean_var(Y, log=True)
|
1159
|
+
|
1160
|
+
x = _prep_kde_data(mean_x, var_x)
|
1161
|
+
y = _prep_kde_data(mean_y, var_y)
|
1162
|
+
|
1163
|
+
# Gridpoints to eval KDE on
|
1164
|
+
mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
|
1165
|
+
var_grid = _grid_points(np.concatenate([var_x, var_y]))
|
1166
|
+
grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
|
1167
|
+
|
1168
|
+
kde_x = _kde_eval(x, grid)
|
1169
|
+
kde_y = _kde_eval(y, grid)
|
1170
|
+
|
1171
|
+
kde_diff = ((kde_x - kde_y) ** 2).mean()
|
1172
|
+
|
1173
|
+
return kde_diff
|
1174
|
+
|
1175
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1176
|
+
raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
|
1177
|
+
|
1178
|
+
|
1179
|
+
class MahalanobisDistance(AbstractDistance):
|
1180
|
+
"""Mahalanobis distance between pseudobulk vectors."""
|
1181
|
+
|
1182
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
1183
|
+
super().__init__()
|
1184
|
+
self.accepts_precomputed = False
|
1185
|
+
self.aggregation_func = aggregation_func
|
1186
|
+
|
1187
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1188
|
+
return mahalanobis(
|
1189
|
+
self.aggregation_func(X, axis=0),
|
1190
|
+
self.aggregation_func(Y, axis=0),
|
1191
|
+
np.linalg.inv(np.cov(X.T)),
|
1192
|
+
)
|
1193
|
+
|
1194
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1195
|
+
raise NotImplementedError("Mahalanobis cannot be called on a pairwise distance matrix.")
|