pertpy 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. pertpy/__init__.py +2 -1
  2. pertpy/data/__init__.py +61 -0
  3. pertpy/data/_dataloader.py +27 -23
  4. pertpy/data/_datasets.py +58 -0
  5. pertpy/metadata/__init__.py +2 -0
  6. pertpy/metadata/_cell_line.py +39 -70
  7. pertpy/metadata/_compound.py +3 -4
  8. pertpy/metadata/_drug.py +2 -6
  9. pertpy/metadata/_look_up.py +38 -51
  10. pertpy/metadata/_metadata.py +7 -10
  11. pertpy/metadata/_moa.py +2 -6
  12. pertpy/plot/__init__.py +0 -5
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +2 -3
  15. pertpy/tools/__init__.py +42 -4
  16. pertpy/tools/_augur.py +14 -15
  17. pertpy/tools/_cinemaot.py +2 -2
  18. pertpy/tools/_coda/_base_coda.py +118 -142
  19. pertpy/tools/_coda/_sccoda.py +16 -15
  20. pertpy/tools/_coda/_tasccoda.py +21 -22
  21. pertpy/tools/_dialogue.py +18 -23
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +21 -16
  32. pertpy/tools/_distances/_distances.py +406 -70
  33. pertpy/tools/_enrichment.py +10 -15
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +76 -53
  36. pertpy/tools/_mixscape.py +15 -11
  37. pertpy/tools/_perturbation_space/_clustering.py +5 -2
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +20 -22
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
  41. pertpy/tools/_perturbation_space/_simple.py +3 -3
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +33 -28
  45. pertpy/tools/_scgen/_utils.py +2 -2
  46. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +22 -13
  47. pertpy-0.8.0.dist-info/RECORD +57 -0
  48. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  49. pertpy/plot/_augur.py +0 -171
  50. pertpy/plot/_coda.py +0 -601
  51. pertpy/plot/_guide_rna.py +0 -64
  52. pertpy/plot/_milopy.py +0 -209
  53. pertpy/plot/_mixscape.py +0 -355
  54. pertpy/tools/_differential_gene_expression.py +0 -325
  55. pertpy-0.7.0.dist-info/RECORD +0 -53
  56. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import multiprocessing
3
4
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING
5
+ from typing import TYPE_CHECKING, Literal, NamedTuple
5
6
 
6
7
  import numba
7
8
  import numpy as np
@@ -13,18 +14,26 @@ from ott.solvers.linear.sinkhorn import Sinkhorn
13
14
  from pandas import Series
14
15
  from rich.progress import track
15
16
  from scipy.sparse import issparse
16
- from scipy.spatial.distance import cosine
17
+ from scipy.spatial.distance import cosine, mahalanobis
17
18
  from scipy.special import gammaln
18
19
  from scipy.stats import kendalltau, kstest, pearsonr, spearmanr
19
20
  from sklearn.linear_model import LogisticRegression
20
21
  from sklearn.metrics import pairwise_distances, r2_score
21
22
  from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel
23
+ from sklearn.neighbors import KernelDensity
22
24
  from statsmodels.discrete.discrete_model import NegativeBinomialP
23
25
 
24
26
  if TYPE_CHECKING:
27
+ from collections.abc import Callable
28
+
25
29
  from anndata import AnnData
26
30
 
27
31
 
32
+ class MeanVar(NamedTuple):
33
+ mean: float
34
+ variance: float
35
+
36
+
28
37
  class Distance:
29
38
  """Distance class, used to compute distances between groups of cells.
30
39
 
@@ -80,6 +89,11 @@ class Distance:
80
89
  Average of the classification probability of the perturbation for a binary classifier.
81
90
  - "classifier_cp": classifier class projection
82
91
  Average of the class
92
+ - "mean_var_distribution": Distance between mean-variance distributions between cells of 2 groups.
93
+ Mean square distance between the mean-variance distributions of cells from 2 groups using Kernel Density Estimation (KDE).
94
+ - "mahalanobis": Mahalanobis distance between the means of cells from two groups.
95
+ It is originally used to measure distance between a point and a distribution.
96
+ in this context, it quantifies the difference between the mean profiles of a target group and a reference group.
83
97
 
84
98
  Attributes:
85
99
  metric: Name of distance metric.
@@ -99,6 +113,7 @@ class Distance:
99
113
  def __init__(
100
114
  self,
101
115
  metric: str = "edistance",
116
+ agg_fct: Callable = np.mean,
102
117
  layer_key: str = None,
103
118
  obsm_key: str = None,
104
119
  cell_wise_metric: str = "euclidean",
@@ -106,37 +121,38 @@ class Distance:
106
121
  """Initialize Distance class.
107
122
 
108
123
  Args:
109
- metric: Distance metric to use. Defaults to "edistance".
124
+ metric: Distance metric to use.
125
+ agg_fct: Aggregation function to generate pseudobulk vectors.
110
126
  layer_key: Name of the counts layer containing raw counts to calculate distances for.
111
127
  Mutually exclusive with 'obsm_key'.
112
- Defaults to None and is then not used.
128
+ Is not used if `None`.
113
129
  obsm_key: Name of embedding in adata.obsm to use.
114
- Mutually exclusive with 'counts_layer_key'.
115
- Defaults to None, but is set to "X_pca" if not set explicitly internally.
130
+ Mutually exclusive with 'layer_key'.
131
+ Defaults to None, but is set to "X_pca" if not explicitly set internally.
116
132
  cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells.
117
- Defaults to "euclidean".
118
133
  """
119
134
  metric_fct: AbstractDistance = None
135
+ self.aggregation_func = agg_fct
120
136
  if metric == "edistance":
121
137
  metric_fct = Edistance()
122
138
  elif metric == "euclidean":
123
- metric_fct = EuclideanDistance()
139
+ metric_fct = EuclideanDistance(self.aggregation_func)
124
140
  elif metric == "root_mean_squared_error":
125
- metric_fct = EuclideanDistance()
141
+ metric_fct = EuclideanDistance(self.aggregation_func)
126
142
  elif metric == "mse":
127
- metric_fct = MeanSquaredDistance()
143
+ metric_fct = MeanSquaredDistance(self.aggregation_func)
128
144
  elif metric == "mean_absolute_error":
129
- metric_fct = MeanAbsoluteDistance()
145
+ metric_fct = MeanAbsoluteDistance(self.aggregation_func)
130
146
  elif metric == "pearson_distance":
131
- metric_fct = PearsonDistance()
147
+ metric_fct = PearsonDistance(self.aggregation_func)
132
148
  elif metric == "spearman_distance":
133
- metric_fct = SpearmanDistance()
149
+ metric_fct = SpearmanDistance(self.aggregation_func)
134
150
  elif metric == "kendalltau_distance":
135
- metric_fct = KendallTauDistance()
151
+ metric_fct = KendallTauDistance(self.aggregation_func)
136
152
  elif metric == "cosine_distance":
137
- metric_fct = CosineDistance()
153
+ metric_fct = CosineDistance(self.aggregation_func)
138
154
  elif metric == "r2_distance":
139
- metric_fct = R2ScoreDistance()
155
+ metric_fct = R2ScoreDistance(self.aggregation_func)
140
156
  elif metric == "mean_pairwise":
141
157
  metric_fct = MeanPairwiseDistance()
142
158
  elif metric == "mmd":
@@ -155,14 +171,17 @@ class Distance:
155
171
  metric_fct = ClassifierProbaDistance()
156
172
  elif metric == "classifier_cp":
157
173
  metric_fct = ClassifierClassProjection()
174
+ elif metric == "mean_var_distribution":
175
+ metric_fct = MeanVarDistributionDistance()
176
+ elif metric == "mahalanobis":
177
+ metric_fct = MahalanobisDistance(self.aggregation_func)
158
178
  else:
159
179
  raise ValueError(f"Metric {metric} not recognized.")
160
180
  self.metric_fct = metric_fct
161
181
 
162
182
  if layer_key and obsm_key:
163
183
  raise ValueError(
164
- "Cannot use 'counts_layer_key' and 'obsm_key' at the same time.\n"
165
- "Please provide only one of the two keys."
184
+ "Cannot use 'layer_key' and 'obsm_key' at the same time.\n" "Please provide only one of the two keys."
166
185
  )
167
186
  if not layer_key and not obsm_key:
168
187
  obsm_key = "X_pca"
@@ -195,37 +214,80 @@ class Distance:
195
214
  >>> D = Distance(X, Y)
196
215
  """
197
216
  if issparse(X):
198
- X = X.A
217
+ X = X.toarray()
199
218
  if issparse(Y):
200
- Y = Y.A
219
+ Y = Y.toarray()
201
220
 
202
221
  if len(X) == 0 or len(Y) == 0:
203
222
  raise ValueError("Neither X nor Y can be empty.")
204
223
 
205
224
  return self.metric_fct(X, Y, **kwargs)
206
225
 
226
+ def bootstrap(
227
+ self,
228
+ X: np.ndarray,
229
+ Y: np.ndarray,
230
+ *,
231
+ n_bootstrap: int = 100,
232
+ random_state: int = 0,
233
+ **kwargs,
234
+ ) -> MeanVar:
235
+ """Bootstrap computation of mean and variance of the distance between vectors X and Y.
236
+
237
+ Args:
238
+ X: First vector of shape (n_samples, n_features).
239
+ Y: Second vector of shape (n_samples, n_features).
240
+ n_bootstrap: Number of bootstrap samples.
241
+ random_state: Random state for bootstrapping.
242
+
243
+ Returns:
244
+ MeanVar: Mean and variance of distance between X and Y.
245
+
246
+ Examples:
247
+ >>> import pertpy as pt
248
+ >>> adata = pt.dt.distance_example()
249
+ >>> Distance = pt.tools.Distance(metric="edistance")
250
+ >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
251
+ >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
252
+ >>> D = Distance.bootstrap(X, Y)
253
+ """
254
+ return self._bootstrap_mode(
255
+ X,
256
+ Y,
257
+ n_bootstraps=n_bootstrap,
258
+ random_state=random_state,
259
+ **kwargs,
260
+ )
261
+
207
262
  def pairwise(
208
263
  self,
209
264
  adata: AnnData,
210
265
  groupby: str,
211
266
  groups: list[str] | None = None,
267
+ bootstrap: bool = False,
268
+ n_bootstrap: int = 100,
269
+ random_state: int = 0,
212
270
  show_progressbar: bool = True,
213
271
  n_jobs: int = -1,
214
272
  **kwargs,
215
- ) -> pd.DataFrame:
273
+ ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
216
274
  """Get pairwise distances between groups of cells.
217
275
 
218
276
  Args:
219
277
  adata: Annotated data matrix.
220
278
  groupby: Column name in adata.obs.
221
279
  groups: List of groups to compute pairwise distances for.
222
- If None, uses all groups. Defaults to None.
223
- show_progressbar: Whether to show progress bar. Defaults to True.
280
+ If None, uses all groups.
281
+ bootstrap: Whether to bootstrap the distance.
282
+ n_bootstrap: Number of bootstrap samples.
283
+ random_state: Random state for bootstrapping.
284
+ show_progressbar: Whether to show progress bar.
224
285
  n_jobs: Number of cores to use. Defaults to -1 (all).
225
286
  kwargs: Additional keyword arguments passed to the metric function.
226
287
 
227
288
  Returns:
228
289
  pd.DataFrame: Dataframe with pairwise distances.
290
+ tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of pairwise distances.
229
291
 
230
292
  Examples:
231
293
  >>> import pertpy as pt
@@ -236,6 +298,8 @@ class Distance:
236
298
  groups = adata.obs[groupby].unique() if groups is None else groups
237
299
  grouping = adata.obs[groupby].copy()
238
300
  df = pd.DataFrame(index=groups, columns=groups, dtype=float)
301
+ if bootstrap:
302
+ df_var = pd.DataFrame(index=groups, columns=groups, dtype=float)
239
303
  fct = track if show_progressbar else lambda iterable: iterable
240
304
 
241
305
  # Some metrics are able to handle precomputed distances. This means that
@@ -251,16 +315,29 @@ class Distance:
251
315
  for index_x, group_x in enumerate(fct(groups)):
252
316
  idx_x = grouping == group_x
253
317
  for group_y in groups[index_x:]: # type: ignore
254
- if group_x == group_y:
255
- dist = 0.0 # by distance axiom
318
+ # subset the pairwise distance matrix to the two groups
319
+ idx_y = grouping == group_y
320
+ sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
321
+ sub_idx = grouping[idx_x | idx_y] == group_x
322
+ if not bootstrap:
323
+ if group_x == group_y:
324
+ dist = 0.0
325
+ else:
326
+ dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
327
+ df.loc[group_x, group_y] = dist
328
+ df.loc[group_y, group_x] = dist
329
+
256
330
  else:
257
- idx_y = grouping == group_y
258
- # subset the pairwise distance matrix to the two groups
259
- sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
260
- sub_idx = grouping[idx_x | idx_y] == group_x
261
- dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
262
- df.loc[group_x, group_y] = dist
263
- df.loc[group_y, group_x] = dist
331
+ bootstrap_output = self._bootstrap_mode_precomputed(
332
+ sub_pwd,
333
+ sub_idx,
334
+ n_bootstraps=n_bootstrap,
335
+ random_state=random_state,
336
+ **kwargs,
337
+ )
338
+ # In the bootstrap case, distance of group to itself is a mean and can be non-zero
339
+ df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
340
+ df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
264
341
  else:
265
342
  if self.layer_key:
266
343
  embedding = adata.layers[self.layer_key]
@@ -269,18 +346,39 @@ class Distance:
269
346
  for index_x, group_x in enumerate(fct(groups)):
270
347
  cells_x = embedding[grouping == group_x].copy()
271
348
  for group_y in groups[index_x:]: # type: ignore
272
- if group_x == group_y:
273
- dist = 0.0
349
+ cells_y = embedding[grouping == group_y].copy()
350
+ if not bootstrap:
351
+ # By distance axiom, the distance between a group and itself is 0
352
+ dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
353
+
354
+ df.loc[group_x, group_y] = dist
355
+ df.loc[group_y, group_x] = dist
274
356
  else:
275
- cells_y = embedding[grouping == group_y].copy()
276
- dist = self(cells_x, cells_y, **kwargs)
277
- df.loc[group_x, group_y] = dist
278
- df.loc[group_y, group_x] = dist
357
+ bootstrap_output = self.bootstrap(
358
+ cells_x,
359
+ cells_y,
360
+ n_bootstrap=n_bootstrap,
361
+ random_state=random_state,
362
+ **kwargs,
363
+ )
364
+ # In the bootstrap case, distance of group to itself is a mean and can be non-zero
365
+ df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
366
+ df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
367
+
279
368
  df.index.name = groupby
280
369
  df.columns.name = groupby
281
370
  df.name = f"pairwise {self.metric}"
282
371
 
283
- return df
372
+ if not bootstrap:
373
+ return df
374
+ else:
375
+ df = df.fillna(0)
376
+ df_var.index.name = groupby
377
+ df_var.columns.name = groupby
378
+ df_var = df_var.fillna(0)
379
+ df_var.name = f"pairwise {self.metric} variance"
380
+
381
+ return df, df_var
284
382
 
285
383
  def onesided_distances(
286
384
  self,
@@ -288,10 +386,13 @@ class Distance:
288
386
  groupby: str,
289
387
  selected_group: str | None = None,
290
388
  groups: list[str] | None = None,
389
+ bootstrap: bool = False,
390
+ n_bootstrap: int = 100,
391
+ random_state: int = 0,
291
392
  show_progressbar: bool = True,
292
393
  n_jobs: int = -1,
293
394
  **kwargs,
294
- ) -> pd.DataFrame:
395
+ ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
295
396
  """Get distances between one selected cell group and the remaining other cell groups.
296
397
 
297
398
  Args:
@@ -299,13 +400,18 @@ class Distance:
299
400
  groupby: Column name in adata.obs.
300
401
  selected_group: Group to compute pairwise distances to all other.
301
402
  groups: List of groups to compute distances to selected_group for.
302
- If None, uses all groups. Defaults to None.
303
- show_progressbar: Whether to show progress bar. Defaults to True.
403
+ If None, uses all groups.
404
+ bootstrap: Whether to bootstrap the distance.
405
+ n_bootstrap: Number of bootstrap samples.
406
+ random_state: Random state for bootstrapping.
407
+ show_progressbar: Whether to show progress bar.
304
408
  n_jobs: Number of cores to use. Defaults to -1 (all).
305
409
  kwargs: Additional keyword arguments passed to the metric function.
306
410
 
307
411
  Returns:
308
412
  pd.DataFrame: Dataframe with distances of groups to selected_group.
413
+ tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group.
414
+
309
415
 
310
416
  Examples:
311
417
  >>> import pertpy as pt
@@ -314,20 +420,30 @@ class Distance:
314
420
  >>> pairwise_df = Distance.onesided_distances(adata, groupby="perturbation", selected_group="control")
315
421
  """
316
422
  if self.metric == "classifier_cp":
423
+ if bootstrap:
424
+ raise NotImplementedError("Currently, ClassifierClassProjection does not support bootstrapping.")
317
425
  return self.metric_fct.onesided_distances( # type: ignore
318
- adata, groupby, selected_group, groups, show_progressbar, n_jobs, **kwargs
426
+ adata,
427
+ groupby,
428
+ selected_group,
429
+ groups,
430
+ show_progressbar,
431
+ n_jobs,
432
+ **kwargs,
319
433
  )
320
434
 
321
435
  groups = adata.obs[groupby].unique() if groups is None else groups
322
436
  grouping = adata.obs[groupby].copy()
323
437
  df = pd.Series(index=groups, dtype=float)
438
+ if bootstrap:
439
+ df_var = pd.Series(index=groups, dtype=float)
324
440
  fct = track if show_progressbar else lambda iterable: iterable
325
441
 
326
442
  # Some metrics are able to handle precomputed distances. This means that
327
443
  # the pairwise distances between all cells are computed once and then
328
444
  # passed to the metric function. This is much faster than computing the
329
445
  # pairwise distances for each group separately. Other metrics are not
330
- # able to handle precomputed distances such as the PsuedobulkDistance.
446
+ # able to handle precomputed distances such as the PseudobulkDistance.
331
447
  if self.metric_fct.accepts_precomputed:
332
448
  # Precompute the pairwise distances if needed
333
449
  if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp.keys():
@@ -337,14 +453,25 @@ class Distance:
337
453
  idx_x = grouping == group_x
338
454
  group_y = selected_group
339
455
  if group_x == group_y:
340
- dist = 0.0 # by distance axiom
456
+ df.loc[group_x] = 0.0 # by distance axiom
341
457
  else:
342
458
  idx_y = grouping == group_y
343
459
  # subset the pairwise distance matrix to the two groups
344
460
  sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
345
461
  sub_idx = grouping[idx_x | idx_y] == group_x
346
- dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
347
- df.loc[group_x] = dist
462
+ if not bootstrap:
463
+ dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
464
+ df.loc[group_x] = dist
465
+ else:
466
+ bootstrap_output = self._bootstrap_mode_precomputed(
467
+ sub_pwd,
468
+ sub_idx,
469
+ n_bootstraps=n_bootstrap,
470
+ random_state=random_state,
471
+ **kwargs,
472
+ )
473
+ df.loc[group_x] = bootstrap_output.mean
474
+ df_var.loc[group_x] = bootstrap_output.variance
348
475
  else:
349
476
  if self.layer_key:
350
477
  embedding = adata.layers[self.layer_key]
@@ -353,15 +480,32 @@ class Distance:
353
480
  for group_x in fct(groups):
354
481
  cells_x = embedding[grouping == group_x].copy()
355
482
  group_y = selected_group
356
- if group_x == group_y:
357
- dist = 0.0
483
+ cells_y = embedding[grouping == group_y].copy()
484
+ if not bootstrap:
485
+ # By distance axiom, the distance between a group and itself is 0
486
+ dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
487
+ df.loc[group_x] = dist
358
488
  else:
359
- cells_y = embedding[grouping == group_y].copy()
360
- dist = self(cells_x, cells_y, **kwargs)
361
- df.loc[group_x] = dist
489
+ bootstrap_output = self.bootstrap(
490
+ cells_x,
491
+ cells_y,
492
+ n_bootstrap=n_bootstrap,
493
+ random_state=random_state,
494
+ **kwargs,
495
+ )
496
+ # In the bootstrap case, distance of group to itself is a mean and can be non-zero
497
+ df.loc[group_x] = bootstrap_output.mean
498
+ df_var.loc[group_x] = bootstrap_output.variance
362
499
  df.index.name = groupby
363
500
  df.name = f"{self.metric} to {selected_group}"
364
- return df
501
+ if not bootstrap:
502
+ return df
503
+ else:
504
+ df_var.index.name = groupby
505
+ df_var = df_var.fillna(0)
506
+ df_var.name = f"pairwise {self.metric} variance to {selected_group}"
507
+
508
+ return df, df_var
365
509
 
366
510
  def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None:
367
511
  """Precompute pairwise distances between all cells, writes to adata.obsp.
@@ -387,6 +531,77 @@ class Distance:
387
531
  pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs)
388
532
  adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd
389
533
 
534
+ def compare_distance(
535
+ self,
536
+ pert: np.ndarray,
537
+ pred: np.ndarray,
538
+ ctrl: np.ndarray,
539
+ mode: Literal["simple", "scaled"] = "simple",
540
+ fit_to_pert_and_ctrl: bool = False,
541
+ **kwargs,
542
+ ) -> float:
543
+ """Compute the score of simulating a perturbation.
544
+
545
+ Args:
546
+ pert: Real perturbed data.
547
+ pred: Simulated perturbed data.
548
+ ctrl: Control data
549
+ mode: Mode to use.
550
+ fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`.
551
+ kwargs: Additional keyword arguments passed to the metric function.
552
+ """
553
+ if mode == "simple":
554
+ pass # nothing to be done
555
+ elif mode == "scaled":
556
+ from sklearn.preprocessing import MinMaxScaler
557
+
558
+ scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl)
559
+ pred = scaler.transform(pred)
560
+ pert = scaler.transform(pert)
561
+ else:
562
+ raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.")
563
+
564
+ d1 = self.metric_fct(pert, pred, **kwargs)
565
+ d2 = self.metric_fct(ctrl, pred, **kwargs)
566
+ return d1 / d2
567
+
568
+ def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
569
+ rng = np.random.default_rng(random_state)
570
+
571
+ distances = []
572
+ for _ in range(n_bootstraps):
573
+ X_bootstrapped = X[rng.choice(a=X.shape[0], size=X.shape[0], replace=True)]
574
+ Y_bootstrapped = Y[rng.choice(a=Y.shape[0], size=X.shape[0], replace=True)]
575
+
576
+ distance = self(X_bootstrapped, Y_bootstrapped, **kwargs)
577
+ distances.append(distance)
578
+
579
+ mean = np.mean(distances)
580
+ variance = np.var(distances)
581
+ return MeanVar(mean=mean, variance=variance)
582
+
583
+ def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
584
+ rng = np.random.default_rng(random_state)
585
+
586
+ distances = []
587
+ for _ in range(n_bootstraps):
588
+ # To maintain the number of cells for both groups (whatever balancing they may have),
589
+ # we sample the positive and negative indices separately
590
+ bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True)
591
+ bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True)
592
+ bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx])
593
+ bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx)
594
+
595
+ bootstrap_sub_idx = sub_idx[bootstrap_idx]
596
+ bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs]
597
+
598
+ distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs)
599
+ distances.append(distance)
600
+
601
+ mean = np.mean(distances)
602
+ variance = np.var(distances)
603
+ return MeanVar(mean=mean, variance=variance)
604
+
390
605
 
391
606
  class AbstractDistance(ABC):
392
607
  """Abstract class of distance metrics between two sets of vectors."""
@@ -500,12 +715,17 @@ class WassersteinDistance(AbstractDistance):
500
715
  class EuclideanDistance(AbstractDistance):
501
716
  """Euclidean distance between pseudobulk vectors."""
502
717
 
503
- def __init__(self) -> None:
718
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
504
719
  super().__init__()
505
720
  self.accepts_precomputed = False
721
+ self.aggregation_func = aggregation_func
506
722
 
507
723
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
508
- return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=2, **kwargs)
724
+ return np.linalg.norm(
725
+ self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
726
+ ord=2,
727
+ **kwargs,
728
+ )
509
729
 
510
730
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
511
731
  raise NotImplementedError("EuclideanDistance cannot be called on a pairwise distance matrix.")
@@ -514,12 +734,21 @@ class EuclideanDistance(AbstractDistance):
514
734
  class MeanSquaredDistance(AbstractDistance):
515
735
  """Mean squared distance between pseudobulk vectors."""
516
736
 
517
- def __init__(self) -> None:
737
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
518
738
  super().__init__()
519
739
  self.accepts_precomputed = False
740
+ self.aggregation_func = aggregation_func
520
741
 
521
742
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
522
- return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=2, **kwargs) ** 2 / X.shape[1]
743
+ return (
744
+ np.linalg.norm(
745
+ self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
746
+ ord=2,
747
+ **kwargs,
748
+ )
749
+ ** 2
750
+ / X.shape[1]
751
+ )
523
752
 
524
753
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
525
754
  raise NotImplementedError("MeanSquaredDistance cannot be called on a pairwise distance matrix.")
@@ -528,12 +757,20 @@ class MeanSquaredDistance(AbstractDistance):
528
757
  class MeanAbsoluteDistance(AbstractDistance):
529
758
  """Absolute (Norm-1) distance between pseudobulk vectors."""
530
759
 
531
- def __init__(self) -> None:
760
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
532
761
  super().__init__()
533
762
  self.accepts_precomputed = False
763
+ self.aggregation_func = aggregation_func
534
764
 
535
765
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
536
- return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=1, **kwargs) / X.shape[1]
766
+ return (
767
+ np.linalg.norm(
768
+ self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
769
+ ord=1,
770
+ **kwargs,
771
+ )
772
+ / X.shape[1]
773
+ )
537
774
 
538
775
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
539
776
  raise NotImplementedError("MeanAbsoluteDistance cannot be called on a pairwise distance matrix.")
@@ -558,12 +795,13 @@ class MeanPairwiseDistance(AbstractDistance):
558
795
  class PearsonDistance(AbstractDistance):
559
796
  """Pearson distance between pseudobulk vectors."""
560
797
 
561
- def __init__(self) -> None:
798
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
562
799
  super().__init__()
563
800
  self.accepts_precomputed = False
801
+ self.aggregation_func = aggregation_func
564
802
 
565
803
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
566
- return 1 - pearsonr(X.mean(axis=0), Y.mean(axis=0))[0]
804
+ return 1 - pearsonr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
567
805
 
568
806
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
569
807
  raise NotImplementedError("PearsonDistance cannot be called on a pairwise distance matrix.")
@@ -572,12 +810,13 @@ class PearsonDistance(AbstractDistance):
572
810
  class SpearmanDistance(AbstractDistance):
573
811
  """Spearman distance between pseudobulk vectors."""
574
812
 
575
- def __init__(self) -> None:
813
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
576
814
  super().__init__()
577
815
  self.accepts_precomputed = False
816
+ self.aggregation_func = aggregation_func
578
817
 
579
818
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
580
- return 1 - spearmanr(X.mean(axis=0), Y.mean(axis=0))[0]
819
+ return 1 - spearmanr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
581
820
 
582
821
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
583
822
  raise NotImplementedError("SpearmanDistance cannot be called on a pairwise distance matrix.")
@@ -586,12 +825,13 @@ class SpearmanDistance(AbstractDistance):
586
825
  class KendallTauDistance(AbstractDistance):
587
826
  """Kendall-tau distance between pseudobulk vectors."""
588
827
 
589
- def __init__(self) -> None:
828
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
590
829
  super().__init__()
591
830
  self.accepts_precomputed = False
831
+ self.aggregation_func = aggregation_func
592
832
 
593
833
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
594
- x, y = X.mean(axis=0), Y.mean(axis=0)
834
+ x, y = self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)
595
835
  n = len(x)
596
836
  tau_corr = kendalltau(x, y).statistic
597
837
  tau_dist = (1 - tau_corr) * n * (n - 1) / 4
@@ -604,12 +844,13 @@ class KendallTauDistance(AbstractDistance):
604
844
  class CosineDistance(AbstractDistance):
605
845
  """Cosine distance between pseudobulk vectors."""
606
846
 
607
- def __init__(self) -> None:
847
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
608
848
  super().__init__()
609
849
  self.accepts_precomputed = False
850
+ self.aggregation_func = aggregation_func
610
851
 
611
852
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
612
- return cosine(X.mean(axis=0), Y.mean(axis=0))
853
+ return cosine(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
613
854
 
614
855
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
615
856
  raise NotImplementedError("CosineDistance cannot be called on a pairwise distance matrix.")
@@ -620,12 +861,13 @@ class R2ScoreDistance(AbstractDistance):
620
861
 
621
862
  # NOTE: This is not a distance metric but a similarity metric.
622
863
 
623
- def __init__(self) -> None:
864
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
624
865
  super().__init__()
625
866
  self.accepts_precomputed = False
867
+ self.aggregation_func = aggregation_func
626
868
 
627
869
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
628
- return 1 - r2_score(X.mean(axis=0), Y.mean(axis=0))
870
+ return 1 - r2_score(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
629
871
 
630
872
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
631
873
  raise NotImplementedError("R2ScoreDistance cannot be called on a pairwise distance matrix.")
@@ -834,6 +1076,7 @@ class ClassifierClassProjection(AbstractDistance):
834
1076
  Similar to the parent function, the returned dataframe contains only the specified groups.
835
1077
  """
836
1078
  groups = adata.obs[groupby].unique() if groups is None else groups
1079
+ fct = track if show_progressbar else lambda iterable: iterable
837
1080
 
838
1081
  X = adata[adata.obs[groupby] != selected_group].X
839
1082
  labels = adata[adata.obs[groupby] != selected_group].obs[groupby].values
@@ -844,7 +1087,8 @@ class ClassifierClassProjection(AbstractDistance):
844
1087
  test_probas = reg.predict_proba(Y)
845
1088
 
846
1089
  df = pd.Series(index=groups, dtype=float)
847
- for group in groups:
1090
+
1091
+ for group in fct(groups):
848
1092
  if group == selected_group:
849
1093
  df.loc[group] = 0
850
1094
  else:
@@ -857,3 +1101,95 @@ class ClassifierClassProjection(AbstractDistance):
857
1101
 
858
1102
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
859
1103
  raise NotImplementedError("ClassifierClassProjection cannot be called on a pairwise distance matrix.")
1104
+
1105
+
1106
+ class MeanVarDistributionDistance(AbstractDistance):
1107
+ """Distance between mean-var distributions of gene expression."""
1108
+
1109
+ def __init__(self) -> None:
1110
+ super().__init__()
1111
+ self.accepts_precomputed = False
1112
+
1113
+ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1114
+ """Difference of mean-var distributions in 2 matrices.
1115
+
1116
+ Args:
1117
+ X: Normalized and log transformed cells x genes count matrix.
1118
+ Y: Normalized and log transformed cells x genes count matrix.
1119
+ """
1120
+
1121
+ def _mean_var(x, log: bool = False):
1122
+ mean = np.mean(x, axis=0)
1123
+ var = np.var(x, axis=0)
1124
+ positive = mean > 0
1125
+ mean = mean[positive]
1126
+ var = var[positive]
1127
+ if log:
1128
+ mean = np.log(mean)
1129
+ var = np.log(var)
1130
+ return mean, var
1131
+
1132
+ def _prep_kde_data(x, y):
1133
+ return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
1134
+
1135
+ def _grid_points(d, n_points=100):
1136
+ # Make grid, add 1 bin on lower/upper end to get final n_points
1137
+ d_min = d.min()
1138
+ d_max = d.max()
1139
+ # Compute bin size
1140
+ d_bin = (d_max - d_min) / (n_points - 2)
1141
+ d_min = d_min - d_bin
1142
+ d_max = d_max + d_bin
1143
+ return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
1144
+
1145
+ def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
1146
+ # the thread_count is determined using the factor 0.875 as recommended here:
1147
+ # https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
1148
+ with multiprocessing.Pool(thread_count) as p:
1149
+ return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
1150
+
1151
+ def _kde_eval(d, grid):
1152
+ # Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
1153
+ # can not be compared well on regions further away from the data as they are -inf
1154
+ kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
1155
+ return _parallel_score_samples(kde, grid)
1156
+
1157
+ mean_x, var_x = _mean_var(X, log=True)
1158
+ mean_y, var_y = _mean_var(Y, log=True)
1159
+
1160
+ x = _prep_kde_data(mean_x, var_x)
1161
+ y = _prep_kde_data(mean_y, var_y)
1162
+
1163
+ # Gridpoints to eval KDE on
1164
+ mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
1165
+ var_grid = _grid_points(np.concatenate([var_x, var_y]))
1166
+ grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
1167
+
1168
+ kde_x = _kde_eval(x, grid)
1169
+ kde_y = _kde_eval(y, grid)
1170
+
1171
+ kde_diff = ((kde_x - kde_y) ** 2).mean()
1172
+
1173
+ return kde_diff
1174
+
1175
+ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1176
+ raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
1177
+
1178
+
1179
+ class MahalanobisDistance(AbstractDistance):
1180
+ """Mahalanobis distance between pseudobulk vectors."""
1181
+
1182
+ def __init__(self, aggregation_func: Callable = np.mean) -> None:
1183
+ super().__init__()
1184
+ self.accepts_precomputed = False
1185
+ self.aggregation_func = aggregation_func
1186
+
1187
+ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1188
+ return mahalanobis(
1189
+ self.aggregation_func(X, axis=0),
1190
+ self.aggregation_func(Y, axis=0),
1191
+ np.linalg.inv(np.cov(X.T)),
1192
+ )
1193
+
1194
+ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1195
+ raise NotImplementedError("Mahalanobis cannot be called on a pairwise distance matrix.")