pertpy 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. pertpy/__init__.py +2 -1
  2. pertpy/data/__init__.py +61 -0
  3. pertpy/data/_dataloader.py +27 -23
  4. pertpy/data/_datasets.py +58 -0
  5. pertpy/metadata/__init__.py +2 -0
  6. pertpy/metadata/_cell_line.py +39 -70
  7. pertpy/metadata/_compound.py +3 -4
  8. pertpy/metadata/_drug.py +2 -6
  9. pertpy/metadata/_look_up.py +38 -51
  10. pertpy/metadata/_metadata.py +7 -10
  11. pertpy/metadata/_moa.py +2 -6
  12. pertpy/plot/__init__.py +0 -5
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +2 -3
  15. pertpy/tools/__init__.py +42 -4
  16. pertpy/tools/_augur.py +14 -15
  17. pertpy/tools/_cinemaot.py +2 -2
  18. pertpy/tools/_coda/_base_coda.py +118 -142
  19. pertpy/tools/_coda/_sccoda.py +16 -15
  20. pertpy/tools/_coda/_tasccoda.py +21 -22
  21. pertpy/tools/_dialogue.py +18 -23
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +21 -16
  32. pertpy/tools/_distances/_distances.py +406 -70
  33. pertpy/tools/_enrichment.py +10 -15
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +76 -53
  36. pertpy/tools/_mixscape.py +15 -11
  37. pertpy/tools/_perturbation_space/_clustering.py +5 -2
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +20 -22
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
  41. pertpy/tools/_perturbation_space/_simple.py +3 -3
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +33 -28
  45. pertpy/tools/_scgen/_utils.py +2 -2
  46. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +22 -13
  47. pertpy-0.8.0.dist-info/RECORD +57 -0
  48. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  49. pertpy/plot/_augur.py +0 -171
  50. pertpy/plot/_coda.py +0 -601
  51. pertpy/plot/_guide_rna.py +0 -64
  52. pertpy/plot/_milopy.py +0 -209
  53. pertpy/plot/_mixscape.py +0 -355
  54. pertpy/tools/_differential_gene_expression.py +0 -325
  55. pertpy-0.7.0.dist-info/RECORD +0 -53
  56. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import multiprocessing
3
4
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING
5
+ from typing import TYPE_CHECKING, Literal, NamedTuple
5
6
 
6
7
  import numba
7
8
  import numpy as np
@@ -13,18 +14,26 @@ from ott.solvers.linear.sinkhorn import Sinkhorn
13
14
  from pandas import Series
14
15
  from rich.progress import track
15
16
  from scipy.sparse import issparse
16
- from scipy.spatial.distance import cosine
17
+ from scipy.spatial.distance import cosine, mahalanobis
17
18
  from scipy.special import gammaln
18
19
  from scipy.stats import kendalltau, kstest, pearsonr, spearmanr
19
20
  from sklearn.linear_model import LogisticRegression
20
21
  from sklearn.metrics import pairwise_distances, r2_score
21
22
  from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel
23
+ from sklearn.neighbors import KernelDensity
22
24
  from statsmodels.discrete.discrete_model import NegativeBinomialP
23
25
 
24
26
  if TYPE_CHECKING:
27
+ from collections.abc import Callable
28
+
25
29
  from anndata import AnnData
26
30
 
27
31
 
32
+ class MeanVar(NamedTuple):
33
+ mean: float
34
+ variance: float
35
+
36
+
28
37
  class Distance:
29
38
  """Distance class, used to compute distances between groups of cells.
30
39
 
@@ -80,6 +89,11 @@ class Distance:
80
89
  Average of the classification probability of the perturbation for a binary classifier.
81
90
  - "classifier_cp": classifier class projection
82
91
  Average of the class
92
+ - "mean_var_distribution": Distance between mean-variance distributions between cells of 2 groups.
93
+ Mean square distance between the mean-variance distributions of cells from 2 groups using Kernel Density Estimation (KDE).
94
+ - "mahalanobis": Mahalanobis distance between the means of cells from two groups.
95
+ It is originally used to measure distance between a point and a distribution.
96
+ in this context, it quantifies the difference between the mean profiles of a target group and a reference group.
83
97
 
84
98
  Attributes:
85
99
  metric: Name of distance metric.
@@ -99,6 +113,7 @@ class Distance:
99
113
  def __init__(
100
114
  self,
101
115
  metric: str = "edistance",
116
+ agg_fct: Callable = np.mean,
102
117
  layer_key: str = None,
103
118
  obsm_key: str = None,
104
119
  cell_wise_metric: str = "euclidean",
@@ -106,37 +121,38 @@ class Distance:
106
121
  """Initialize Distance class.
107
122
 
108
123
  Args:
109
- metric: Distance metric to use. Defaults to "edistance".
124
+ metric: Distance metric to use.
125
+ agg_fct: Aggregation function to generate pseudobulk vectors.
110
126
  layer_key: Name of the counts layer containing raw counts to calculate distances for.
111
127
  Mutually exclusive with 'obsm_key'.
112
- Defaults to None and is then not used.
128
+ Is not used if `None`.
113
129
  obsm_key: Name of embedding in adata.obsm to use.
114
- Mutually exclusive with 'counts_layer_key'.
115
- Defaults to None, but is set to "X_pca" if not set explicitly internally.
130
+ Mutually exclusive with 'layer_key'.
131
+ Defaults to None, but is set to "X_pca" if not explicitly set internally.
116
132
  cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells.
117
- Defaults to "euclidean".
118
133
  """
119
134
  metric_fct: AbstractDistance = None
135
+ self.aggregation_func = agg_fct
120
136
  if metric == "edistance":
121
137
  metric_fct = Edistance()
122
138
  elif metric == "euclidean":
123
- metric_fct = EuclideanDistance()
139
+ metric_fct = EuclideanDistance(self.aggregation_func)
124
140
  elif metric == "root_mean_squared_error":
125
- metric_fct = EuclideanDistance()
141
+ metric_fct = EuclideanDistance(self.aggregation_func)
126
142
  elif metric == "mse":
127
- metric_fct = MeanSquaredDistance()
143
+ metric_fct = MeanSquaredDistance(self.aggregation_func)
128
144
  elif metric == "mean_absolute_error":
129
- metric_fct = MeanAbsoluteDistance()
145
+ metric_fct = MeanAbsoluteDistance(self.aggregation_func)
130
146
  elif metric == "pearson_distance":
131
- metric_fct = PearsonDistance()
147
+ metric_fct = PearsonDistance(self.aggregation_func)
132
148
  elif metric == "spearman_distance":
133
- metric_fct = SpearmanDistance()
149
+ metric_fct = SpearmanDistance(self.aggregation_func)
134
150
  elif metric == "kendalltau_distance":
135
- metric_fct = KendallTauDistance()
151
+ metric_fct = KendallTauDistance(self.aggregation_func)
136
152
  elif metric == "cosine_distance":
137
- metric_fct = CosineDistance()
153
+ metric_fct = CosineDistance(self.aggregation_func)
138
154
  elif metric == "r2_distance":
139
- metric_fct = R2ScoreDistance()
155
+ metric_fct = R2ScoreDistance(self.aggregation_func)
140
156
  elif metric == "mean_pairwise":
141
157
  metric_fct = MeanPairwiseDistance()
142
158
  elif metric == "mmd":
@@ -155,14 +171,17 @@ class Distance:
155
171
  metric_fct = ClassifierProbaDistance()
156
172
  elif metric == "classifier_cp":
157
173
  metric_fct = ClassifierClassProjection()
174
+ elif metric == "mean_var_distribution":
175
+ metric_fct = MeanVarDistributionDistance()
176
+ elif metric == "mahalanobis":
177
+ metric_fct = MahalanobisDistance(self.aggregation_func)
158
178
  else:
159
179
  raise ValueError(f"Metric {metric} not recognized.")
160
180
  self.metric_fct = metric_fct
161
181
 
162
182
  if layer_key and obsm_key:
163
183
  raise ValueError(
164
- "Cannot use 'counts_layer_key' and 'obsm_key' at the same time.\n"
165
- "Please provide only one of the two keys."
184
+ "Cannot use 'layer_key' and 'obsm_key' at the same time.\n" "Please provide only one of the two keys."
166
185
  )
167
186
  if not layer_key and not obsm_key:
168
187
  obsm_key = "X_pca"
@@ -195,37 +214,80 @@ class Distance:
195
214
  >>> D = Distance(X, Y)
196
215
  """
197
216
  if issparse(X):
198
- X = X.A
217
+ X = X.toarray()
199
218
  if issparse(Y):
200
- Y = Y.A
219
+ Y = Y.toarray()
201
220
 
202
221
  if len(X) == 0 or len(Y) == 0:
203
222
  raise ValueError("Neither X nor Y can be empty.")
204
223
 
205
224
  return self.metric_fct(X, Y, **kwargs)
206
225
 
226
+ def bootstrap(
227
+ self,
228
+ X: np.ndarray,
229
+ Y: np.ndarray,
230
+ *,
231
+ n_bootstrap: int = 100,
232
+ random_state: int = 0,
233
+ **kwargs,
234
+ ) -> MeanVar:
235
+ """Bootstrap computation of mean and variance of the distance between vectors X and Y.
236
+
237
+ Args:
238
+ X: First vector of shape (n_samples, n_features).
239
+ Y: Second vector of shape (n_samples, n_features).
240
+ n_bootstrap: Number of bootstrap samples.
241
+ random_state: Random state for bootstrapping.
242
+
243
+ Returns:
244
+ MeanVar: Mean and variance of distance between X and Y.
245
+
246
+ Examples:
247
+ >>> import pertpy as pt
248
+ >>> adata = pt.dt.distance_example()
249
+ >>> Distance = pt.tools.Distance(metric="edistance")
250
+ >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
251
+ >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
252
+ >>> D = Distance.bootstrap(X, Y)
253
+ """
254
+ return self._bootstrap_mode(
255
+ X,
256
+ Y,
257
+ n_bootstraps=n_bootstrap,
258
+ random_state=random_state,
259
+ **kwargs,
260
+ )
261
+
207
262
  def pairwise(
208
263
  self,
209
264
  adata: AnnData,
210
265
  groupby: str,
211
266
  groups: list[str] | None = None,
267
+ bootstrap: bool = False,
268
+ n_bootstrap: int = 100,
269
+ random_state: int = 0,
212
270
  show_progressbar: bool = True,
213
271
  n_jobs: int = -1,
214
272
  **kwargs,
215
- ) -> pd.DataFrame:
273
+ ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
216
274
  """Get pairwise distances between groups of cells.
217
275
 
218
276
  Args:
219
277
  adata: Annotated data matrix.
220
278
  groupby: Column name in adata.obs.
221
279
  groups: List of groups to compute pairwise distances for.
222
- If None, uses all groups. Defaults to None.
223
- show_progressbar: Whether to show progress bar. Defaults to True.
280
+ If None, uses all groups.
281
+ bootstrap: Whether to bootstrap the distance.
282
+ n_bootstrap: Number of bootstrap samples.
283
+ random_state: Random state for bootstrapping.
284
+ show_progressbar: Whether to show progress bar.
224
285
  n_jobs: Number of cores to use. Defaults to -1 (all).
225
286
  kwargs: Additional keyword arguments passed to the metric function.
226
287
 
227
288
  Returns:
228
289
  pd.DataFrame: Dataframe with pairwise distances.
290
+ tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of pairwise distances.
229
291
 
230
292
  Examples:
231
293
  >>> import pertpy as pt
@@ -236,6 +298,8 @@ class Distance:
236
298
  groups = adata.obs[groupby].unique() if groups is None else groups
237
299
  grouping = adata.obs[groupby].copy()
238
300
  df = pd.DataFrame(index=groups, columns=groups, dtype=float)
301
+ if bootstrap:
302
+ df_var = pd.DataFrame(index=groups, columns=groups, dtype=float)
239
303
  fct = track if show_progressbar else lambda iterable: iterable
240
304
 
241
305
  # Some metrics are able to handle precomputed distances. This means that
@@ -251,16 +315,29 @@ class Distance:
251
315
  for index_x, group_x in enumerate(fct(groups)):
252
316
  idx_x = grouping == group_x
253
317
  for group_y in groups[index_x:]: # type: ignore
254
- if group_x == group_y:
255
- dist = 0.0 # by distance axiom
318
+ # subset the pairwise distance matrix to the two groups
319
+ idx_y = grouping == group_y
320
+ sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
321
+ sub_idx = grouping[idx_x | idx_y] == group_x
322
+ if not bootstrap:
323
+ if group_x == group_y:
324
+ dist = 0.0
325
+ else:
326
+ dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
327
+ df.loc[group_x, group_y] = dist
328
+ df.loc[group_y, group_x] = dist
329
+
256
330
  else:
257
- idx_y = grouping == group_y
258
- # subset the pairwise distance matrix to the two groups
259
- sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
260
- sub_idx = grouping[idx_x | idx_y] == group_x
261
- dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
262
- df.loc[group_x, group_y] = dist
263
- df.loc[group_y, group_x] = dist
331
+ bootstrap_output = self._bootstrap_mode_precomputed(
332
+ sub_pwd,
333
+ sub_idx,
334
+ n_bootstraps=n_bootstrap,
335
+ random_state=random_state,
336
+ **kwargs,
337
+ )
338
+ # In the bootstrap case, distance of group to itself is a mean and can be non-zero
339
+ df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
340
+ df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
264
341
  else:
265
342
  if self.layer_key:
266
343
  embedding = adata.layers[self.layer_key]
@@ -269,18 +346,39 @@ class Distance:
269
346
  for index_x, group_x in enumerate(fct(groups)):
270
347
  cells_x = embedding[grouping == group_x].copy()
271
348
  for group_y in groups[index_x:]: # type: ignore
272
- if group_x == group_y:
273
- dist = 0.0
349
+ cells_y = embedding[grouping == group_y].copy()
350
+ if not bootstrap:
351
+ # By distance axiom, the distance between a group and itself is 0
352
+ dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
353
+
354
+ df.loc[group_x, group_y] = dist
355
+ df.loc[group_y, group_x] = dist
274
356
  else:
275
- cells_y = embedding[grouping == group_y].copy()
276
- dist = self(cells_x, cells_y, **kwargs)
277
- df.loc[group_x, group_y] = dist
278
- df.loc[group_y, group_x] = dist
357
+ bootstrap_output = self.bootstrap(
358
+ cells_x,
359
+ cells_y,
360
+ n_bootstrap=n_bootstrap,
361
+ random_state=random_state,
362
+ **kwargs,
363
+ )
364
+ # In the bootstrap case, distance of group to itself is a mean and can be non-zero
365
+ df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
366
+ df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
367
+
279
368
  df.index.name = groupby
280
369
  df.columns.name = groupby
281
370
  df.name = f"pairwise {self.metric}"
282
371
 
283
- return df
372
+ if not bootstrap:
373
+ return df
374
+ else:
375
+ df = df.fillna(0)
376
+ df_var.index.name = groupby
377
+ df_var.columns.name = groupby
378
+ df_var = df_var.fillna(0)
379
+ df_var.name = f"pairwise {self.metric} variance"
380
+
381
+ return df, df_var
284
382
 
285
383
  def onesided_distances(
286
384
  self,
@@ -288,10 +386,13 @@ class Distance:
288
386
  groupby: str,
289
387
  selected_group: str | None = None,
290
388
  groups: list[str] | None = None,
389
+ bootstrap: bool = False,
390
+ n_bootstrap: int = 100,
391
+ random_state: int = 0,
291
392
  show_progressbar: bool = True,
292
393
  n_jobs: int = -1,
293
394
  **kwargs,
294
- ) -> pd.DataFrame:
395
+ ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
295
396
  """Get distances between one selected cell group and the remaining other cell groups.
296
397
 
297
398
  Args:
@@ -299,13 +400,18 @@ class Distance:
299
400
  groupby: Column name in adata.obs.
300
401
  selected_group: Group to compute pairwise distances to all other.
301
402
  groups: List of groups to compute distances to selected_group for.
302
- If None, uses all groups. Defaults to None.
303
- show_progressbar: Whether to show progress bar. Defaults to True.
403
+ If None, uses all groups.
404
+ bootstrap: Whether to bootstrap the distance.
405
+ n_bootstrap: Number of bootstrap samples.
406
+ random_state: Random state for bootstrapping.
407
+ show_progressbar: Whether to show progress bar.
304
408
  n_jobs: Number of cores to use. Defaults to -1 (all).
305
409
  kwargs: Additional keyword arguments passed to the metric function.
306
410
 
307
411
  Returns:
308
412
  pd.DataFrame: Dataframe with distances of groups to selected_group.
413
+ tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group.
414
+
309
415
 
310
416
  Examples:
311
417
  >>> import pertpy as pt
@@ -314,20 +420,30 @@ class Distance:
314
420
  >>> pairwise_df = Distance.onesided_distances(adata, groupby="perturbation", selected_group="control")
315
421
  """
316
422
  if self.metric == "classifier_cp":
423
+ if bootstrap:
424
+ raise NotImplementedError("Currently, ClassifierClassProjection does not support bootstrapping.")
317
425
  return self.metric_fct.onesided_distances( # type: ignore
318
- adata, groupby, selected_group, groups, show_progressbar, n_jobs, **kwargs
426
+ adata,
427
+ groupby,
428
+ selected_group,
429
+ groups,
430
+ show_progressbar,
431
+ n_jobs,
432
+ **kwargs,
319
433
  )
320
434
 
321
435
  groups = adata.obs[groupby].unique() if groups is None else groups
322
436
  grouping = adata.obs[groupby].copy()
323
437
  df = pd.Series(index=groups, dtype=float)
438
+ if bootstrap:
439
+ df_var = pd.Series(index=groups, dtype=float)
324
440
  fct = track if show_progressbar else lambda iterable: iterable
325
441
 
326
442
  # Some metrics are able to handle precomputed distances. This means that
327
443
  # the pairwise distances between all cells are computed once and then
328
444
  # passed to the metric function. This is much faster than computing the
329
445
  # pairwise distances for each group separately. Other metrics are not
330
- # able to handle precomputed distances such as the PsuedobulkDistance.
446
+ # able to handle precomputed distances such as the PseudobulkDistance.
331
447
  if self.metric_fct.accepts_precomputed:
332
448
  # Precompute the pairwise distances if needed
333
449
  if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp.keys():
@@ -337,14 +453,25 @@ class Distance:
337
453
  idx_x = grouping == group_x
338
454
  group_y = selected_group
339
455
  if group_x == group_y:
340
- dist = 0.0 # by distance axiom
456
+ df.loc[group_x] = 0.0 # by distance axiom
341
457
  else:
342
458
  idx_y = grouping == group_y
343
459
  # subset the pairwise distance matrix to the two groups
344
460
  sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
345
461
  sub_idx = grouping[idx_x | idx_y] == group_x
346
- dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
347
- df.loc[group_x] = dist
462
+ if not bootstrap:
463
+ dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
464
+ df.loc[group_x] = dist
465
+ else:
466
+ bootstrap_output = self._bootstrap_mode_precomputed(
467
+ sub_pwd,
468
+ sub_idx,
469
+ n_bootstraps=n_bootstrap,
470
+ random_state=random_state,
471
+ **kwargs,
472
+ )
473
+ df.loc[group_x] = bootstrap_output.mean
474
+ df_var.loc[group_x] = bootstrap_output.variance
348
475
  else:
349
476
  if self.layer_key:
350
477
  embedding = adata.layers[self.layer_key]
@@ -353,15 +480,32 @@ class Distance:
353
480
  for group_x in fct(groups):
354
481
  cells_x = embedding[grouping == group_x].copy()
355
482
  group_y = selected_group
356
- if group_x == group_y:
357
- dist = 0.0
483
+ cells_y = embedding[grouping == group_y].copy()
484
+ if not bootstrap:
485
+ # By distance axiom, the distance between a group and itself is 0
486
+ dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
487
+ df.loc[group_x] = dist
358
488
  else:
359
- cells_y = embedding[grouping == group_y].copy()
360
- dist = self(cells_x, cells_y, **kwargs)
361
- df.loc[group_x] = dist
489
+ bootstrap_output = self.bootstrap(
490
+ cells_x,
491
+ cells_y,
492
+ n_bootstrap=n_bootstrap,
493
+ random_state=random_state,
494
+ **kwargs,
495
+ )
496
+ # In the bootstrap case, distance of group to itself is a mean and can be non-zero
497
+ df.loc[group_x] = bootstrap_output.mean
498
+ df_var.loc[group_x] = bootstrap_output.variance
362
499
  df.index.name = groupby
363
500
  df.name = f"{self.metric} to {selected_group}"
364
- return df
501
+ if not bootstrap:
502
+ return df
503
+ else:
504
+ df_var.index.name = groupby
505
+ df_var = df_var.fillna(0)
506
+ df_var.name = f"pairwise {self.metric} variance to {selected_group}"
507
+
508
+ return df, df_var
365
509
 
366
510
  def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None:
367
511
  """Precompute pairwise distances between all cells, writes to adata.obsp.
@@ -387,6 +531,77 @@ class Distance:
387
531
  pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs)
388
532
  adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd
389
533
 
534
+ def compare_distance(
535
+ self,
536
+ pert: np.ndarray,
537
+ pred: np.ndarray,
538
+ ctrl: np.ndarray,
539
+ mode: Literal["simple", "scaled"] = "simple",
540
+ fit_to_pert_and_ctrl: bool = False,
541
+ **kwargs,
542
+ ) -> float:
543
+ """Compute the score of simulating a perturbation.
544
+
545
+ Args:
546
+ pert: Real perturbed data.
547
+ pred: Simulated perturbed data.
548
+ ctrl: Control data
549
+ mode: Mode to use.
550
+ fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`.
551
+ kwargs: Additional keyword arguments passed to the metric function.
552
+ """
553
+ if mode == "simple":
554
+ pass # nothing to be done
555
+ elif mode == "scaled":
556
+ from sklearn.preprocessing import MinMaxScaler
557
+
558
+ scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl)
559
+ pred = scaler.transform(pred)
560
+ pert = scaler.transform(pert)
561
+ else:
562
+ raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.")
563
+
564
+ d1 = self.metric_fct(pert, pred, **kwargs)
565
+ d2 = self.metric_fct(ctrl, pred, **kwargs)
566
+ return d1 / d2
567
+
568
+ def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
569
+ rng = np.random.default_rng(random_state)
570
+
571
+ distances = []
572
+ for _ in range(n_bootstraps):
573
+ X_bootstrapped = X[rng.choice(a=X.shape[0], size=X.shape[0], replace=True)]
574
+ Y_bootstrapped = Y[rng.choice(a=Y.shape[0], size=X.shape[0], replace=True)]
575
+
576
+ distance = self(X_bootstrapped, Y_bootstrapped, **kwargs)
577
+ distances.append(distance)
578
+
579
+ mean = np.mean(distances)
580
+ variance = np.var(distances)
581
+ return MeanVar(mean=mean, variance=variance)
582
+
583
+ def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
584
+ rng = np.random.default_rng(random_state)
585
+
586
+ distances = []
587
+ for _ in range(n_bootstraps):
588
+ # To maintain the number of cells for both groups (whatever balancing they may have),
589
+ # we sample the positive and negative indices separately
590
+ bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True)
591
+ bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True)
592
+ bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx])
593
+ bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx)
594
+
595
+ bootstrap_sub_idx = sub_idx[bootstrap_idx]
596
+ bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs]
597
+
598
+ distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs)
599
+ distances.append(distance)
600
+
601
+ mean = np.mean(distances)
602
+ variance = np.var(distances)
603
+ return MeanVar(mean=mean, variance=variance)
604
+
390
605
 
391
606
  class AbstractDistance(ABC):
392
607
  """Abstract class of distance metrics between two sets of vectors."""
@@ -500,12 +715,17 @@ class WassersteinDistance(AbstractDistance):
500
715
  class EuclideanDistance(AbstractDistance):
501
716
  """Euclidean distance between pseudobulk vectors."""
502
717
 
503
- def __init__(self) -> None:
718
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
504
719
  super().__init__()
505
720
  self.accepts_precomputed = False
721
+ self.aggregation_func = aggregation_func
506
722
 
507
723
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
508
- return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=2, **kwargs)
724
+ return np.linalg.norm(
725
+ self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
726
+ ord=2,
727
+ **kwargs,
728
+ )
509
729
 
510
730
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
511
731
  raise NotImplementedError("EuclideanDistance cannot be called on a pairwise distance matrix.")
@@ -514,12 +734,21 @@ class EuclideanDistance(AbstractDistance):
514
734
  class MeanSquaredDistance(AbstractDistance):
515
735
  """Mean squared distance between pseudobulk vectors."""
516
736
 
517
- def __init__(self) -> None:
737
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
518
738
  super().__init__()
519
739
  self.accepts_precomputed = False
740
+ self.aggregation_func = aggregation_func
520
741
 
521
742
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
522
- return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=2, **kwargs) ** 2 / X.shape[1]
743
+ return (
744
+ np.linalg.norm(
745
+ self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
746
+ ord=2,
747
+ **kwargs,
748
+ )
749
+ ** 2
750
+ / X.shape[1]
751
+ )
523
752
 
524
753
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
525
754
  raise NotImplementedError("MeanSquaredDistance cannot be called on a pairwise distance matrix.")
@@ -528,12 +757,20 @@ class MeanSquaredDistance(AbstractDistance):
528
757
  class MeanAbsoluteDistance(AbstractDistance):
529
758
  """Absolute (Norm-1) distance between pseudobulk vectors."""
530
759
 
531
- def __init__(self) -> None:
760
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
532
761
  super().__init__()
533
762
  self.accepts_precomputed = False
763
+ self.aggregation_func = aggregation_func
534
764
 
535
765
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
536
- return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=1, **kwargs) / X.shape[1]
766
+ return (
767
+ np.linalg.norm(
768
+ self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
769
+ ord=1,
770
+ **kwargs,
771
+ )
772
+ / X.shape[1]
773
+ )
537
774
 
538
775
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
539
776
  raise NotImplementedError("MeanAbsoluteDistance cannot be called on a pairwise distance matrix.")
@@ -558,12 +795,13 @@ class MeanPairwiseDistance(AbstractDistance):
558
795
  class PearsonDistance(AbstractDistance):
559
796
  """Pearson distance between pseudobulk vectors."""
560
797
 
561
- def __init__(self) -> None:
798
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
562
799
  super().__init__()
563
800
  self.accepts_precomputed = False
801
+ self.aggregation_func = aggregation_func
564
802
 
565
803
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
566
- return 1 - pearsonr(X.mean(axis=0), Y.mean(axis=0))[0]
804
+ return 1 - pearsonr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
567
805
 
568
806
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
569
807
  raise NotImplementedError("PearsonDistance cannot be called on a pairwise distance matrix.")
@@ -572,12 +810,13 @@ class PearsonDistance(AbstractDistance):
572
810
  class SpearmanDistance(AbstractDistance):
573
811
  """Spearman distance between pseudobulk vectors."""
574
812
 
575
- def __init__(self) -> None:
813
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
576
814
  super().__init__()
577
815
  self.accepts_precomputed = False
816
+ self.aggregation_func = aggregation_func
578
817
 
579
818
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
580
- return 1 - spearmanr(X.mean(axis=0), Y.mean(axis=0))[0]
819
+ return 1 - spearmanr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
581
820
 
582
821
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
583
822
  raise NotImplementedError("SpearmanDistance cannot be called on a pairwise distance matrix.")
@@ -586,12 +825,13 @@ class SpearmanDistance(AbstractDistance):
586
825
  class KendallTauDistance(AbstractDistance):
587
826
  """Kendall-tau distance between pseudobulk vectors."""
588
827
 
589
- def __init__(self) -> None:
828
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
590
829
  super().__init__()
591
830
  self.accepts_precomputed = False
831
+ self.aggregation_func = aggregation_func
592
832
 
593
833
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
594
- x, y = X.mean(axis=0), Y.mean(axis=0)
834
+ x, y = self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)
595
835
  n = len(x)
596
836
  tau_corr = kendalltau(x, y).statistic
597
837
  tau_dist = (1 - tau_corr) * n * (n - 1) / 4
@@ -604,12 +844,13 @@ class KendallTauDistance(AbstractDistance):
604
844
  class CosineDistance(AbstractDistance):
605
845
  """Cosine distance between pseudobulk vectors."""
606
846
 
607
- def __init__(self) -> None:
847
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
608
848
  super().__init__()
609
849
  self.accepts_precomputed = False
850
+ self.aggregation_func = aggregation_func
610
851
 
611
852
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
612
- return cosine(X.mean(axis=0), Y.mean(axis=0))
853
+ return cosine(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
613
854
 
614
855
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
615
856
  raise NotImplementedError("CosineDistance cannot be called on a pairwise distance matrix.")
@@ -620,12 +861,13 @@ class R2ScoreDistance(AbstractDistance):
620
861
 
621
862
  # NOTE: This is not a distance metric but a similarity metric.
622
863
 
623
- def __init__(self) -> None:
864
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
624
865
  super().__init__()
625
866
  self.accepts_precomputed = False
867
+ self.aggregation_func = aggregation_func
626
868
 
627
869
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
628
- return 1 - r2_score(X.mean(axis=0), Y.mean(axis=0))
870
+ return 1 - r2_score(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
629
871
 
630
872
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
631
873
  raise NotImplementedError("R2ScoreDistance cannot be called on a pairwise distance matrix.")
@@ -834,6 +1076,7 @@ class ClassifierClassProjection(AbstractDistance):
834
1076
  Similar to the parent function, the returned dataframe contains only the specified groups.
835
1077
  """
836
1078
  groups = adata.obs[groupby].unique() if groups is None else groups
1079
+ fct = track if show_progressbar else lambda iterable: iterable
837
1080
 
838
1081
  X = adata[adata.obs[groupby] != selected_group].X
839
1082
  labels = adata[adata.obs[groupby] != selected_group].obs[groupby].values
@@ -844,7 +1087,8 @@ class ClassifierClassProjection(AbstractDistance):
844
1087
  test_probas = reg.predict_proba(Y)
845
1088
 
846
1089
  df = pd.Series(index=groups, dtype=float)
847
- for group in groups:
1090
+
1091
+ for group in fct(groups):
848
1092
  if group == selected_group:
849
1093
  df.loc[group] = 0
850
1094
  else:
@@ -857,3 +1101,95 @@ class ClassifierClassProjection(AbstractDistance):
857
1101
 
858
1102
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
859
1103
  raise NotImplementedError("ClassifierClassProjection cannot be called on a pairwise distance matrix.")
1104
+
1105
+
1106
+ class MeanVarDistributionDistance(AbstractDistance):
1107
+ """Distance between mean-var distributions of gene expression."""
1108
+
1109
+ def __init__(self) -> None:
1110
+ super().__init__()
1111
+ self.accepts_precomputed = False
1112
+
1113
+ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1114
+ """Difference of mean-var distributions in 2 matrices.
1115
+
1116
+ Args:
1117
+ X: Normalized and log transformed cells x genes count matrix.
1118
+ Y: Normalized and log transformed cells x genes count matrix.
1119
+ """
1120
+
1121
+ def _mean_var(x, log: bool = False):
1122
+ mean = np.mean(x, axis=0)
1123
+ var = np.var(x, axis=0)
1124
+ positive = mean > 0
1125
+ mean = mean[positive]
1126
+ var = var[positive]
1127
+ if log:
1128
+ mean = np.log(mean)
1129
+ var = np.log(var)
1130
+ return mean, var
1131
+
1132
+ def _prep_kde_data(x, y):
1133
+ return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
1134
+
1135
+ def _grid_points(d, n_points=100):
1136
+ # Make grid, add 1 bin on lower/upper end to get final n_points
1137
+ d_min = d.min()
1138
+ d_max = d.max()
1139
+ # Compute bin size
1140
+ d_bin = (d_max - d_min) / (n_points - 2)
1141
+ d_min = d_min - d_bin
1142
+ d_max = d_max + d_bin
1143
+ return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
1144
+
1145
+ def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
1146
+ # the thread_count is determined using the factor 0.875 as recommended here:
1147
+ # https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
1148
+ with multiprocessing.Pool(thread_count) as p:
1149
+ return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
1150
+
1151
+ def _kde_eval(d, grid):
1152
+ # Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
1153
+ # can not be compared well on regions further away from the data as they are -inf
1154
+ kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
1155
+ return _parallel_score_samples(kde, grid)
1156
+
1157
+ mean_x, var_x = _mean_var(X, log=True)
1158
+ mean_y, var_y = _mean_var(Y, log=True)
1159
+
1160
+ x = _prep_kde_data(mean_x, var_x)
1161
+ y = _prep_kde_data(mean_y, var_y)
1162
+
1163
+ # Gridpoints to eval KDE on
1164
+ mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
1165
+ var_grid = _grid_points(np.concatenate([var_x, var_y]))
1166
+ grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
1167
+
1168
+ kde_x = _kde_eval(x, grid)
1169
+ kde_y = _kde_eval(y, grid)
1170
+
1171
+ kde_diff = ((kde_x - kde_y) ** 2).mean()
1172
+
1173
+ return kde_diff
1174
+
1175
+ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1176
+ raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
1177
+
1178
+
1179
+ class MahalanobisDistance(AbstractDistance):
1180
+ """Mahalanobis distance between pseudobulk vectors."""
1181
+
1182
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
1183
+ super().__init__()
1184
+ self.accepts_precomputed = False
1185
+ self.aggregation_func = aggregation_func
1186
+
1187
+ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1188
+ return mahalanobis(
1189
+ self.aggregation_func(X, axis=0),
1190
+ self.aggregation_func(Y, axis=0),
1191
+ np.linalg.inv(np.cov(X.T)),
1192
+ )
1193
+
1194
+ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1195
+ raise NotImplementedError("Mahalanobis cannot be called on a pairwise distance matrix.")