pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,35 +1,46 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import multiprocessing
|
3
4
|
from abc import ABC, abstractmethod
|
4
|
-
from typing import TYPE_CHECKING
|
5
|
+
from typing import TYPE_CHECKING, Literal, NamedTuple
|
5
6
|
|
7
|
+
import numba
|
6
8
|
import numpy as np
|
7
9
|
import pandas as pd
|
8
10
|
from ott.geometry.geometry import Geometry
|
9
11
|
from ott.geometry.pointcloud import PointCloud
|
10
12
|
from ott.problems.linear.linear_problem import LinearProblem
|
11
13
|
from ott.solvers.linear.sinkhorn import Sinkhorn
|
14
|
+
from pandas import Series
|
12
15
|
from rich.progress import track
|
13
16
|
from scipy.sparse import issparse
|
14
|
-
from scipy.spatial.distance import cosine
|
17
|
+
from scipy.spatial.distance import cosine, mahalanobis
|
15
18
|
from scipy.special import gammaln
|
16
|
-
from scipy.stats import kendalltau, pearsonr, spearmanr
|
19
|
+
from scipy.stats import kendalltau, kstest, pearsonr, spearmanr
|
20
|
+
from sklearn.linear_model import LogisticRegression
|
17
21
|
from sklearn.metrics import pairwise_distances, r2_score
|
18
22
|
from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel
|
23
|
+
from sklearn.neighbors import KernelDensity
|
19
24
|
from statsmodels.discrete.discrete_model import NegativeBinomialP
|
20
25
|
|
21
26
|
if TYPE_CHECKING:
|
22
|
-
from collections.abc import
|
27
|
+
from collections.abc import Callable
|
23
28
|
|
24
29
|
from anndata import AnnData
|
25
30
|
|
26
31
|
|
32
|
+
class MeanVar(NamedTuple):
|
33
|
+
mean: float
|
34
|
+
variance: float
|
35
|
+
|
36
|
+
|
27
37
|
class Distance:
|
28
38
|
"""Distance class, used to compute distances between groups of cells.
|
29
39
|
|
30
40
|
The distance metric can be specified by the user. This class also provides a
|
31
41
|
method to compute the pairwise distances between all groups of cells.
|
32
42
|
Currently available metrics:
|
43
|
+
|
33
44
|
- "edistance": Energy distance (Default metric).
|
34
45
|
In essence, it is twice the mean pairwise distance between cells of two
|
35
46
|
groups minus the mean pairwise distance between cells within each group
|
@@ -55,8 +66,6 @@ class Distance:
|
|
55
66
|
Coefficient of determination distance between the means of cells from two groups.
|
56
67
|
- "mean_pairwise": Mean pairwise distance.
|
57
68
|
Mean of the pairwise euclidean distances between cells of two groups.
|
58
|
-
- "mean_pairwise": Mean pairwise distance.
|
59
|
-
Mean of the pairwise euclidean distances between cells of two groups.
|
60
69
|
- "mmd": Maximum mean discrepancy
|
61
70
|
Maximum mean discrepancy between the cells of two groups.
|
62
71
|
Here, uses linear, rbf, and quadratic polynomial MMD. For theory on MMD in single-cell applications, see
|
@@ -66,14 +75,25 @@ class Distance:
|
|
66
75
|
OTT-JAX implementation of the Sinkhorn algorithm to compute the distance.
|
67
76
|
For more information on the optimal transport solver, see
|
68
77
|
`Cuturi et al. (2013) <https://proceedings.neurips.cc/paper/2013/file/af21d0c97db2e27e13572cbf59eb343d-Paper.pdf>`__.
|
69
|
-
- "
|
78
|
+
- "sym_kldiv": symmetrized Kullback–Leibler divergence distance.
|
70
79
|
Kullback–Leibler divergence of the gaussian distributions between cells of two groups.
|
71
|
-
Here we fit a gaussian distribution over
|
80
|
+
Here we fit a gaussian distribution over one group of cells and then calculate the KL divergence on the other, and vice versa.
|
72
81
|
- "t_test": t-test statistic.
|
73
82
|
T-test statistic measure between cells of two groups.
|
83
|
+
- "ks_test": Kolmogorov-Smirnov test statistic.
|
84
|
+
Kolmogorov-Smirnov test statistic measure between cells of two groups.
|
74
85
|
- "nb_ll": log-likelihood over negative binomial
|
75
86
|
Average of log-likelihoods of samples of the secondary group after fitting a negative binomial distribution
|
76
87
|
over the samples of the first group.
|
88
|
+
- "classifier_proba": probability of a binary classifier
|
89
|
+
Average of the classification probability of the perturbation for a binary classifier.
|
90
|
+
- "classifier_cp": classifier class projection
|
91
|
+
Average of the class
|
92
|
+
- "mean_var_distribution": Distance between mean-variance distributions between cells of 2 groups.
|
93
|
+
Mean square distance between the mean-variance distributions of cells from 2 groups using Kernel Density Estimation (KDE).
|
94
|
+
- "mahalanobis": Mahalanobis distance between the means of cells from two groups.
|
95
|
+
It is originally used to measure distance between a point and a distribution.
|
96
|
+
in this context, it quantifies the difference between the mean profiles of a target group and a reference group.
|
77
97
|
|
78
98
|
Attributes:
|
79
99
|
metric: Name of distance metric.
|
@@ -93,6 +113,7 @@ class Distance:
|
|
93
113
|
def __init__(
|
94
114
|
self,
|
95
115
|
metric: str = "edistance",
|
116
|
+
agg_fct: Callable = np.mean,
|
96
117
|
layer_key: str = None,
|
97
118
|
obsm_key: str = None,
|
98
119
|
cell_wise_metric: str = "euclidean",
|
@@ -100,57 +121,67 @@ class Distance:
|
|
100
121
|
"""Initialize Distance class.
|
101
122
|
|
102
123
|
Args:
|
103
|
-
metric: Distance metric to use.
|
124
|
+
metric: Distance metric to use.
|
125
|
+
agg_fct: Aggregation function to generate pseudobulk vectors.
|
104
126
|
layer_key: Name of the counts layer containing raw counts to calculate distances for.
|
105
127
|
Mutually exclusive with 'obsm_key'.
|
106
|
-
|
128
|
+
Is not used if `None`.
|
107
129
|
obsm_key: Name of embedding in adata.obsm to use.
|
108
|
-
Mutually exclusive with '
|
109
|
-
Defaults to None, but is set to "X_pca" if not set
|
130
|
+
Mutually exclusive with 'layer_key'.
|
131
|
+
Defaults to None, but is set to "X_pca" if not explicitly set internally.
|
110
132
|
cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells.
|
111
|
-
Defaults to "euclidean".
|
112
133
|
"""
|
113
134
|
metric_fct: AbstractDistance = None
|
135
|
+
self.aggregation_func = agg_fct
|
114
136
|
if metric == "edistance":
|
115
137
|
metric_fct = Edistance()
|
116
138
|
elif metric == "euclidean":
|
117
|
-
metric_fct = EuclideanDistance()
|
139
|
+
metric_fct = EuclideanDistance(self.aggregation_func)
|
118
140
|
elif metric == "root_mean_squared_error":
|
119
|
-
metric_fct = EuclideanDistance()
|
141
|
+
metric_fct = EuclideanDistance(self.aggregation_func)
|
120
142
|
elif metric == "mse":
|
121
|
-
metric_fct = MeanSquaredDistance()
|
143
|
+
metric_fct = MeanSquaredDistance(self.aggregation_func)
|
122
144
|
elif metric == "mean_absolute_error":
|
123
|
-
metric_fct = MeanAbsoluteDistance()
|
145
|
+
metric_fct = MeanAbsoluteDistance(self.aggregation_func)
|
124
146
|
elif metric == "pearson_distance":
|
125
|
-
metric_fct = PearsonDistance()
|
147
|
+
metric_fct = PearsonDistance(self.aggregation_func)
|
126
148
|
elif metric == "spearman_distance":
|
127
|
-
metric_fct = SpearmanDistance()
|
149
|
+
metric_fct = SpearmanDistance(self.aggregation_func)
|
128
150
|
elif metric == "kendalltau_distance":
|
129
|
-
metric_fct = KendallTauDistance()
|
151
|
+
metric_fct = KendallTauDistance(self.aggregation_func)
|
130
152
|
elif metric == "cosine_distance":
|
131
|
-
metric_fct = CosineDistance()
|
153
|
+
metric_fct = CosineDistance(self.aggregation_func)
|
132
154
|
elif metric == "r2_distance":
|
133
|
-
metric_fct = R2ScoreDistance()
|
155
|
+
metric_fct = R2ScoreDistance(self.aggregation_func)
|
134
156
|
elif metric == "mean_pairwise":
|
135
157
|
metric_fct = MeanPairwiseDistance()
|
136
158
|
elif metric == "mmd":
|
137
159
|
metric_fct = MMD()
|
138
160
|
elif metric == "wasserstein":
|
139
161
|
metric_fct = WassersteinDistance()
|
140
|
-
elif metric == "
|
141
|
-
metric_fct =
|
162
|
+
elif metric == "sym_kldiv":
|
163
|
+
metric_fct = SymmetricKLDivergence()
|
142
164
|
elif metric == "t_test":
|
143
165
|
metric_fct = TTestDistance()
|
166
|
+
elif metric == "ks_test":
|
167
|
+
metric_fct = KSTestDistance()
|
144
168
|
elif metric == "nb_ll":
|
145
169
|
metric_fct = NBLL()
|
170
|
+
elif metric == "classifier_proba":
|
171
|
+
metric_fct = ClassifierProbaDistance()
|
172
|
+
elif metric == "classifier_cp":
|
173
|
+
metric_fct = ClassifierClassProjection()
|
174
|
+
elif metric == "mean_var_distribution":
|
175
|
+
metric_fct = MeanVarDistributionDistance()
|
176
|
+
elif metric == "mahalanobis":
|
177
|
+
metric_fct = MahalanobisDistance(self.aggregation_func)
|
146
178
|
else:
|
147
179
|
raise ValueError(f"Metric {metric} not recognized.")
|
148
180
|
self.metric_fct = metric_fct
|
149
181
|
|
150
182
|
if layer_key and obsm_key:
|
151
183
|
raise ValueError(
|
152
|
-
"Cannot use '
|
153
|
-
"Please provide only one of the two keys."
|
184
|
+
"Cannot use 'layer_key' and 'obsm_key' at the same time.\n" "Please provide only one of the two keys."
|
154
185
|
)
|
155
186
|
if not layer_key and not obsm_key:
|
156
187
|
obsm_key = "X_pca"
|
@@ -183,37 +214,80 @@ class Distance:
|
|
183
214
|
>>> D = Distance(X, Y)
|
184
215
|
"""
|
185
216
|
if issparse(X):
|
186
|
-
X = X.
|
217
|
+
X = X.toarray()
|
187
218
|
if issparse(Y):
|
188
|
-
Y = Y.
|
219
|
+
Y = Y.toarray()
|
189
220
|
|
190
221
|
if len(X) == 0 or len(Y) == 0:
|
191
222
|
raise ValueError("Neither X nor Y can be empty.")
|
192
223
|
|
193
224
|
return self.metric_fct(X, Y, **kwargs)
|
194
225
|
|
226
|
+
def bootstrap(
|
227
|
+
self,
|
228
|
+
X: np.ndarray,
|
229
|
+
Y: np.ndarray,
|
230
|
+
*,
|
231
|
+
n_bootstrap: int = 100,
|
232
|
+
random_state: int = 0,
|
233
|
+
**kwargs,
|
234
|
+
) -> MeanVar:
|
235
|
+
"""Bootstrap computation of mean and variance of the distance between vectors X and Y.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
X: First vector of shape (n_samples, n_features).
|
239
|
+
Y: Second vector of shape (n_samples, n_features).
|
240
|
+
n_bootstrap: Number of bootstrap samples.
|
241
|
+
random_state: Random state for bootstrapping.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
MeanVar: Mean and variance of distance between X and Y.
|
245
|
+
|
246
|
+
Examples:
|
247
|
+
>>> import pertpy as pt
|
248
|
+
>>> adata = pt.dt.distance_example()
|
249
|
+
>>> Distance = pt.tools.Distance(metric="edistance")
|
250
|
+
>>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
|
251
|
+
>>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
|
252
|
+
>>> D = Distance.bootstrap(X, Y)
|
253
|
+
"""
|
254
|
+
return self._bootstrap_mode(
|
255
|
+
X,
|
256
|
+
Y,
|
257
|
+
n_bootstraps=n_bootstrap,
|
258
|
+
random_state=random_state,
|
259
|
+
**kwargs,
|
260
|
+
)
|
261
|
+
|
195
262
|
def pairwise(
|
196
263
|
self,
|
197
264
|
adata: AnnData,
|
198
265
|
groupby: str,
|
199
266
|
groups: list[str] | None = None,
|
267
|
+
bootstrap: bool = False,
|
268
|
+
n_bootstrap: int = 100,
|
269
|
+
random_state: int = 0,
|
200
270
|
show_progressbar: bool = True,
|
201
271
|
n_jobs: int = -1,
|
202
272
|
**kwargs,
|
203
|
-
) -> pd.DataFrame:
|
273
|
+
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
|
204
274
|
"""Get pairwise distances between groups of cells.
|
205
275
|
|
206
276
|
Args:
|
207
277
|
adata: Annotated data matrix.
|
208
278
|
groupby: Column name in adata.obs.
|
209
279
|
groups: List of groups to compute pairwise distances for.
|
210
|
-
If None, uses all groups.
|
211
|
-
|
280
|
+
If None, uses all groups.
|
281
|
+
bootstrap: Whether to bootstrap the distance.
|
282
|
+
n_bootstrap: Number of bootstrap samples.
|
283
|
+
random_state: Random state for bootstrapping.
|
284
|
+
show_progressbar: Whether to show progress bar.
|
212
285
|
n_jobs: Number of cores to use. Defaults to -1 (all).
|
213
286
|
kwargs: Additional keyword arguments passed to the metric function.
|
214
287
|
|
215
288
|
Returns:
|
216
289
|
pd.DataFrame: Dataframe with pairwise distances.
|
290
|
+
tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of pairwise distances.
|
217
291
|
|
218
292
|
Examples:
|
219
293
|
>>> import pertpy as pt
|
@@ -224,6 +298,8 @@ class Distance:
|
|
224
298
|
groups = adata.obs[groupby].unique() if groups is None else groups
|
225
299
|
grouping = adata.obs[groupby].copy()
|
226
300
|
df = pd.DataFrame(index=groups, columns=groups, dtype=float)
|
301
|
+
if bootstrap:
|
302
|
+
df_var = pd.DataFrame(index=groups, columns=groups, dtype=float)
|
227
303
|
fct = track if show_progressbar else lambda iterable: iterable
|
228
304
|
|
229
305
|
# Some metrics are able to handle precomputed distances. This means that
|
@@ -239,16 +315,29 @@ class Distance:
|
|
239
315
|
for index_x, group_x in enumerate(fct(groups)):
|
240
316
|
idx_x = grouping == group_x
|
241
317
|
for group_y in groups[index_x:]: # type: ignore
|
242
|
-
|
243
|
-
|
318
|
+
# subset the pairwise distance matrix to the two groups
|
319
|
+
idx_y = grouping == group_y
|
320
|
+
sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
|
321
|
+
sub_idx = grouping[idx_x | idx_y] == group_x
|
322
|
+
if not bootstrap:
|
323
|
+
if group_x == group_y:
|
324
|
+
dist = 0.0
|
325
|
+
else:
|
326
|
+
dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
|
327
|
+
df.loc[group_x, group_y] = dist
|
328
|
+
df.loc[group_y, group_x] = dist
|
329
|
+
|
244
330
|
else:
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
331
|
+
bootstrap_output = self._bootstrap_mode_precomputed(
|
332
|
+
sub_pwd,
|
333
|
+
sub_idx,
|
334
|
+
n_bootstraps=n_bootstrap,
|
335
|
+
random_state=random_state,
|
336
|
+
**kwargs,
|
337
|
+
)
|
338
|
+
# In the bootstrap case, distance of group to itself is a mean and can be non-zero
|
339
|
+
df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
|
340
|
+
df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
|
252
341
|
else:
|
253
342
|
if self.layer_key:
|
254
343
|
embedding = adata.layers[self.layer_key]
|
@@ -257,18 +346,39 @@ class Distance:
|
|
257
346
|
for index_x, group_x in enumerate(fct(groups)):
|
258
347
|
cells_x = embedding[grouping == group_x].copy()
|
259
348
|
for group_y in groups[index_x:]: # type: ignore
|
260
|
-
|
261
|
-
|
349
|
+
cells_y = embedding[grouping == group_y].copy()
|
350
|
+
if not bootstrap:
|
351
|
+
# By distance axiom, the distance between a group and itself is 0
|
352
|
+
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
353
|
+
|
354
|
+
df.loc[group_x, group_y] = dist
|
355
|
+
df.loc[group_y, group_x] = dist
|
262
356
|
else:
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
357
|
+
bootstrap_output = self.bootstrap(
|
358
|
+
cells_x,
|
359
|
+
cells_y,
|
360
|
+
n_bootstrap=n_bootstrap,
|
361
|
+
random_state=random_state,
|
362
|
+
**kwargs,
|
363
|
+
)
|
364
|
+
# In the bootstrap case, distance of group to itself is a mean and can be non-zero
|
365
|
+
df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
|
366
|
+
df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
|
367
|
+
|
267
368
|
df.index.name = groupby
|
268
369
|
df.columns.name = groupby
|
269
370
|
df.name = f"pairwise {self.metric}"
|
270
371
|
|
271
|
-
|
372
|
+
if not bootstrap:
|
373
|
+
return df
|
374
|
+
else:
|
375
|
+
df = df.fillna(0)
|
376
|
+
df_var.index.name = groupby
|
377
|
+
df_var.columns.name = groupby
|
378
|
+
df_var = df_var.fillna(0)
|
379
|
+
df_var.name = f"pairwise {self.metric} variance"
|
380
|
+
|
381
|
+
return df, df_var
|
272
382
|
|
273
383
|
def onesided_distances(
|
274
384
|
self,
|
@@ -276,24 +386,32 @@ class Distance:
|
|
276
386
|
groupby: str,
|
277
387
|
selected_group: str | None = None,
|
278
388
|
groups: list[str] | None = None,
|
389
|
+
bootstrap: bool = False,
|
390
|
+
n_bootstrap: int = 100,
|
391
|
+
random_state: int = 0,
|
279
392
|
show_progressbar: bool = True,
|
280
393
|
n_jobs: int = -1,
|
281
394
|
**kwargs,
|
282
|
-
) -> pd.DataFrame:
|
283
|
-
"""Get
|
395
|
+
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
|
396
|
+
"""Get distances between one selected cell group and the remaining other cell groups.
|
284
397
|
|
285
398
|
Args:
|
286
399
|
adata: Annotated data matrix.
|
287
400
|
groupby: Column name in adata.obs.
|
288
401
|
selected_group: Group to compute pairwise distances to all other.
|
289
402
|
groups: List of groups to compute distances to selected_group for.
|
290
|
-
If None, uses all groups.
|
291
|
-
|
403
|
+
If None, uses all groups.
|
404
|
+
bootstrap: Whether to bootstrap the distance.
|
405
|
+
n_bootstrap: Number of bootstrap samples.
|
406
|
+
random_state: Random state for bootstrapping.
|
407
|
+
show_progressbar: Whether to show progress bar.
|
292
408
|
n_jobs: Number of cores to use. Defaults to -1 (all).
|
293
409
|
kwargs: Additional keyword arguments passed to the metric function.
|
294
410
|
|
295
411
|
Returns:
|
296
412
|
pd.DataFrame: Dataframe with distances of groups to selected_group.
|
413
|
+
tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group.
|
414
|
+
|
297
415
|
|
298
416
|
Examples:
|
299
417
|
>>> import pertpy as pt
|
@@ -301,16 +419,31 @@ class Distance:
|
|
301
419
|
>>> Distance = pt.tools.Distance(metric="edistance")
|
302
420
|
>>> pairwise_df = Distance.onesided_distances(adata, groupby="perturbation", selected_group="control")
|
303
421
|
"""
|
422
|
+
if self.metric == "classifier_cp":
|
423
|
+
if bootstrap:
|
424
|
+
raise NotImplementedError("Currently, ClassifierClassProjection does not support bootstrapping.")
|
425
|
+
return self.metric_fct.onesided_distances( # type: ignore
|
426
|
+
adata,
|
427
|
+
groupby,
|
428
|
+
selected_group,
|
429
|
+
groups,
|
430
|
+
show_progressbar,
|
431
|
+
n_jobs,
|
432
|
+
**kwargs,
|
433
|
+
)
|
434
|
+
|
304
435
|
groups = adata.obs[groupby].unique() if groups is None else groups
|
305
436
|
grouping = adata.obs[groupby].copy()
|
306
437
|
df = pd.Series(index=groups, dtype=float)
|
438
|
+
if bootstrap:
|
439
|
+
df_var = pd.Series(index=groups, dtype=float)
|
307
440
|
fct = track if show_progressbar else lambda iterable: iterable
|
308
441
|
|
309
442
|
# Some metrics are able to handle precomputed distances. This means that
|
310
443
|
# the pairwise distances between all cells are computed once and then
|
311
444
|
# passed to the metric function. This is much faster than computing the
|
312
445
|
# pairwise distances for each group separately. Other metrics are not
|
313
|
-
# able to handle precomputed distances such as the
|
446
|
+
# able to handle precomputed distances such as the PseudobulkDistance.
|
314
447
|
if self.metric_fct.accepts_precomputed:
|
315
448
|
# Precompute the pairwise distances if needed
|
316
449
|
if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp.keys():
|
@@ -320,28 +453,59 @@ class Distance:
|
|
320
453
|
idx_x = grouping == group_x
|
321
454
|
group_y = selected_group
|
322
455
|
if group_x == group_y:
|
323
|
-
|
456
|
+
df.loc[group_x] = 0.0 # by distance axiom
|
324
457
|
else:
|
325
458
|
idx_y = grouping == group_y
|
326
459
|
# subset the pairwise distance matrix to the two groups
|
327
460
|
sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
|
328
461
|
sub_idx = grouping[idx_x | idx_y] == group_x
|
329
|
-
|
330
|
-
|
462
|
+
if not bootstrap:
|
463
|
+
dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
|
464
|
+
df.loc[group_x] = dist
|
465
|
+
else:
|
466
|
+
bootstrap_output = self._bootstrap_mode_precomputed(
|
467
|
+
sub_pwd,
|
468
|
+
sub_idx,
|
469
|
+
n_bootstraps=n_bootstrap,
|
470
|
+
random_state=random_state,
|
471
|
+
**kwargs,
|
472
|
+
)
|
473
|
+
df.loc[group_x] = bootstrap_output.mean
|
474
|
+
df_var.loc[group_x] = bootstrap_output.variance
|
331
475
|
else:
|
332
|
-
|
476
|
+
if self.layer_key:
|
477
|
+
embedding = adata.layers[self.layer_key]
|
478
|
+
else:
|
479
|
+
embedding = adata.obsm[self.obsm_key].copy()
|
333
480
|
for group_x in fct(groups):
|
334
481
|
cells_x = embedding[grouping == group_x].copy()
|
335
482
|
group_y = selected_group
|
336
|
-
|
337
|
-
|
483
|
+
cells_y = embedding[grouping == group_y].copy()
|
484
|
+
if not bootstrap:
|
485
|
+
# By distance axiom, the distance between a group and itself is 0
|
486
|
+
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
487
|
+
df.loc[group_x] = dist
|
338
488
|
else:
|
339
|
-
|
340
|
-
|
341
|
-
|
489
|
+
bootstrap_output = self.bootstrap(
|
490
|
+
cells_x,
|
491
|
+
cells_y,
|
492
|
+
n_bootstrap=n_bootstrap,
|
493
|
+
random_state=random_state,
|
494
|
+
**kwargs,
|
495
|
+
)
|
496
|
+
# In the bootstrap case, distance of group to itself is a mean and can be non-zero
|
497
|
+
df.loc[group_x] = bootstrap_output.mean
|
498
|
+
df_var.loc[group_x] = bootstrap_output.variance
|
342
499
|
df.index.name = groupby
|
343
500
|
df.name = f"{self.metric} to {selected_group}"
|
344
|
-
|
501
|
+
if not bootstrap:
|
502
|
+
return df
|
503
|
+
else:
|
504
|
+
df_var.index.name = groupby
|
505
|
+
df_var = df_var.fillna(0)
|
506
|
+
df_var.name = f"pairwise {self.metric} variance to {selected_group}"
|
507
|
+
|
508
|
+
return df, df_var
|
345
509
|
|
346
510
|
def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None:
|
347
511
|
"""Precompute pairwise distances between all cells, writes to adata.obsp.
|
@@ -367,6 +531,77 @@ class Distance:
|
|
367
531
|
pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs)
|
368
532
|
adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd
|
369
533
|
|
534
|
+
def compare_distance(
|
535
|
+
self,
|
536
|
+
pert: np.ndarray,
|
537
|
+
pred: np.ndarray,
|
538
|
+
ctrl: np.ndarray,
|
539
|
+
mode: Literal["simple", "scaled"] = "simple",
|
540
|
+
fit_to_pert_and_ctrl: bool = False,
|
541
|
+
**kwargs,
|
542
|
+
) -> float:
|
543
|
+
"""Compute the score of simulating a perturbation.
|
544
|
+
|
545
|
+
Args:
|
546
|
+
pert: Real perturbed data.
|
547
|
+
pred: Simulated perturbed data.
|
548
|
+
ctrl: Control data
|
549
|
+
mode: Mode to use.
|
550
|
+
fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`.
|
551
|
+
kwargs: Additional keyword arguments passed to the metric function.
|
552
|
+
"""
|
553
|
+
if mode == "simple":
|
554
|
+
pass # nothing to be done
|
555
|
+
elif mode == "scaled":
|
556
|
+
from sklearn.preprocessing import MinMaxScaler
|
557
|
+
|
558
|
+
scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl)
|
559
|
+
pred = scaler.transform(pred)
|
560
|
+
pert = scaler.transform(pert)
|
561
|
+
else:
|
562
|
+
raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.")
|
563
|
+
|
564
|
+
d1 = self.metric_fct(pert, pred, **kwargs)
|
565
|
+
d2 = self.metric_fct(ctrl, pred, **kwargs)
|
566
|
+
return d1 / d2
|
567
|
+
|
568
|
+
def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
|
569
|
+
rng = np.random.default_rng(random_state)
|
570
|
+
|
571
|
+
distances = []
|
572
|
+
for _ in range(n_bootstraps):
|
573
|
+
X_bootstrapped = X[rng.choice(a=X.shape[0], size=X.shape[0], replace=True)]
|
574
|
+
Y_bootstrapped = Y[rng.choice(a=Y.shape[0], size=X.shape[0], replace=True)]
|
575
|
+
|
576
|
+
distance = self(X_bootstrapped, Y_bootstrapped, **kwargs)
|
577
|
+
distances.append(distance)
|
578
|
+
|
579
|
+
mean = np.mean(distances)
|
580
|
+
variance = np.var(distances)
|
581
|
+
return MeanVar(mean=mean, variance=variance)
|
582
|
+
|
583
|
+
def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
|
584
|
+
rng = np.random.default_rng(random_state)
|
585
|
+
|
586
|
+
distances = []
|
587
|
+
for _ in range(n_bootstraps):
|
588
|
+
# To maintain the number of cells for both groups (whatever balancing they may have),
|
589
|
+
# we sample the positive and negative indices separately
|
590
|
+
bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True)
|
591
|
+
bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True)
|
592
|
+
bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx])
|
593
|
+
bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx)
|
594
|
+
|
595
|
+
bootstrap_sub_idx = sub_idx[bootstrap_idx]
|
596
|
+
bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs]
|
597
|
+
|
598
|
+
distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs)
|
599
|
+
distances.append(distance)
|
600
|
+
|
601
|
+
mean = np.mean(distances)
|
602
|
+
variance = np.var(distances)
|
603
|
+
return MeanVar(mean=mean, variance=variance)
|
604
|
+
|
370
605
|
|
371
606
|
class AbstractDistance(ABC):
|
372
607
|
"""Abstract class of distance metrics between two sets of vectors."""
|
@@ -471,11 +706,8 @@ class WassersteinDistance(AbstractDistance):
|
|
471
706
|
return self.solve_ot_problem(geom, **kwargs)
|
472
707
|
|
473
708
|
def solve_ot_problem(self, geom: Geometry, **kwargs):
|
474
|
-
# Define a linear problem with that cost structure.
|
475
709
|
ot_prob = LinearProblem(geom)
|
476
|
-
# Create a Sinkhorn solver
|
477
710
|
solver = Sinkhorn()
|
478
|
-
# Solve OT problem
|
479
711
|
ot = solver(ot_prob, **kwargs)
|
480
712
|
return ot.reg_ot_cost.item()
|
481
713
|
|
@@ -483,12 +715,17 @@ class WassersteinDistance(AbstractDistance):
|
|
483
715
|
class EuclideanDistance(AbstractDistance):
|
484
716
|
"""Euclidean distance between pseudobulk vectors."""
|
485
717
|
|
486
|
-
def __init__(self) -> None:
|
718
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
487
719
|
super().__init__()
|
488
720
|
self.accepts_precomputed = False
|
721
|
+
self.aggregation_func = aggregation_func
|
489
722
|
|
490
723
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
491
|
-
return np.linalg.norm(
|
724
|
+
return np.linalg.norm(
|
725
|
+
self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
|
726
|
+
ord=2,
|
727
|
+
**kwargs,
|
728
|
+
)
|
492
729
|
|
493
730
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
494
731
|
raise NotImplementedError("EuclideanDistance cannot be called on a pairwise distance matrix.")
|
@@ -497,12 +734,21 @@ class EuclideanDistance(AbstractDistance):
|
|
497
734
|
class MeanSquaredDistance(AbstractDistance):
|
498
735
|
"""Mean squared distance between pseudobulk vectors."""
|
499
736
|
|
500
|
-
def __init__(self) -> None:
|
737
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
501
738
|
super().__init__()
|
502
739
|
self.accepts_precomputed = False
|
740
|
+
self.aggregation_func = aggregation_func
|
503
741
|
|
504
742
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
505
|
-
return
|
743
|
+
return (
|
744
|
+
np.linalg.norm(
|
745
|
+
self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
|
746
|
+
ord=2,
|
747
|
+
**kwargs,
|
748
|
+
)
|
749
|
+
** 2
|
750
|
+
/ X.shape[1]
|
751
|
+
)
|
506
752
|
|
507
753
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
508
754
|
raise NotImplementedError("MeanSquaredDistance cannot be called on a pairwise distance matrix.")
|
@@ -511,12 +757,20 @@ class MeanSquaredDistance(AbstractDistance):
|
|
511
757
|
class MeanAbsoluteDistance(AbstractDistance):
|
512
758
|
"""Absolute (Norm-1) distance between pseudobulk vectors."""
|
513
759
|
|
514
|
-
def __init__(self) -> None:
|
760
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
515
761
|
super().__init__()
|
516
762
|
self.accepts_precomputed = False
|
763
|
+
self.aggregation_func = aggregation_func
|
517
764
|
|
518
765
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
519
|
-
return
|
766
|
+
return (
|
767
|
+
np.linalg.norm(
|
768
|
+
self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
|
769
|
+
ord=1,
|
770
|
+
**kwargs,
|
771
|
+
)
|
772
|
+
/ X.shape[1]
|
773
|
+
)
|
520
774
|
|
521
775
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
522
776
|
raise NotImplementedError("MeanAbsoluteDistance cannot be called on a pairwise distance matrix.")
|
@@ -541,12 +795,13 @@ class MeanPairwiseDistance(AbstractDistance):
|
|
541
795
|
class PearsonDistance(AbstractDistance):
|
542
796
|
"""Pearson distance between pseudobulk vectors."""
|
543
797
|
|
544
|
-
def __init__(self) -> None:
|
798
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
545
799
|
super().__init__()
|
546
800
|
self.accepts_precomputed = False
|
801
|
+
self.aggregation_func = aggregation_func
|
547
802
|
|
548
803
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
549
|
-
return 1 - pearsonr(
|
804
|
+
return 1 - pearsonr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
|
550
805
|
|
551
806
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
552
807
|
raise NotImplementedError("PearsonDistance cannot be called on a pairwise distance matrix.")
|
@@ -555,12 +810,13 @@ class PearsonDistance(AbstractDistance):
|
|
555
810
|
class SpearmanDistance(AbstractDistance):
|
556
811
|
"""Spearman distance between pseudobulk vectors."""
|
557
812
|
|
558
|
-
def __init__(self) -> None:
|
813
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
559
814
|
super().__init__()
|
560
815
|
self.accepts_precomputed = False
|
816
|
+
self.aggregation_func = aggregation_func
|
561
817
|
|
562
818
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
563
|
-
return 1 - spearmanr(
|
819
|
+
return 1 - spearmanr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
|
564
820
|
|
565
821
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
566
822
|
raise NotImplementedError("SpearmanDistance cannot be called on a pairwise distance matrix.")
|
@@ -569,12 +825,13 @@ class SpearmanDistance(AbstractDistance):
|
|
569
825
|
class KendallTauDistance(AbstractDistance):
|
570
826
|
"""Kendall-tau distance between pseudobulk vectors."""
|
571
827
|
|
572
|
-
def __init__(self) -> None:
|
828
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
573
829
|
super().__init__()
|
574
830
|
self.accepts_precomputed = False
|
831
|
+
self.aggregation_func = aggregation_func
|
575
832
|
|
576
833
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
577
|
-
x, y =
|
834
|
+
x, y = self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)
|
578
835
|
n = len(x)
|
579
836
|
tau_corr = kendalltau(x, y).statistic
|
580
837
|
tau_dist = (1 - tau_corr) * n * (n - 1) / 4
|
@@ -587,12 +844,13 @@ class KendallTauDistance(AbstractDistance):
|
|
587
844
|
class CosineDistance(AbstractDistance):
|
588
845
|
"""Cosine distance between pseudobulk vectors."""
|
589
846
|
|
590
|
-
def __init__(self) -> None:
|
847
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
591
848
|
super().__init__()
|
592
849
|
self.accepts_precomputed = False
|
850
|
+
self.aggregation_func = aggregation_func
|
593
851
|
|
594
852
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
595
|
-
return cosine(
|
853
|
+
return cosine(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
|
596
854
|
|
597
855
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
598
856
|
raise NotImplementedError("CosineDistance cannot be called on a pairwise distance matrix.")
|
@@ -603,22 +861,24 @@ class R2ScoreDistance(AbstractDistance):
|
|
603
861
|
|
604
862
|
# NOTE: This is not a distance metric but a similarity metric.
|
605
863
|
|
606
|
-
def __init__(self) -> None:
|
864
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
607
865
|
super().__init__()
|
608
866
|
self.accepts_precomputed = False
|
867
|
+
self.aggregation_func = aggregation_func
|
609
868
|
|
610
869
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
611
|
-
return 1 - r2_score(
|
870
|
+
return 1 - r2_score(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
|
612
871
|
|
613
872
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
614
873
|
raise NotImplementedError("R2ScoreDistance cannot be called on a pairwise distance matrix.")
|
615
874
|
|
616
875
|
|
617
|
-
class
|
618
|
-
"""Average of KL divergence between gene distributions of two groups
|
876
|
+
class SymmetricKLDivergence(AbstractDistance):
|
877
|
+
"""Average of symmetric KL divergence between gene distributions of two groups
|
619
878
|
|
620
879
|
Assuming a Gaussian distribution for each gene in each group, calculates
|
621
|
-
the KL divergence between them and averages over all genes
|
880
|
+
the KL divergence between them and averages over all genes. Repeats this ABBA to get a symmetrized distance.
|
881
|
+
See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Symmetrised_divergence.
|
622
882
|
|
623
883
|
"""
|
624
884
|
|
@@ -632,11 +892,12 @@ class KLDivergence(AbstractDistance):
|
|
632
892
|
x_mean, x_std = X[:, i].mean(), X[:, i].std() + epsilon
|
633
893
|
y_mean, y_std = Y[:, i].mean(), Y[:, i].std() + epsilon
|
634
894
|
kl = np.log(y_std / x_std) + (x_std**2 + (x_mean - y_mean) ** 2) / (2 * y_std**2) - 1 / 2
|
635
|
-
|
895
|
+
klr = np.log(x_std / y_std) + (y_std**2 + (y_mean - x_mean) ** 2) / (2 * x_std**2) - 1 / 2
|
896
|
+
kl_all.append(kl + klr)
|
636
897
|
return sum(kl_all) / len(kl_all)
|
637
898
|
|
638
899
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
639
|
-
raise NotImplementedError("
|
900
|
+
raise NotImplementedError("SymmetricKLDivergence cannot be called on a pairwise distance matrix.")
|
640
901
|
|
641
902
|
|
642
903
|
class TTestDistance(AbstractDistance):
|
@@ -663,6 +924,23 @@ class TTestDistance(AbstractDistance):
|
|
663
924
|
raise NotImplementedError("TTestDistance cannot be called on a pairwise distance matrix.")
|
664
925
|
|
665
926
|
|
927
|
+
class KSTestDistance(AbstractDistance):
|
928
|
+
"""Average of two-sided KS test statistic between two groups"""
|
929
|
+
|
930
|
+
def __init__(self) -> None:
|
931
|
+
super().__init__()
|
932
|
+
self.accepts_precomputed = False
|
933
|
+
|
934
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
935
|
+
stats = []
|
936
|
+
for i in range(X.shape[1]):
|
937
|
+
stats.append(abs(kstest(X[:, i], Y[:, i])[0]))
|
938
|
+
return sum(stats) / len(stats)
|
939
|
+
|
940
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
941
|
+
raise NotImplementedError("KSTestDistance cannot be called on a pairwise distance matrix.")
|
942
|
+
|
943
|
+
|
666
944
|
class NBLL(AbstractDistance):
|
667
945
|
"""
|
668
946
|
Average of Log likelihood (scalar) of group B cells
|
@@ -683,16 +961,12 @@ class NBLL(AbstractDistance):
|
|
683
961
|
if not _is_count_matrix(matrix=X) or not _is_count_matrix(matrix=Y):
|
684
962
|
raise ValueError("NBLL distance only works for raw counts.")
|
685
963
|
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
if mu[0] == np.nan or theta[0] == np.nan:
|
693
|
-
raise ValueError("Could not fit a negative binomial distribution to the input data")
|
694
|
-
# calculate the nll of y
|
695
|
-
eps = np.repeat(epsilon, y.shape[0])
|
964
|
+
@numba.jit(forceobj=True)
|
965
|
+
def _compute_nll(y: np.ndarray, nb_params: tuple[float, float], epsilon: float) -> float:
|
966
|
+
mu = np.exp(nb_params[0])
|
967
|
+
theta = 1 / nb_params[1]
|
968
|
+
eps = epsilon
|
969
|
+
|
696
970
|
log_theta_mu_eps = np.log(theta + mu + eps)
|
697
971
|
nll = (
|
698
972
|
theta * (np.log(theta + eps) - log_theta_mu_eps)
|
@@ -701,9 +975,221 @@ class NBLL(AbstractDistance):
|
|
701
975
|
- gammaln(theta)
|
702
976
|
- gammaln(y + 1)
|
703
977
|
)
|
704
|
-
|
978
|
+
return nll.mean()
|
979
|
+
|
980
|
+
def _process_gene(x: np.ndarray, y: np.ndarray, epsilon: float) -> float:
|
981
|
+
try:
|
982
|
+
nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
|
983
|
+
return _compute_nll(y, nb_params, epsilon)
|
984
|
+
except np.linalg.linalg.LinAlgError:
|
985
|
+
if x.mean() < 10 and y.mean() < 10:
|
986
|
+
return 0.0
|
987
|
+
else:
|
988
|
+
return np.nan # Use NaN to indicate skipped genes
|
989
|
+
|
990
|
+
nlls = []
|
991
|
+
genes_skipped = 0
|
992
|
+
|
993
|
+
for i in range(X.shape[1]):
|
994
|
+
nll = _process_gene(X[:, i], Y[:, i], epsilon)
|
995
|
+
if np.isnan(nll):
|
996
|
+
genes_skipped += 1
|
997
|
+
else:
|
998
|
+
nlls.append(nll)
|
999
|
+
|
1000
|
+
if genes_skipped > X.shape[1] / 2:
|
1001
|
+
raise AttributeError(f"{genes_skipped} genes could not be fit, which is over half.")
|
705
1002
|
|
706
|
-
return -sum(nlls) / len(nlls)
|
1003
|
+
return -np.sum(nlls) / len(nlls)
|
707
1004
|
|
708
1005
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
709
1006
|
raise NotImplementedError("NBLL cannot be called on a pairwise distance matrix.")
|
1007
|
+
|
1008
|
+
|
1009
|
+
def _sample(X, frac=None, n=None):
|
1010
|
+
"""Returns subsample of cells in format (train, test)."""
|
1011
|
+
if frac and n:
|
1012
|
+
raise ValueError("Cannot pass both frac and n.")
|
1013
|
+
if frac:
|
1014
|
+
n_cells = max(1, int(X.shape[0] * frac))
|
1015
|
+
elif n:
|
1016
|
+
n_cells = n
|
1017
|
+
else:
|
1018
|
+
raise ValueError("Must pass either `frac` or `n`.")
|
1019
|
+
|
1020
|
+
rng = np.random.default_rng()
|
1021
|
+
sampled_indices = rng.choice(X.shape[0], n_cells, replace=False)
|
1022
|
+
remaining_indices = np.setdiff1d(np.arange(X.shape[0]), sampled_indices)
|
1023
|
+
return X[remaining_indices, :], X[sampled_indices, :]
|
1024
|
+
|
1025
|
+
|
1026
|
+
class ClassifierProbaDistance(AbstractDistance):
|
1027
|
+
"""Average of classification probabilites of a binary classifier.
|
1028
|
+
|
1029
|
+
Assumes the first condition is control and the second is perturbed.
|
1030
|
+
Always holds out 20% of the perturbed condition.
|
1031
|
+
"""
|
1032
|
+
|
1033
|
+
def __init__(self) -> None:
|
1034
|
+
super().__init__()
|
1035
|
+
self.accepts_precomputed = False
|
1036
|
+
|
1037
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1038
|
+
Y_train, Y_test = _sample(Y, frac=0.2)
|
1039
|
+
label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0]
|
1040
|
+
train = np.concatenate([X, Y_train])
|
1041
|
+
|
1042
|
+
reg = LogisticRegression()
|
1043
|
+
reg.fit(train, label)
|
1044
|
+
test_labels = reg.predict_proba(Y_test)
|
1045
|
+
return np.mean(test_labels[:, 1])
|
1046
|
+
|
1047
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1048
|
+
raise NotImplementedError("ClassifierProbaDistance cannot be called on a pairwise distance matrix.")
|
1049
|
+
|
1050
|
+
|
1051
|
+
class ClassifierClassProjection(AbstractDistance):
|
1052
|
+
"""Average of 1-(classification probability of control).
|
1053
|
+
|
1054
|
+
Warning: unlike all other distances, this must also take a list of categorical labels the same length as X.
|
1055
|
+
"""
|
1056
|
+
|
1057
|
+
def __init__(self) -> None:
|
1058
|
+
super().__init__()
|
1059
|
+
self.accepts_precomputed = False
|
1060
|
+
|
1061
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1062
|
+
raise NotImplementedError("ClassifierClassProjection can currently only be called with onesided.")
|
1063
|
+
|
1064
|
+
def onesided_distances(
|
1065
|
+
self,
|
1066
|
+
adata: AnnData,
|
1067
|
+
groupby: str,
|
1068
|
+
selected_group: str | None = None,
|
1069
|
+
groups: list[str] | None = None,
|
1070
|
+
show_progressbar: bool = True,
|
1071
|
+
n_jobs: int = -1,
|
1072
|
+
**kwargs,
|
1073
|
+
) -> Series:
|
1074
|
+
"""Unlike the parent function, all groups except the selected group are factored into the classifier.
|
1075
|
+
|
1076
|
+
Similar to the parent function, the returned dataframe contains only the specified groups.
|
1077
|
+
"""
|
1078
|
+
groups = adata.obs[groupby].unique() if groups is None else groups
|
1079
|
+
fct = track if show_progressbar else lambda iterable: iterable
|
1080
|
+
|
1081
|
+
X = adata[adata.obs[groupby] != selected_group].X
|
1082
|
+
labels = adata[adata.obs[groupby] != selected_group].obs[groupby].values
|
1083
|
+
Y = adata[adata.obs[groupby] == selected_group].X
|
1084
|
+
|
1085
|
+
reg = LogisticRegression()
|
1086
|
+
reg.fit(X, labels)
|
1087
|
+
test_probas = reg.predict_proba(Y)
|
1088
|
+
|
1089
|
+
df = pd.Series(index=groups, dtype=float)
|
1090
|
+
|
1091
|
+
for group in fct(groups):
|
1092
|
+
if group == selected_group:
|
1093
|
+
df.loc[group] = 0
|
1094
|
+
else:
|
1095
|
+
class_idx = list(reg.classes_).index(group)
|
1096
|
+
df.loc[group] = 1 - np.mean(test_probas[:, class_idx])
|
1097
|
+
df.index.name = groupby
|
1098
|
+
df.name = f"classifier_cp to {selected_group}"
|
1099
|
+
|
1100
|
+
return df
|
1101
|
+
|
1102
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1103
|
+
raise NotImplementedError("ClassifierClassProjection cannot be called on a pairwise distance matrix.")
|
1104
|
+
|
1105
|
+
|
1106
|
+
class MeanVarDistributionDistance(AbstractDistance):
|
1107
|
+
"""Distance between mean-var distributions of gene expression."""
|
1108
|
+
|
1109
|
+
def __init__(self) -> None:
|
1110
|
+
super().__init__()
|
1111
|
+
self.accepts_precomputed = False
|
1112
|
+
|
1113
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1114
|
+
"""Difference of mean-var distributions in 2 matrices.
|
1115
|
+
|
1116
|
+
Args:
|
1117
|
+
X: Normalized and log transformed cells x genes count matrix.
|
1118
|
+
Y: Normalized and log transformed cells x genes count matrix.
|
1119
|
+
"""
|
1120
|
+
|
1121
|
+
def _mean_var(x, log: bool = False):
|
1122
|
+
mean = np.mean(x, axis=0)
|
1123
|
+
var = np.var(x, axis=0)
|
1124
|
+
positive = mean > 0
|
1125
|
+
mean = mean[positive]
|
1126
|
+
var = var[positive]
|
1127
|
+
if log:
|
1128
|
+
mean = np.log(mean)
|
1129
|
+
var = np.log(var)
|
1130
|
+
return mean, var
|
1131
|
+
|
1132
|
+
def _prep_kde_data(x, y):
|
1133
|
+
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1134
|
+
|
1135
|
+
def _grid_points(d, n_points=100):
|
1136
|
+
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1137
|
+
d_min = d.min()
|
1138
|
+
d_max = d.max()
|
1139
|
+
# Compute bin size
|
1140
|
+
d_bin = (d_max - d_min) / (n_points - 2)
|
1141
|
+
d_min = d_min - d_bin
|
1142
|
+
d_max = d_max + d_bin
|
1143
|
+
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1144
|
+
|
1145
|
+
def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
|
1146
|
+
# the thread_count is determined using the factor 0.875 as recommended here:
|
1147
|
+
# https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
|
1148
|
+
with multiprocessing.Pool(thread_count) as p:
|
1149
|
+
return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
|
1150
|
+
|
1151
|
+
def _kde_eval(d, grid):
|
1152
|
+
# Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
|
1153
|
+
# can not be compared well on regions further away from the data as they are -inf
|
1154
|
+
kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
|
1155
|
+
return _parallel_score_samples(kde, grid)
|
1156
|
+
|
1157
|
+
mean_x, var_x = _mean_var(X, log=True)
|
1158
|
+
mean_y, var_y = _mean_var(Y, log=True)
|
1159
|
+
|
1160
|
+
x = _prep_kde_data(mean_x, var_x)
|
1161
|
+
y = _prep_kde_data(mean_y, var_y)
|
1162
|
+
|
1163
|
+
# Gridpoints to eval KDE on
|
1164
|
+
mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
|
1165
|
+
var_grid = _grid_points(np.concatenate([var_x, var_y]))
|
1166
|
+
grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
|
1167
|
+
|
1168
|
+
kde_x = _kde_eval(x, grid)
|
1169
|
+
kde_y = _kde_eval(y, grid)
|
1170
|
+
|
1171
|
+
kde_diff = ((kde_x - kde_y) ** 2).mean()
|
1172
|
+
|
1173
|
+
return kde_diff
|
1174
|
+
|
1175
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1176
|
+
raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
|
1177
|
+
|
1178
|
+
|
1179
|
+
class MahalanobisDistance(AbstractDistance):
|
1180
|
+
"""Mahalanobis distance between pseudobulk vectors."""
|
1181
|
+
|
1182
|
+
def __init__(self, aggregation_func: Callable = np.mean) -> None:
|
1183
|
+
super().__init__()
|
1184
|
+
self.accepts_precomputed = False
|
1185
|
+
self.aggregation_func = aggregation_func
|
1186
|
+
|
1187
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1188
|
+
return mahalanobis(
|
1189
|
+
self.aggregation_func(X, axis=0),
|
1190
|
+
self.aggregation_func(Y, axis=0),
|
1191
|
+
np.linalg.inv(np.cov(X.T)),
|
1192
|
+
)
|
1193
|
+
|
1194
|
+
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1195
|
+
raise NotImplementedError("Mahalanobis cannot be called on a pairwise distance matrix.")
|