pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,35 +1,46 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import multiprocessing
3
4
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING
5
+ from typing import TYPE_CHECKING, Literal, NamedTuple
5
6
 
7
+ import numba
6
8
  import numpy as np
7
9
  import pandas as pd
8
10
  from ott.geometry.geometry import Geometry
9
11
  from ott.geometry.pointcloud import PointCloud
10
12
  from ott.problems.linear.linear_problem import LinearProblem
11
13
  from ott.solvers.linear.sinkhorn import Sinkhorn
14
+ from pandas import Series
12
15
  from rich.progress import track
13
16
  from scipy.sparse import issparse
14
- from scipy.spatial.distance import cosine
17
+ from scipy.spatial.distance import cosine, mahalanobis
15
18
  from scipy.special import gammaln
16
- from scipy.stats import kendalltau, pearsonr, spearmanr
19
+ from scipy.stats import kendalltau, kstest, pearsonr, spearmanr
20
+ from sklearn.linear_model import LogisticRegression
17
21
  from sklearn.metrics import pairwise_distances, r2_score
18
22
  from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel
23
+ from sklearn.neighbors import KernelDensity
19
24
  from statsmodels.discrete.discrete_model import NegativeBinomialP
20
25
 
21
26
  if TYPE_CHECKING:
22
- from collections.abc import Iterable
27
+ from collections.abc import Callable
23
28
 
24
29
  from anndata import AnnData
25
30
 
26
31
 
32
+ class MeanVar(NamedTuple):
33
+ mean: float
34
+ variance: float
35
+
36
+
27
37
  class Distance:
28
38
  """Distance class, used to compute distances between groups of cells.
29
39
 
30
40
  The distance metric can be specified by the user. This class also provides a
31
41
  method to compute the pairwise distances between all groups of cells.
32
42
  Currently available metrics:
43
+
33
44
  - "edistance": Energy distance (Default metric).
34
45
  In essence, it is twice the mean pairwise distance between cells of two
35
46
  groups minus the mean pairwise distance between cells within each group
@@ -55,8 +66,6 @@ class Distance:
55
66
  Coefficient of determination distance between the means of cells from two groups.
56
67
  - "mean_pairwise": Mean pairwise distance.
57
68
  Mean of the pairwise euclidean distances between cells of two groups.
58
- - "mean_pairwise": Mean pairwise distance.
59
- Mean of the pairwise euclidean distances between cells of two groups.
60
69
  - "mmd": Maximum mean discrepancy
61
70
  Maximum mean discrepancy between the cells of two groups.
62
71
  Here, uses linear, rbf, and quadratic polynomial MMD. For theory on MMD in single-cell applications, see
@@ -66,14 +75,25 @@ class Distance:
66
75
  OTT-JAX implementation of the Sinkhorn algorithm to compute the distance.
67
76
  For more information on the optimal transport solver, see
68
77
  `Cuturi et al. (2013) <https://proceedings.neurips.cc/paper/2013/file/af21d0c97db2e27e13572cbf59eb343d-Paper.pdf>`__.
69
- - "kl_divergence": Kullback–Leibler divergence distance.
78
+ - "sym_kldiv": symmetrized Kullback–Leibler divergence distance.
70
79
  Kullback–Leibler divergence of the gaussian distributions between cells of two groups.
71
- Here we fit a gaussian distribution over each group of cells and then calculate the KL divergence
80
+ Here we fit a gaussian distribution over one group of cells and then calculate the KL divergence on the other, and vice versa.
72
81
  - "t_test": t-test statistic.
73
82
  T-test statistic measure between cells of two groups.
83
+ - "ks_test": Kolmogorov-Smirnov test statistic.
84
+ Kolmogorov-Smirnov test statistic measure between cells of two groups.
74
85
  - "nb_ll": log-likelihood over negative binomial
75
86
  Average of log-likelihoods of samples of the secondary group after fitting a negative binomial distribution
76
87
  over the samples of the first group.
88
+ - "classifier_proba": probability of a binary classifier
89
+ Average of the classification probability of the perturbation for a binary classifier.
90
+ - "classifier_cp": classifier class projection
91
+ Average of the class
92
+ - "mean_var_distribution": Distance between mean-variance distributions between cells of 2 groups.
93
+ Mean square distance between the mean-variance distributions of cells from 2 groups using Kernel Density Estimation (KDE).
94
+ - "mahalanobis": Mahalanobis distance between the means of cells from two groups.
95
+ It is originally used to measure distance between a point and a distribution.
96
+ in this context, it quantifies the difference between the mean profiles of a target group and a reference group.
77
97
 
78
98
  Attributes:
79
99
  metric: Name of distance metric.
@@ -93,6 +113,7 @@ class Distance:
93
113
  def __init__(
94
114
  self,
95
115
  metric: str = "edistance",
116
+ agg_fct: Callable = np.mean,
96
117
  layer_key: str = None,
97
118
  obsm_key: str = None,
98
119
  cell_wise_metric: str = "euclidean",
@@ -100,57 +121,67 @@ class Distance:
100
121
  """Initialize Distance class.
101
122
 
102
123
  Args:
103
- metric: Distance metric to use. Defaults to "edistance".
124
+ metric: Distance metric to use.
125
+ agg_fct: Aggregation function to generate pseudobulk vectors.
104
126
  layer_key: Name of the counts layer containing raw counts to calculate distances for.
105
127
  Mutually exclusive with 'obsm_key'.
106
- Defaults to None and is then not used.
128
+ Is not used if `None`.
107
129
  obsm_key: Name of embedding in adata.obsm to use.
108
- Mutually exclusive with 'counts_layer_key'.
109
- Defaults to None, but is set to "X_pca" if not set explicitly internally.
130
+ Mutually exclusive with 'layer_key'.
131
+ Defaults to None, but is set to "X_pca" if not explicitly set internally.
110
132
  cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells.
111
- Defaults to "euclidean".
112
133
  """
113
134
  metric_fct: AbstractDistance = None
135
+ self.aggregation_func = agg_fct
114
136
  if metric == "edistance":
115
137
  metric_fct = Edistance()
116
138
  elif metric == "euclidean":
117
- metric_fct = EuclideanDistance()
139
+ metric_fct = EuclideanDistance(self.aggregation_func)
118
140
  elif metric == "root_mean_squared_error":
119
- metric_fct = EuclideanDistance()
141
+ metric_fct = EuclideanDistance(self.aggregation_func)
120
142
  elif metric == "mse":
121
- metric_fct = MeanSquaredDistance()
143
+ metric_fct = MeanSquaredDistance(self.aggregation_func)
122
144
  elif metric == "mean_absolute_error":
123
- metric_fct = MeanAbsoluteDistance()
145
+ metric_fct = MeanAbsoluteDistance(self.aggregation_func)
124
146
  elif metric == "pearson_distance":
125
- metric_fct = PearsonDistance()
147
+ metric_fct = PearsonDistance(self.aggregation_func)
126
148
  elif metric == "spearman_distance":
127
- metric_fct = SpearmanDistance()
149
+ metric_fct = SpearmanDistance(self.aggregation_func)
128
150
  elif metric == "kendalltau_distance":
129
- metric_fct = KendallTauDistance()
151
+ metric_fct = KendallTauDistance(self.aggregation_func)
130
152
  elif metric == "cosine_distance":
131
- metric_fct = CosineDistance()
153
+ metric_fct = CosineDistance(self.aggregation_func)
132
154
  elif metric == "r2_distance":
133
- metric_fct = R2ScoreDistance()
155
+ metric_fct = R2ScoreDistance(self.aggregation_func)
134
156
  elif metric == "mean_pairwise":
135
157
  metric_fct = MeanPairwiseDistance()
136
158
  elif metric == "mmd":
137
159
  metric_fct = MMD()
138
160
  elif metric == "wasserstein":
139
161
  metric_fct = WassersteinDistance()
140
- elif metric == "kl_divergence":
141
- metric_fct = KLDivergence()
162
+ elif metric == "sym_kldiv":
163
+ metric_fct = SymmetricKLDivergence()
142
164
  elif metric == "t_test":
143
165
  metric_fct = TTestDistance()
166
+ elif metric == "ks_test":
167
+ metric_fct = KSTestDistance()
144
168
  elif metric == "nb_ll":
145
169
  metric_fct = NBLL()
170
+ elif metric == "classifier_proba":
171
+ metric_fct = ClassifierProbaDistance()
172
+ elif metric == "classifier_cp":
173
+ metric_fct = ClassifierClassProjection()
174
+ elif metric == "mean_var_distribution":
175
+ metric_fct = MeanVarDistributionDistance()
176
+ elif metric == "mahalanobis":
177
+ metric_fct = MahalanobisDistance(self.aggregation_func)
146
178
  else:
147
179
  raise ValueError(f"Metric {metric} not recognized.")
148
180
  self.metric_fct = metric_fct
149
181
 
150
182
  if layer_key and obsm_key:
151
183
  raise ValueError(
152
- "Cannot use 'counts_layer_key' and 'obsm_key' at the same time.\n"
153
- "Please provide only one of the two keys."
184
+ "Cannot use 'layer_key' and 'obsm_key' at the same time.\n" "Please provide only one of the two keys."
154
185
  )
155
186
  if not layer_key and not obsm_key:
156
187
  obsm_key = "X_pca"
@@ -183,37 +214,80 @@ class Distance:
183
214
  >>> D = Distance(X, Y)
184
215
  """
185
216
  if issparse(X):
186
- X = X.A
217
+ X = X.toarray()
187
218
  if issparse(Y):
188
- Y = Y.A
219
+ Y = Y.toarray()
189
220
 
190
221
  if len(X) == 0 or len(Y) == 0:
191
222
  raise ValueError("Neither X nor Y can be empty.")
192
223
 
193
224
  return self.metric_fct(X, Y, **kwargs)
194
225
 
226
+ def bootstrap(
227
+ self,
228
+ X: np.ndarray,
229
+ Y: np.ndarray,
230
+ *,
231
+ n_bootstrap: int = 100,
232
+ random_state: int = 0,
233
+ **kwargs,
234
+ ) -> MeanVar:
235
+ """Bootstrap computation of mean and variance of the distance between vectors X and Y.
236
+
237
+ Args:
238
+ X: First vector of shape (n_samples, n_features).
239
+ Y: Second vector of shape (n_samples, n_features).
240
+ n_bootstrap: Number of bootstrap samples.
241
+ random_state: Random state for bootstrapping.
242
+
243
+ Returns:
244
+ MeanVar: Mean and variance of distance between X and Y.
245
+
246
+ Examples:
247
+ >>> import pertpy as pt
248
+ >>> adata = pt.dt.distance_example()
249
+ >>> Distance = pt.tools.Distance(metric="edistance")
250
+ >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
251
+ >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
252
+ >>> D = Distance.bootstrap(X, Y)
253
+ """
254
+ return self._bootstrap_mode(
255
+ X,
256
+ Y,
257
+ n_bootstraps=n_bootstrap,
258
+ random_state=random_state,
259
+ **kwargs,
260
+ )
261
+
195
262
  def pairwise(
196
263
  self,
197
264
  adata: AnnData,
198
265
  groupby: str,
199
266
  groups: list[str] | None = None,
267
+ bootstrap: bool = False,
268
+ n_bootstrap: int = 100,
269
+ random_state: int = 0,
200
270
  show_progressbar: bool = True,
201
271
  n_jobs: int = -1,
202
272
  **kwargs,
203
- ) -> pd.DataFrame:
273
+ ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
204
274
  """Get pairwise distances between groups of cells.
205
275
 
206
276
  Args:
207
277
  adata: Annotated data matrix.
208
278
  groupby: Column name in adata.obs.
209
279
  groups: List of groups to compute pairwise distances for.
210
- If None, uses all groups. Defaults to None.
211
- show_progressbar: Whether to show progress bar. Defaults to True.
280
+ If None, uses all groups.
281
+ bootstrap: Whether to bootstrap the distance.
282
+ n_bootstrap: Number of bootstrap samples.
283
+ random_state: Random state for bootstrapping.
284
+ show_progressbar: Whether to show progress bar.
212
285
  n_jobs: Number of cores to use. Defaults to -1 (all).
213
286
  kwargs: Additional keyword arguments passed to the metric function.
214
287
 
215
288
  Returns:
216
289
  pd.DataFrame: Dataframe with pairwise distances.
290
+ tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of pairwise distances.
217
291
 
218
292
  Examples:
219
293
  >>> import pertpy as pt
@@ -224,6 +298,8 @@ class Distance:
224
298
  groups = adata.obs[groupby].unique() if groups is None else groups
225
299
  grouping = adata.obs[groupby].copy()
226
300
  df = pd.DataFrame(index=groups, columns=groups, dtype=float)
301
+ if bootstrap:
302
+ df_var = pd.DataFrame(index=groups, columns=groups, dtype=float)
227
303
  fct = track if show_progressbar else lambda iterable: iterable
228
304
 
229
305
  # Some metrics are able to handle precomputed distances. This means that
@@ -239,16 +315,29 @@ class Distance:
239
315
  for index_x, group_x in enumerate(fct(groups)):
240
316
  idx_x = grouping == group_x
241
317
  for group_y in groups[index_x:]: # type: ignore
242
- if group_x == group_y:
243
- dist = 0.0 # by distance axiom
318
+ # subset the pairwise distance matrix to the two groups
319
+ idx_y = grouping == group_y
320
+ sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
321
+ sub_idx = grouping[idx_x | idx_y] == group_x
322
+ if not bootstrap:
323
+ if group_x == group_y:
324
+ dist = 0.0
325
+ else:
326
+ dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
327
+ df.loc[group_x, group_y] = dist
328
+ df.loc[group_y, group_x] = dist
329
+
244
330
  else:
245
- idx_y = grouping == group_y
246
- # subset the pairwise distance matrix to the two groups
247
- sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
248
- sub_idx = grouping[idx_x | idx_y] == group_x
249
- dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
250
- df.loc[group_x, group_y] = dist
251
- df.loc[group_y, group_x] = dist
331
+ bootstrap_output = self._bootstrap_mode_precomputed(
332
+ sub_pwd,
333
+ sub_idx,
334
+ n_bootstraps=n_bootstrap,
335
+ random_state=random_state,
336
+ **kwargs,
337
+ )
338
+ # In the bootstrap case, distance of group to itself is a mean and can be non-zero
339
+ df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
340
+ df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
252
341
  else:
253
342
  if self.layer_key:
254
343
  embedding = adata.layers[self.layer_key]
@@ -257,18 +346,39 @@ class Distance:
257
346
  for index_x, group_x in enumerate(fct(groups)):
258
347
  cells_x = embedding[grouping == group_x].copy()
259
348
  for group_y in groups[index_x:]: # type: ignore
260
- if group_x == group_y:
261
- dist = 0.0
349
+ cells_y = embedding[grouping == group_y].copy()
350
+ if not bootstrap:
351
+ # By distance axiom, the distance between a group and itself is 0
352
+ dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
353
+
354
+ df.loc[group_x, group_y] = dist
355
+ df.loc[group_y, group_x] = dist
262
356
  else:
263
- cells_y = embedding[grouping == group_y].copy()
264
- dist = self(cells_x, cells_y, **kwargs)
265
- df.loc[group_x, group_y] = dist
266
- df.loc[group_y, group_x] = dist
357
+ bootstrap_output = self.bootstrap(
358
+ cells_x,
359
+ cells_y,
360
+ n_bootstrap=n_bootstrap,
361
+ random_state=random_state,
362
+ **kwargs,
363
+ )
364
+ # In the bootstrap case, distance of group to itself is a mean and can be non-zero
365
+ df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
366
+ df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
367
+
267
368
  df.index.name = groupby
268
369
  df.columns.name = groupby
269
370
  df.name = f"pairwise {self.metric}"
270
371
 
271
- return df
372
+ if not bootstrap:
373
+ return df
374
+ else:
375
+ df = df.fillna(0)
376
+ df_var.index.name = groupby
377
+ df_var.columns.name = groupby
378
+ df_var = df_var.fillna(0)
379
+ df_var.name = f"pairwise {self.metric} variance"
380
+
381
+ return df, df_var
272
382
 
273
383
  def onesided_distances(
274
384
  self,
@@ -276,24 +386,32 @@ class Distance:
276
386
  groupby: str,
277
387
  selected_group: str | None = None,
278
388
  groups: list[str] | None = None,
389
+ bootstrap: bool = False,
390
+ n_bootstrap: int = 100,
391
+ random_state: int = 0,
279
392
  show_progressbar: bool = True,
280
393
  n_jobs: int = -1,
281
394
  **kwargs,
282
- ) -> pd.DataFrame:
283
- """Get pairwise distances between groups of cells.
395
+ ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
396
+ """Get distances between one selected cell group and the remaining other cell groups.
284
397
 
285
398
  Args:
286
399
  adata: Annotated data matrix.
287
400
  groupby: Column name in adata.obs.
288
401
  selected_group: Group to compute pairwise distances to all other.
289
402
  groups: List of groups to compute distances to selected_group for.
290
- If None, uses all groups. Defaults to None.
291
- show_progressbar: Whether to show progress bar. Defaults to True.
403
+ If None, uses all groups.
404
+ bootstrap: Whether to bootstrap the distance.
405
+ n_bootstrap: Number of bootstrap samples.
406
+ random_state: Random state for bootstrapping.
407
+ show_progressbar: Whether to show progress bar.
292
408
  n_jobs: Number of cores to use. Defaults to -1 (all).
293
409
  kwargs: Additional keyword arguments passed to the metric function.
294
410
 
295
411
  Returns:
296
412
  pd.DataFrame: Dataframe with distances of groups to selected_group.
413
+ tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group.
414
+
297
415
 
298
416
  Examples:
299
417
  >>> import pertpy as pt
@@ -301,16 +419,31 @@ class Distance:
301
419
  >>> Distance = pt.tools.Distance(metric="edistance")
302
420
  >>> pairwise_df = Distance.onesided_distances(adata, groupby="perturbation", selected_group="control")
303
421
  """
422
+ if self.metric == "classifier_cp":
423
+ if bootstrap:
424
+ raise NotImplementedError("Currently, ClassifierClassProjection does not support bootstrapping.")
425
+ return self.metric_fct.onesided_distances( # type: ignore
426
+ adata,
427
+ groupby,
428
+ selected_group,
429
+ groups,
430
+ show_progressbar,
431
+ n_jobs,
432
+ **kwargs,
433
+ )
434
+
304
435
  groups = adata.obs[groupby].unique() if groups is None else groups
305
436
  grouping = adata.obs[groupby].copy()
306
437
  df = pd.Series(index=groups, dtype=float)
438
+ if bootstrap:
439
+ df_var = pd.Series(index=groups, dtype=float)
307
440
  fct = track if show_progressbar else lambda iterable: iterable
308
441
 
309
442
  # Some metrics are able to handle precomputed distances. This means that
310
443
  # the pairwise distances between all cells are computed once and then
311
444
  # passed to the metric function. This is much faster than computing the
312
445
  # pairwise distances for each group separately. Other metrics are not
313
- # able to handle precomputed distances such as the PsuedobulkDistance.
446
+ # able to handle precomputed distances such as the PseudobulkDistance.
314
447
  if self.metric_fct.accepts_precomputed:
315
448
  # Precompute the pairwise distances if needed
316
449
  if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp.keys():
@@ -320,28 +453,59 @@ class Distance:
320
453
  idx_x = grouping == group_x
321
454
  group_y = selected_group
322
455
  if group_x == group_y:
323
- dist = 0.0 # by distance axiom
456
+ df.loc[group_x] = 0.0 # by distance axiom
324
457
  else:
325
458
  idx_y = grouping == group_y
326
459
  # subset the pairwise distance matrix to the two groups
327
460
  sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
328
461
  sub_idx = grouping[idx_x | idx_y] == group_x
329
- dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
330
- df.loc[group_x] = dist
462
+ if not bootstrap:
463
+ dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
464
+ df.loc[group_x] = dist
465
+ else:
466
+ bootstrap_output = self._bootstrap_mode_precomputed(
467
+ sub_pwd,
468
+ sub_idx,
469
+ n_bootstraps=n_bootstrap,
470
+ random_state=random_state,
471
+ **kwargs,
472
+ )
473
+ df.loc[group_x] = bootstrap_output.mean
474
+ df_var.loc[group_x] = bootstrap_output.variance
331
475
  else:
332
- embedding = adata.obsm[self.obsm_key].copy()
476
+ if self.layer_key:
477
+ embedding = adata.layers[self.layer_key]
478
+ else:
479
+ embedding = adata.obsm[self.obsm_key].copy()
333
480
  for group_x in fct(groups):
334
481
  cells_x = embedding[grouping == group_x].copy()
335
482
  group_y = selected_group
336
- if group_x == group_y:
337
- dist = 0.0
483
+ cells_y = embedding[grouping == group_y].copy()
484
+ if not bootstrap:
485
+ # By distance axiom, the distance between a group and itself is 0
486
+ dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
487
+ df.loc[group_x] = dist
338
488
  else:
339
- cells_y = embedding[grouping == group_y].copy()
340
- dist = self.metric_fct(cells_x, cells_y, **kwargs)
341
- df.loc[group_x] = dist
489
+ bootstrap_output = self.bootstrap(
490
+ cells_x,
491
+ cells_y,
492
+ n_bootstrap=n_bootstrap,
493
+ random_state=random_state,
494
+ **kwargs,
495
+ )
496
+ # In the bootstrap case, distance of group to itself is a mean and can be non-zero
497
+ df.loc[group_x] = bootstrap_output.mean
498
+ df_var.loc[group_x] = bootstrap_output.variance
342
499
  df.index.name = groupby
343
500
  df.name = f"{self.metric} to {selected_group}"
344
- return df
501
+ if not bootstrap:
502
+ return df
503
+ else:
504
+ df_var.index.name = groupby
505
+ df_var = df_var.fillna(0)
506
+ df_var.name = f"pairwise {self.metric} variance to {selected_group}"
507
+
508
+ return df, df_var
345
509
 
346
510
  def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None:
347
511
  """Precompute pairwise distances between all cells, writes to adata.obsp.
@@ -367,6 +531,77 @@ class Distance:
367
531
  pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs)
368
532
  adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd
369
533
 
534
+ def compare_distance(
535
+ self,
536
+ pert: np.ndarray,
537
+ pred: np.ndarray,
538
+ ctrl: np.ndarray,
539
+ mode: Literal["simple", "scaled"] = "simple",
540
+ fit_to_pert_and_ctrl: bool = False,
541
+ **kwargs,
542
+ ) -> float:
543
+ """Compute the score of simulating a perturbation.
544
+
545
+ Args:
546
+ pert: Real perturbed data.
547
+ pred: Simulated perturbed data.
548
+ ctrl: Control data
549
+ mode: Mode to use.
550
+ fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`.
551
+ kwargs: Additional keyword arguments passed to the metric function.
552
+ """
553
+ if mode == "simple":
554
+ pass # nothing to be done
555
+ elif mode == "scaled":
556
+ from sklearn.preprocessing import MinMaxScaler
557
+
558
+ scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl)
559
+ pred = scaler.transform(pred)
560
+ pert = scaler.transform(pert)
561
+ else:
562
+ raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.")
563
+
564
+ d1 = self.metric_fct(pert, pred, **kwargs)
565
+ d2 = self.metric_fct(ctrl, pred, **kwargs)
566
+ return d1 / d2
567
+
568
+ def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
569
+ rng = np.random.default_rng(random_state)
570
+
571
+ distances = []
572
+ for _ in range(n_bootstraps):
573
+ X_bootstrapped = X[rng.choice(a=X.shape[0], size=X.shape[0], replace=True)]
574
+ Y_bootstrapped = Y[rng.choice(a=Y.shape[0], size=X.shape[0], replace=True)]
575
+
576
+ distance = self(X_bootstrapped, Y_bootstrapped, **kwargs)
577
+ distances.append(distance)
578
+
579
+ mean = np.mean(distances)
580
+ variance = np.var(distances)
581
+ return MeanVar(mean=mean, variance=variance)
582
+
583
+ def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
584
+ rng = np.random.default_rng(random_state)
585
+
586
+ distances = []
587
+ for _ in range(n_bootstraps):
588
+ # To maintain the number of cells for both groups (whatever balancing they may have),
589
+ # we sample the positive and negative indices separately
590
+ bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True)
591
+ bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True)
592
+ bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx])
593
+ bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx)
594
+
595
+ bootstrap_sub_idx = sub_idx[bootstrap_idx]
596
+ bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs]
597
+
598
+ distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs)
599
+ distances.append(distance)
600
+
601
+ mean = np.mean(distances)
602
+ variance = np.var(distances)
603
+ return MeanVar(mean=mean, variance=variance)
604
+
370
605
 
371
606
  class AbstractDistance(ABC):
372
607
  """Abstract class of distance metrics between two sets of vectors."""
@@ -471,11 +706,8 @@ class WassersteinDistance(AbstractDistance):
471
706
  return self.solve_ot_problem(geom, **kwargs)
472
707
 
473
708
  def solve_ot_problem(self, geom: Geometry, **kwargs):
474
- # Define a linear problem with that cost structure.
475
709
  ot_prob = LinearProblem(geom)
476
- # Create a Sinkhorn solver
477
710
  solver = Sinkhorn()
478
- # Solve OT problem
479
711
  ot = solver(ot_prob, **kwargs)
480
712
  return ot.reg_ot_cost.item()
481
713
 
@@ -483,12 +715,17 @@ class WassersteinDistance(AbstractDistance):
483
715
  class EuclideanDistance(AbstractDistance):
484
716
  """Euclidean distance between pseudobulk vectors."""
485
717
 
486
- def __init__(self) -> None:
718
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
487
719
  super().__init__()
488
720
  self.accepts_precomputed = False
721
+ self.aggregation_func = aggregation_func
489
722
 
490
723
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
491
- return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=2, **kwargs)
724
+ return np.linalg.norm(
725
+ self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
726
+ ord=2,
727
+ **kwargs,
728
+ )
492
729
 
493
730
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
494
731
  raise NotImplementedError("EuclideanDistance cannot be called on a pairwise distance matrix.")
@@ -497,12 +734,21 @@ class EuclideanDistance(AbstractDistance):
497
734
  class MeanSquaredDistance(AbstractDistance):
498
735
  """Mean squared distance between pseudobulk vectors."""
499
736
 
500
- def __init__(self) -> None:
737
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
501
738
  super().__init__()
502
739
  self.accepts_precomputed = False
740
+ self.aggregation_func = aggregation_func
503
741
 
504
742
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
505
- return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=2, **kwargs) ** 0.5
743
+ return (
744
+ np.linalg.norm(
745
+ self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
746
+ ord=2,
747
+ **kwargs,
748
+ )
749
+ ** 2
750
+ / X.shape[1]
751
+ )
506
752
 
507
753
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
508
754
  raise NotImplementedError("MeanSquaredDistance cannot be called on a pairwise distance matrix.")
@@ -511,12 +757,20 @@ class MeanSquaredDistance(AbstractDistance):
511
757
  class MeanAbsoluteDistance(AbstractDistance):
512
758
  """Absolute (Norm-1) distance between pseudobulk vectors."""
513
759
 
514
- def __init__(self) -> None:
760
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
515
761
  super().__init__()
516
762
  self.accepts_precomputed = False
763
+ self.aggregation_func = aggregation_func
517
764
 
518
765
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
519
- return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=1, **kwargs)
766
+ return (
767
+ np.linalg.norm(
768
+ self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
769
+ ord=1,
770
+ **kwargs,
771
+ )
772
+ / X.shape[1]
773
+ )
520
774
 
521
775
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
522
776
  raise NotImplementedError("MeanAbsoluteDistance cannot be called on a pairwise distance matrix.")
@@ -541,12 +795,13 @@ class MeanPairwiseDistance(AbstractDistance):
541
795
  class PearsonDistance(AbstractDistance):
542
796
  """Pearson distance between pseudobulk vectors."""
543
797
 
544
- def __init__(self) -> None:
798
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
545
799
  super().__init__()
546
800
  self.accepts_precomputed = False
801
+ self.aggregation_func = aggregation_func
547
802
 
548
803
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
549
- return 1 - pearsonr(X.mean(axis=0), Y.mean(axis=0))[0]
804
+ return 1 - pearsonr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
550
805
 
551
806
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
552
807
  raise NotImplementedError("PearsonDistance cannot be called on a pairwise distance matrix.")
@@ -555,12 +810,13 @@ class PearsonDistance(AbstractDistance):
555
810
  class SpearmanDistance(AbstractDistance):
556
811
  """Spearman distance between pseudobulk vectors."""
557
812
 
558
- def __init__(self) -> None:
813
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
559
814
  super().__init__()
560
815
  self.accepts_precomputed = False
816
+ self.aggregation_func = aggregation_func
561
817
 
562
818
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
563
- return 1 - spearmanr(X.mean(axis=0), Y.mean(axis=0))[0]
819
+ return 1 - spearmanr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
564
820
 
565
821
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
566
822
  raise NotImplementedError("SpearmanDistance cannot be called on a pairwise distance matrix.")
@@ -569,12 +825,13 @@ class SpearmanDistance(AbstractDistance):
569
825
  class KendallTauDistance(AbstractDistance):
570
826
  """Kendall-tau distance between pseudobulk vectors."""
571
827
 
572
- def __init__(self) -> None:
828
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
573
829
  super().__init__()
574
830
  self.accepts_precomputed = False
831
+ self.aggregation_func = aggregation_func
575
832
 
576
833
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
577
- x, y = X.mean(axis=0), Y.mean(axis=0)
834
+ x, y = self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)
578
835
  n = len(x)
579
836
  tau_corr = kendalltau(x, y).statistic
580
837
  tau_dist = (1 - tau_corr) * n * (n - 1) / 4
@@ -587,12 +844,13 @@ class KendallTauDistance(AbstractDistance):
587
844
  class CosineDistance(AbstractDistance):
588
845
  """Cosine distance between pseudobulk vectors."""
589
846
 
590
- def __init__(self) -> None:
847
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
591
848
  super().__init__()
592
849
  self.accepts_precomputed = False
850
+ self.aggregation_func = aggregation_func
593
851
 
594
852
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
595
- return cosine(X.mean(axis=0), Y.mean(axis=0))
853
+ return cosine(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
596
854
 
597
855
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
598
856
  raise NotImplementedError("CosineDistance cannot be called on a pairwise distance matrix.")
@@ -603,22 +861,24 @@ class R2ScoreDistance(AbstractDistance):
603
861
 
604
862
  # NOTE: This is not a distance metric but a similarity metric.
605
863
 
606
- def __init__(self) -> None:
864
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
607
865
  super().__init__()
608
866
  self.accepts_precomputed = False
867
+ self.aggregation_func = aggregation_func
609
868
 
610
869
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
611
- return 1 - r2_score(X.mean(axis=0), Y.mean(axis=0))
870
+ return 1 - r2_score(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
612
871
 
613
872
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
614
873
  raise NotImplementedError("R2ScoreDistance cannot be called on a pairwise distance matrix.")
615
874
 
616
875
 
617
- class KLDivergence(AbstractDistance):
618
- """Average of KL divergence between gene distributions of two groups
876
+ class SymmetricKLDivergence(AbstractDistance):
877
+ """Average of symmetric KL divergence between gene distributions of two groups
619
878
 
620
879
  Assuming a Gaussian distribution for each gene in each group, calculates
621
- the KL divergence between them and averages over all genes
880
+ the KL divergence between them and averages over all genes. Repeats this ABBA to get a symmetrized distance.
881
+ See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Symmetrised_divergence.
622
882
 
623
883
  """
624
884
 
@@ -632,11 +892,12 @@ class KLDivergence(AbstractDistance):
632
892
  x_mean, x_std = X[:, i].mean(), X[:, i].std() + epsilon
633
893
  y_mean, y_std = Y[:, i].mean(), Y[:, i].std() + epsilon
634
894
  kl = np.log(y_std / x_std) + (x_std**2 + (x_mean - y_mean) ** 2) / (2 * y_std**2) - 1 / 2
635
- kl_all.append(kl)
895
+ klr = np.log(x_std / y_std) + (y_std**2 + (y_mean - x_mean) ** 2) / (2 * x_std**2) - 1 / 2
896
+ kl_all.append(kl + klr)
636
897
  return sum(kl_all) / len(kl_all)
637
898
 
638
899
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
639
- raise NotImplementedError("KLDivergence cannot be called on a pairwise distance matrix.")
900
+ raise NotImplementedError("SymmetricKLDivergence cannot be called on a pairwise distance matrix.")
640
901
 
641
902
 
642
903
  class TTestDistance(AbstractDistance):
@@ -663,6 +924,23 @@ class TTestDistance(AbstractDistance):
663
924
  raise NotImplementedError("TTestDistance cannot be called on a pairwise distance matrix.")
664
925
 
665
926
 
927
+ class KSTestDistance(AbstractDistance):
928
+ """Average of two-sided KS test statistic between two groups"""
929
+
930
+ def __init__(self) -> None:
931
+ super().__init__()
932
+ self.accepts_precomputed = False
933
+
934
+ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
935
+ stats = []
936
+ for i in range(X.shape[1]):
937
+ stats.append(abs(kstest(X[:, i], Y[:, i])[0]))
938
+ return sum(stats) / len(stats)
939
+
940
+ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
941
+ raise NotImplementedError("KSTestDistance cannot be called on a pairwise distance matrix.")
942
+
943
+
666
944
  class NBLL(AbstractDistance):
667
945
  """
668
946
  Average of Log likelihood (scalar) of group B cells
@@ -683,16 +961,12 @@ class NBLL(AbstractDistance):
683
961
  if not _is_count_matrix(matrix=X) or not _is_count_matrix(matrix=Y):
684
962
  raise ValueError("NBLL distance only works for raw counts.")
685
963
 
686
- nlls = []
687
- for i in range(X.shape[1]):
688
- x, y = X[:, i], Y[:, i]
689
- nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
690
- mu = np.repeat(np.exp(nb_params[0]), y.shape[0])
691
- theta = np.repeat(1 / nb_params[1], y.shape[0])
692
- if mu[0] == np.nan or theta[0] == np.nan:
693
- raise ValueError("Could not fit a negative binomial distribution to the input data")
694
- # calculate the nll of y
695
- eps = np.repeat(epsilon, y.shape[0])
964
+ @numba.jit(forceobj=True)
965
+ def _compute_nll(y: np.ndarray, nb_params: tuple[float, float], epsilon: float) -> float:
966
+ mu = np.exp(nb_params[0])
967
+ theta = 1 / nb_params[1]
968
+ eps = epsilon
969
+
696
970
  log_theta_mu_eps = np.log(theta + mu + eps)
697
971
  nll = (
698
972
  theta * (np.log(theta + eps) - log_theta_mu_eps)
@@ -701,9 +975,221 @@ class NBLL(AbstractDistance):
701
975
  - gammaln(theta)
702
976
  - gammaln(y + 1)
703
977
  )
704
- nlls.append(nll.mean())
978
+ return nll.mean()
979
+
980
+ def _process_gene(x: np.ndarray, y: np.ndarray, epsilon: float) -> float:
981
+ try:
982
+ nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
983
+ return _compute_nll(y, nb_params, epsilon)
984
+ except np.linalg.linalg.LinAlgError:
985
+ if x.mean() < 10 and y.mean() < 10:
986
+ return 0.0
987
+ else:
988
+ return np.nan # Use NaN to indicate skipped genes
989
+
990
+ nlls = []
991
+ genes_skipped = 0
992
+
993
+ for i in range(X.shape[1]):
994
+ nll = _process_gene(X[:, i], Y[:, i], epsilon)
995
+ if np.isnan(nll):
996
+ genes_skipped += 1
997
+ else:
998
+ nlls.append(nll)
999
+
1000
+ if genes_skipped > X.shape[1] / 2:
1001
+ raise AttributeError(f"{genes_skipped} genes could not be fit, which is over half.")
705
1002
 
706
- return -sum(nlls) / len(nlls)
1003
+ return -np.sum(nlls) / len(nlls)
707
1004
 
708
1005
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
709
1006
  raise NotImplementedError("NBLL cannot be called on a pairwise distance matrix.")
1007
+
1008
+
1009
+ def _sample(X, frac=None, n=None):
1010
+ """Returns subsample of cells in format (train, test)."""
1011
+ if frac and n:
1012
+ raise ValueError("Cannot pass both frac and n.")
1013
+ if frac:
1014
+ n_cells = max(1, int(X.shape[0] * frac))
1015
+ elif n:
1016
+ n_cells = n
1017
+ else:
1018
+ raise ValueError("Must pass either `frac` or `n`.")
1019
+
1020
+ rng = np.random.default_rng()
1021
+ sampled_indices = rng.choice(X.shape[0], n_cells, replace=False)
1022
+ remaining_indices = np.setdiff1d(np.arange(X.shape[0]), sampled_indices)
1023
+ return X[remaining_indices, :], X[sampled_indices, :]
1024
+
1025
+
1026
+ class ClassifierProbaDistance(AbstractDistance):
1027
+ """Average of classification probabilites of a binary classifier.
1028
+
1029
+ Assumes the first condition is control and the second is perturbed.
1030
+ Always holds out 20% of the perturbed condition.
1031
+ """
1032
+
1033
+ def __init__(self) -> None:
1034
+ super().__init__()
1035
+ self.accepts_precomputed = False
1036
+
1037
+ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1038
+ Y_train, Y_test = _sample(Y, frac=0.2)
1039
+ label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0]
1040
+ train = np.concatenate([X, Y_train])
1041
+
1042
+ reg = LogisticRegression()
1043
+ reg.fit(train, label)
1044
+ test_labels = reg.predict_proba(Y_test)
1045
+ return np.mean(test_labels[:, 1])
1046
+
1047
+ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1048
+ raise NotImplementedError("ClassifierProbaDistance cannot be called on a pairwise distance matrix.")
1049
+
1050
+
1051
+ class ClassifierClassProjection(AbstractDistance):
1052
+ """Average of 1-(classification probability of control).
1053
+
1054
+ Warning: unlike all other distances, this must also take a list of categorical labels the same length as X.
1055
+ """
1056
+
1057
+ def __init__(self) -> None:
1058
+ super().__init__()
1059
+ self.accepts_precomputed = False
1060
+
1061
+ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1062
+ raise NotImplementedError("ClassifierClassProjection can currently only be called with onesided.")
1063
+
1064
+ def onesided_distances(
1065
+ self,
1066
+ adata: AnnData,
1067
+ groupby: str,
1068
+ selected_group: str | None = None,
1069
+ groups: list[str] | None = None,
1070
+ show_progressbar: bool = True,
1071
+ n_jobs: int = -1,
1072
+ **kwargs,
1073
+ ) -> Series:
1074
+ """Unlike the parent function, all groups except the selected group are factored into the classifier.
1075
+
1076
+ Similar to the parent function, the returned dataframe contains only the specified groups.
1077
+ """
1078
+ groups = adata.obs[groupby].unique() if groups is None else groups
1079
+ fct = track if show_progressbar else lambda iterable: iterable
1080
+
1081
+ X = adata[adata.obs[groupby] != selected_group].X
1082
+ labels = adata[adata.obs[groupby] != selected_group].obs[groupby].values
1083
+ Y = adata[adata.obs[groupby] == selected_group].X
1084
+
1085
+ reg = LogisticRegression()
1086
+ reg.fit(X, labels)
1087
+ test_probas = reg.predict_proba(Y)
1088
+
1089
+ df = pd.Series(index=groups, dtype=float)
1090
+
1091
+ for group in fct(groups):
1092
+ if group == selected_group:
1093
+ df.loc[group] = 0
1094
+ else:
1095
+ class_idx = list(reg.classes_).index(group)
1096
+ df.loc[group] = 1 - np.mean(test_probas[:, class_idx])
1097
+ df.index.name = groupby
1098
+ df.name = f"classifier_cp to {selected_group}"
1099
+
1100
+ return df
1101
+
1102
+ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1103
+ raise NotImplementedError("ClassifierClassProjection cannot be called on a pairwise distance matrix.")
1104
+
1105
+
1106
+ class MeanVarDistributionDistance(AbstractDistance):
1107
+ """Distance between mean-var distributions of gene expression."""
1108
+
1109
+ def __init__(self) -> None:
1110
+ super().__init__()
1111
+ self.accepts_precomputed = False
1112
+
1113
+ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1114
+ """Difference of mean-var distributions in 2 matrices.
1115
+
1116
+ Args:
1117
+ X: Normalized and log transformed cells x genes count matrix.
1118
+ Y: Normalized and log transformed cells x genes count matrix.
1119
+ """
1120
+
1121
+ def _mean_var(x, log: bool = False):
1122
+ mean = np.mean(x, axis=0)
1123
+ var = np.var(x, axis=0)
1124
+ positive = mean > 0
1125
+ mean = mean[positive]
1126
+ var = var[positive]
1127
+ if log:
1128
+ mean = np.log(mean)
1129
+ var = np.log(var)
1130
+ return mean, var
1131
+
1132
+ def _prep_kde_data(x, y):
1133
+ return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
1134
+
1135
+ def _grid_points(d, n_points=100):
1136
+ # Make grid, add 1 bin on lower/upper end to get final n_points
1137
+ d_min = d.min()
1138
+ d_max = d.max()
1139
+ # Compute bin size
1140
+ d_bin = (d_max - d_min) / (n_points - 2)
1141
+ d_min = d_min - d_bin
1142
+ d_max = d_max + d_bin
1143
+ return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
1144
+
1145
+ def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
1146
+ # the thread_count is determined using the factor 0.875 as recommended here:
1147
+ # https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
1148
+ with multiprocessing.Pool(thread_count) as p:
1149
+ return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
1150
+
1151
+ def _kde_eval(d, grid):
1152
+ # Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
1153
+ # can not be compared well on regions further away from the data as they are -inf
1154
+ kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
1155
+ return _parallel_score_samples(kde, grid)
1156
+
1157
+ mean_x, var_x = _mean_var(X, log=True)
1158
+ mean_y, var_y = _mean_var(Y, log=True)
1159
+
1160
+ x = _prep_kde_data(mean_x, var_x)
1161
+ y = _prep_kde_data(mean_y, var_y)
1162
+
1163
+ # Gridpoints to eval KDE on
1164
+ mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
1165
+ var_grid = _grid_points(np.concatenate([var_x, var_y]))
1166
+ grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
1167
+
1168
+ kde_x = _kde_eval(x, grid)
1169
+ kde_y = _kde_eval(y, grid)
1170
+
1171
+ kde_diff = ((kde_x - kde_y) ** 2).mean()
1172
+
1173
+ return kde_diff
1174
+
1175
+ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1176
+ raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
1177
+
1178
+
1179
+ class MahalanobisDistance(AbstractDistance):
1180
+ """Mahalanobis distance between pseudobulk vectors."""
1181
+
1182
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
1183
+ super().__init__()
1184
+ self.accepts_precomputed = False
1185
+ self.aggregation_func = aggregation_func
1186
+
1187
+ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1188
+ return mahalanobis(
1189
+ self.aggregation_func(X, axis=0),
1190
+ self.aggregation_func(Y, axis=0),
1191
+ np.linalg.inv(np.cov(X.T)),
1192
+ )
1193
+
1194
+ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1195
+ raise NotImplementedError("Mahalanobis cannot be called on a pairwise distance matrix.")