pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_milo.py CHANGED
@@ -3,21 +3,24 @@ from __future__ import annotations
3
3
  import logging
4
4
  import random
5
5
  import re
6
- from typing import Literal
6
+ from typing import TYPE_CHECKING, Literal
7
7
 
8
+ import matplotlib.pyplot as plt
8
9
  import numpy as np
9
10
  import pandas as pd
11
+ import scanpy as sc
12
+ import seaborn as sns
10
13
  from anndata import AnnData
14
+ from lamin_utils import logger
11
15
  from mudata import MuData
12
- from rich import print
13
-
14
- try:
15
- from rpy2.robjects import conversion, numpy2ri, pandas2ri
16
- from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
17
- except ModuleNotFoundError:
18
- print(
19
- "[bold yellow]ryp2 is not installed. Install with [green]pip install rpy2 [yellow]to run tools with R support."
20
- )
16
+
17
+ if TYPE_CHECKING:
18
+ from collections.abc import Sequence
19
+
20
+ from matplotlib.axes import Axes
21
+ from matplotlib.colors import Colormap
22
+ from matplotlib.figure import Figure
23
+
21
24
  from scipy.sparse import csr_matrix
22
25
  from sklearn.metrics.pairwise import euclidean_distances
23
26
 
@@ -26,7 +29,16 @@ class Milo:
26
29
  """Python implementation of Milo."""
27
30
 
28
31
  def __init__(self):
29
- pass
32
+ try:
33
+ from rpy2.robjects import conversion, numpy2ri, pandas2ri
34
+ from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
35
+ except ModuleNotFoundError:
36
+ raise ImportError("milo requires rpy2 to be installed.") from None
37
+
38
+ try:
39
+ importr("edgeR")
40
+ except ImportError as e:
41
+ raise ImportError("milo requires a valid R installation with edger installed:\n") from e
30
42
 
31
43
  def load(
32
44
  self,
@@ -39,7 +51,7 @@ class Milo:
39
51
  input: AnnData
40
52
  feature_key: Key to store the cell-level AnnData object in the MuData object
41
53
  Returns:
42
- MuData: MuData object with original AnnData (default is `mudata[feature_key]`).
54
+ MuData: MuData object with original AnnData.
43
55
 
44
56
  Examples:
45
57
  >>> import pertpy as pt
@@ -71,11 +83,10 @@ class Milo:
71
83
  neighbors_key: The key in `adata.obsp` or `mdata[feature_key].obsp` to use as KNN graph.
72
84
  If not specified, `make_nhoods` looks .obsp[‘connectivities’] for connectivities (default storage places for `scanpy.pp.neighbors`).
73
85
  If specified, it looks at .obsp[.uns[neighbors_key][‘connectivities_key’]] for connectivities.
74
- (default: None)
75
- feature_key: If input data is MuData, specify key to cell-level AnnData object. (default: 'rna')
76
- prop: Fraction of cells to sample for neighbourhood index search. (default: 0.1)
77
- seed: Random seed for cell sampling. (default: 0)
78
- copy: Determines whether a copy of the `adata` is returned. (default: False)
86
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
87
+ prop: Fraction of cells to sample for neighbourhood index search.
88
+ seed: Random seed for cell sampling.
89
+ copy: Determines whether a copy of the `adata` is returned.
79
90
 
80
91
  Returns:
81
92
  If `copy=True`, returns the copy of `adata` with the result in `.obs`, `.obsm`, and `.uns`.
@@ -119,7 +130,7 @@ class Milo:
119
130
  try:
120
131
  knn_graph = adata.obsp["connectivities"].copy()
121
132
  except KeyError:
122
- print('No "connectivities" slot in adata.obsp -- please run scanpy.pp.neighbors(adata) first')
133
+ logger.error('No "connectivities" slot in adata.obsp -- please run scanpy.pp.neighbors(adata) first')
123
134
  raise
124
135
  else:
125
136
  try:
@@ -174,6 +185,7 @@ class Milo:
174
185
  dist_mat = knn_dists[nhood_ixs, :]
175
186
  k_distances = dist_mat.max(1).toarray().ravel()
176
187
  adata.obs["nhood_kth_distance"] = 0
188
+ adata.obs["nhood_kth_distance"] = adata.obs["nhood_kth_distance"].astype(float)
177
189
  adata.obs.loc[adata.obs["nhood_ixs_refined"] == 1, "nhood_kth_distance"] = k_distances
178
190
 
179
191
  if copy:
@@ -190,7 +202,7 @@ class Milo:
190
202
  Args:
191
203
  data: AnnData object with neighbourhoods defined in `obsm['nhoods']` or MuData object with a modality with neighbourhoods defined in `obsm['nhoods']`
192
204
  sample_col: Column in adata.obs that contains sample information
193
- feature_key: If input data is MuData, specify key to cell-level AnnData object. (default: 'rna')
205
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
194
206
 
195
207
  Returns:
196
208
  MuData object storing the original (i.e. rna) AnnData in `mudata[feature_key]`
@@ -221,7 +233,7 @@ class Milo:
221
233
  try:
222
234
  nhoods = adata.obsm["nhoods"]
223
235
  except KeyError:
224
- print('Cannot find "nhoods" slot in adata.obsm -- please run milopy.make_nhoods(adata)')
236
+ logger.error('Cannot find "nhoods" slot in adata.obsm -- please run milopy.make_nhoods(adata)')
225
237
  raise
226
238
  # Make nhood abundance matrix
227
239
  sample_dummies = pd.get_dummies(adata.obs[sample_col])
@@ -229,7 +241,7 @@ class Milo:
229
241
  sample_dummies = csr_matrix(sample_dummies.values)
230
242
  nhood_count_mat = nhoods.T.dot(sample_dummies)
231
243
  sample_obs = pd.DataFrame(index=all_samples)
232
- sample_adata = AnnData(X=nhood_count_mat.T, obs=sample_obs, dtype=np.float32)
244
+ sample_adata = AnnData(X=nhood_count_mat.T, obs=sample_obs)
233
245
  sample_adata.uns["sample_col"] = sample_col
234
246
  # Save nhood index info
235
247
  sample_adata.var["index_cell"] = adata.obs_names[adata.obs["nhood_ixs_refined"] == 1]
@@ -261,10 +273,10 @@ class Milo:
261
273
  design: Formula for the test, following glm syntax from R (e.g. '~ condition').
262
274
  Terms should be columns in `milo_mdata[feature_key].obs`.
263
275
  model_contrasts: A string vector that defines the contrasts used to perform DA testing, following glm syntax from R (e.g. "conditionDisease - conditionControl").
264
- If no contrast is specified (default), then the last categorical level in condition of interest is used as the test group. Defaults to None.
265
- subset_samples: subset of samples (obs in `milo_mdata['milo']`) to use for the test. Defaults to None.
266
- add_intercept: whether to include an intercept in the model. If False, this is equivalent to adding + 0 in the design formula. When model_contrasts is specified, this is set to False by default. Defaults to True.
267
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
276
+ If no contrast is specified (default), then the last categorical level in condition of interest is used as the test group.
277
+ subset_samples: subset of samples (obs in `milo_mdata['milo']`) to use for the test.
278
+ add_intercept: whether to include an intercept in the model. If False, this is equivalent to adding + 0 in the design formula. When model_contrasts is specified, this is set to False by default.
279
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
268
280
  solver: The solver to fit the model to. One of "edger" (requires R, rpy2 and edgeR to be installed) or "batchglm"
269
281
 
270
282
  Returns:
@@ -288,8 +300,8 @@ class Milo:
288
300
  try:
289
301
  sample_adata = mdata["milo"]
290
302
  except KeyError:
291
- print(
292
- "[bold red]milo_mdata should be a MuData object with two slots:"
303
+ logger.error(
304
+ "milo_mdata should be a MuData object with two slots:"
293
305
  " feature_key and 'milo' - please run milopy.count_nhoods() first"
294
306
  )
295
307
  raise
@@ -303,7 +315,7 @@ class Milo:
303
315
  sample_obs = adata.obs[covariates + [sample_col]].drop_duplicates()
304
316
  except KeyError:
305
317
  missing_cov = [x for x in covariates if x not in sample_adata.obs.columns]
306
- print("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
318
+ logger.warning("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
307
319
  raise
308
320
  sample_obs = sample_obs[covariates + [sample_col]]
309
321
  sample_obs.index = sample_obs[sample_col].astype("str")
@@ -311,7 +323,7 @@ class Milo:
311
323
  try:
312
324
  assert sample_obs.loc[sample_adata.obs_names].shape[0] == len(sample_adata.obs_names)
313
325
  except AssertionError:
314
- print(
326
+ logger.warning(
315
327
  f"Values in mdata[{feature_key}].obs[{covariates}] cannot be unambiguously assigned to each sample"
316
328
  f" -- each sample value should match a single covariate value"
317
329
  )
@@ -323,7 +335,9 @@ class Milo:
323
335
  design_df = sample_adata.obs[covariates]
324
336
  except KeyError:
325
337
  missing_cov = [x for x in covariates if x not in sample_adata.obs.columns]
326
- print('Covariates {c} are not columns in adata.uns["sample_adata"].obs'.format(c=" ".join(missing_cov)))
338
+ logger.error(
339
+ 'Covariates {c} are not columns in adata.uns["sample_adata"].obs'.format(c=" ".join(missing_cov))
340
+ )
327
341
  raise
328
342
  # Get count matrix
329
343
  count_mat = sample_adata.X.T.toarray()
@@ -367,6 +381,8 @@ class Milo:
367
381
  return(colnames(m))
368
382
  }
369
383
  """
384
+ from rpy2.robjects.packages import STAP
385
+
370
386
  get_model_cols = STAP(r_str, "get_model_cols")
371
387
  model_mat_cols = get_model_cols.get_model_cols(design_df, design)
372
388
  model_df = pd.DataFrame(model)
@@ -374,13 +390,16 @@ class Milo:
374
390
  try:
375
391
  mod_contrast = limma.makeContrasts(contrasts=model_contrasts, levels=model_df)
376
392
  except ValueError:
377
- print("Model contrasts must be in the form 'A-B' or 'A+B'")
393
+ logger.error("Model contrasts must be in the form 'A-B' or 'A+B'")
378
394
  raise
379
395
  res = base.as_data_frame(
380
396
  edgeR.topTags(edgeR.glmQLFTest(fit, contrast=mod_contrast), sort_by="none", n=np.inf)
381
397
  )
382
398
  else:
383
399
  res = base.as_data_frame(edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf))
400
+
401
+ from rpy2.robjects import conversion
402
+
384
403
  res = conversion.rpy2py(res)
385
404
  if not isinstance(res, pd.DataFrame):
386
405
  res = pd.DataFrame(res)
@@ -405,7 +424,7 @@ class Milo:
405
424
  Args:
406
425
  mdata: MuData object
407
426
  anno_col: Column in adata.obs containing the cell annotations to use for nhood labelling
408
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
427
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
409
428
 
410
429
  Returns:
411
430
  None. Adds in place:
@@ -423,12 +442,12 @@ class Milo:
423
442
  >>> sc.pp.neighbors(mdata["rna"])
424
443
  >>> milo.make_nhoods(mdata["rna"])
425
444
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
426
- >>> milo.annotate_nhoods(mdata, anno_col='cell_type')
445
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
427
446
  """
428
447
  try:
429
448
  sample_adata = mdata["milo"]
430
449
  except KeyError:
431
- print(
450
+ logger.error(
432
451
  "milo_mdata should be a MuData object with two slots: feature_key and 'milo' - please run milopy.count_nhoods(adata) first"
433
452
  )
434
453
  raise
@@ -459,7 +478,7 @@ class Milo:
459
478
  Args:
460
479
  mdata: MuData object
461
480
  anno_col: Column in adata.obs containing the cell annotations to use for nhood labelling
462
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
481
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
463
482
 
464
483
  Returns:
465
484
  None. Adds in place:
@@ -474,7 +493,7 @@ class Milo:
474
493
  >>> sc.pp.neighbors(mdata["rna"])
475
494
  >>> milo.make_nhoods(mdata["rna"])
476
495
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
477
- >>> milo.annotate_nhoods_continuous(mdata, anno_col='nUMI')
496
+ >>> milo.annotate_nhoods_continuous(mdata, anno_col="nUMI")
478
497
  """
479
498
  if "milo" not in mdata.mod:
480
499
  raise ValueError(
@@ -500,7 +519,7 @@ class Milo:
500
519
  Args:
501
520
  mdata: MuData object
502
521
  new_covariates: columns in `milo_mdata[feature_key].obs` to add to `milo_mdata['milo'].obs`.
503
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
522
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
504
523
 
505
524
  Returns:
506
525
  None, adds columns to `milo_mdata['milo']` in place
@@ -519,7 +538,7 @@ class Milo:
519
538
  try:
520
539
  sample_adata = mdata["milo"]
521
540
  except KeyError:
522
- print(
541
+ logger.error(
523
542
  "milo_mdata should be a MuData object with two slots: feature_key and 'milo' - please run milopy.count_nhoods(adata) first"
524
543
  )
525
544
  raise
@@ -533,14 +552,14 @@ class Milo:
533
552
  sample_obs = adata.obs[covariates + [sample_col]].drop_duplicates()
534
553
  except KeyError:
535
554
  missing_cov = [covar for covar in covariates if covar not in sample_adata.obs.columns]
536
- print("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
555
+ logger.error("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
537
556
  raise
538
557
  sample_obs = sample_obs[covariates + [sample_col]].astype("str")
539
558
  sample_obs.index = sample_obs[sample_col]
540
559
  try:
541
560
  assert sample_obs.loc[sample_adata.obs_names].shape[0] == len(sample_adata.obs_names)
542
561
  except ValueError:
543
- print(
562
+ logger.error(
544
563
  "Covariates cannot be unambiguously assigned to each sample -- each sample value should match a single covariate value"
545
564
  )
546
565
  raise
@@ -551,8 +570,8 @@ class Milo:
551
570
 
552
571
  Args:
553
572
  mdata: MuData object
554
- basis: Name of the obsm basis to use for layout of neighbourhoods (key in `adata.obsm`). Defaults to "X_umap".
555
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
573
+ basis: Name of the obsm basis to use for layout of neighbourhoods (key in `adata.obsm`).
574
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
556
575
 
557
576
  Returns:
558
577
  - `milo_mdata['milo'].varp['nhood_connectivities']`: graph of overlap between neighbourhoods (i.e. no of shared cells)
@@ -584,13 +603,13 @@ class Milo:
584
603
  "distances_key": "",
585
604
  }
586
605
 
587
- def add_nhood_expression(self, mdata: MuData, layer: str | None = None, feature_key: str | None = "rna"):
606
+ def add_nhood_expression(self, mdata: MuData, layer: str | None = None, feature_key: str | None = "rna") -> None:
588
607
  """Calculates the mean expression in neighbourhoods of each feature.
589
608
 
590
609
  Args:
591
610
  mdata: MuData object
592
- layer: If provided, use `milo_mdata[feature_key][layer]` as expression matrix instead of `milo_mdata[feature_key].X`. Defaults to None.
593
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
611
+ layer: If provided, use `milo_mdata[feature_key][layer]` as expression matrix instead of `milo_mdata[feature_key].X`.
612
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
594
613
 
595
614
  Returns:
596
615
  Updates adata in place to store the matrix of average expression in each neighbourhood in `milo_mdata['milo'].varm['expr']`
@@ -609,7 +628,7 @@ class Milo:
609
628
  try:
610
629
  sample_adata = mdata["milo"]
611
630
  except KeyError:
612
- print(
631
+ logger.error(
613
632
  "milo_mdata should be a MuData object with two slots:"
614
633
  " feature_key and 'milo' - please run milopy.count_nhoods(adata) first"
615
634
  )
@@ -633,6 +652,9 @@ class Milo:
633
652
  self,
634
653
  ):
635
654
  """Set up rpy2 to run edgeR"""
655
+ from rpy2.robjects import numpy2ri, pandas2ri
656
+ from rpy2.robjects.packages import importr
657
+
636
658
  numpy2ri.activate()
637
659
  pandas2ri.activate()
638
660
  edgeR = self._try_import_bioc_library("edgeR")
@@ -651,11 +673,13 @@ class Milo:
651
673
  Args:
652
674
  name (str): R packages name
653
675
  """
676
+ from rpy2.robjects.packages import PackageNotInstalledError, importr
677
+
654
678
  try:
655
679
  _r_lib = importr(name)
656
680
  return _r_lib
657
681
  except PackageNotInstalledError:
658
- print(f"Install Bioconductor library `{name!r}` first as `BiocManager::install({name!r}).`")
682
+ logger.error(f"Install Bioconductor library `{name!r}` first as `BiocManager::install({name!r}).`")
659
683
  raise
660
684
 
661
685
  def _graph_spatial_fdr(
@@ -663,11 +687,13 @@ class Milo:
663
687
  sample_adata: AnnData,
664
688
  neighbors_key: str | None = None,
665
689
  ):
666
- """FDR correction weighted on inverse of connectivity of neighbourhoods. The distance to the k-th nearest neighbor is used as a measure of connectivity.
690
+ """FDR correction weighted on inverse of connectivity of neighbourhoods.
691
+
692
+ The distance to the k-th nearest neighbor is used as a measure of connectivity.
667
693
 
668
694
  Args:
669
695
  sample_adata: Sample-level AnnData.
670
- neighbors_key: The key in `adata.obsp` to use as KNN graph. Defaults to None.
696
+ neighbors_key: The key in `adata.obsp` to use as KNN graph.
671
697
  """
672
698
  # use 1/connectivity as the weighting for the weighted BH adjustment from Cydar
673
699
  w = 1 / sample_adata.var["kth_distance"]
@@ -686,3 +712,334 @@ class Milo:
686
712
 
687
713
  sample_adata.var["SpatialFDR"] = np.nan
688
714
  sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp
715
+
716
+ def plot_nhood_graph(
717
+ self,
718
+ mdata: MuData,
719
+ alpha: float = 0.1,
720
+ min_logFC: float = 0,
721
+ min_size: int = 10,
722
+ plot_edges: bool = False,
723
+ title: str = "DA log-Fold Change",
724
+ color_map: Colormap | str | None = None,
725
+ palette: str | Sequence[str] | None = None,
726
+ ax: Axes | None = None,
727
+ show: bool | None = None,
728
+ save: bool | str | None = None,
729
+ **kwargs,
730
+ ) -> None:
731
+ """Visualize DA results on abstracted graph (wrapper around sc.pl.embedding)
732
+
733
+ Args:
734
+ mdata: MuData object
735
+ alpha: Significance threshold. (default: 0.1)
736
+ min_logFC: Minimum absolute log-Fold Change to show results. If is 0, show all significant neighbourhoods.
737
+ min_size: Minimum size of nodes in visualization. (default: 10)
738
+ plot_edges: If edges for neighbourhood overlaps whould be plotted.
739
+ title: Plot title.
740
+ show: Show the plot, do not return axis.
741
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
742
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
743
+ **kwargs: Additional arguments to `scanpy.pl.embedding`.
744
+
745
+ Examples:
746
+ >>> import pertpy as pt
747
+ >>> import scanpy as sc
748
+ >>> adata = pt.dt.bhattacherjee()
749
+ >>> milo = pt.tl.Milo()
750
+ >>> mdata = milo.load(adata)
751
+ >>> sc.pp.neighbors(mdata["rna"])
752
+ >>> sc.tl.umap(mdata["rna"])
753
+ >>> milo.make_nhoods(mdata["rna"])
754
+ >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
755
+ >>> milo.da_nhoods(mdata,
756
+ >>> design='~label',
757
+ >>> model_contrasts='labelwithdraw_15d_Cocaine-labelwithdraw_48h_Cocaine')
758
+ >>> milo.build_nhood_graph(mdata)
759
+ >>> milo.plot_nhood_graph(mdata)
760
+
761
+ Preview:
762
+ .. image:: /_static/docstring_previews/milo_nhood_graph.png
763
+ """
764
+ nhood_adata = mdata["milo"].T.copy()
765
+
766
+ if "Nhood_size" not in nhood_adata.obs.columns:
767
+ raise KeyError(
768
+ 'Cannot find "Nhood_size" column in adata.uns["nhood_adata"].obs -- \
769
+ please run milopy.utils.build_nhood_graph(adata)'
770
+ )
771
+
772
+ nhood_adata.obs["graph_color"] = nhood_adata.obs["logFC"]
773
+ nhood_adata.obs.loc[nhood_adata.obs["SpatialFDR"] > alpha, "graph_color"] = np.nan
774
+ nhood_adata.obs["abs_logFC"] = abs(nhood_adata.obs["logFC"])
775
+ nhood_adata.obs.loc[nhood_adata.obs["abs_logFC"] < min_logFC, "graph_color"] = np.nan
776
+
777
+ # Plotting order - extreme logFC on top
778
+ nhood_adata.obs.loc[nhood_adata.obs["graph_color"].isna(), "abs_logFC"] = np.nan
779
+ ordered = nhood_adata.obs.sort_values("abs_logFC", na_position="first").index
780
+ nhood_adata = nhood_adata[ordered]
781
+
782
+ vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
783
+ vmin = -vmax
784
+
785
+ sc.pl.embedding(
786
+ nhood_adata,
787
+ "X_milo_graph",
788
+ color="graph_color",
789
+ cmap="RdBu_r",
790
+ size=nhood_adata.obs["Nhood_size"] * min_size,
791
+ edges=plot_edges,
792
+ neighbors_key="nhood",
793
+ sort_order=False,
794
+ frameon=False,
795
+ vmax=vmax,
796
+ vmin=vmin,
797
+ title=title,
798
+ color_map=color_map,
799
+ palette=palette,
800
+ ax=ax,
801
+ show=show,
802
+ save=save,
803
+ **kwargs,
804
+ )
805
+
806
+ def plot_nhood(
807
+ self,
808
+ mdata: MuData,
809
+ ix: int,
810
+ feature_key: str | None = "rna",
811
+ basis: str = "X_umap",
812
+ color_map: Colormap | str | None = None,
813
+ palette: str | Sequence[str] | None = None,
814
+ return_fig: bool | None = None,
815
+ ax: Axes | None = None,
816
+ show: bool | None = None,
817
+ save: bool | str | None = None,
818
+ **kwargs,
819
+ ) -> None:
820
+ """Visualize cells in a neighbourhood.
821
+
822
+ Args:
823
+ mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
824
+ ix: index of neighbourhood to visualize
825
+ basis: Embedding to use for visualization.
826
+ show: Show the plot, do not return axis.
827
+ save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
828
+ **kwargs: Additional arguments to `scanpy.pl.embedding`.
829
+
830
+ Examples:
831
+ >>> import pertpy as pt
832
+ >>> import scanpy as sc
833
+ >>> adata = pt.dt.bhattacherjee()
834
+ >>> milo = pt.tl.Milo()
835
+ >>> mdata = milo.load(adata)
836
+ >>> sc.pp.neighbors(mdata["rna"])
837
+ >>> sc.tl.umap(mdata["rna"])
838
+ >>> milo.make_nhoods(mdata["rna"])
839
+ >>> milo.plot_nhood(mdata, ix=0)
840
+
841
+ Preview:
842
+ .. image:: /_static/docstring_previews/milo_nhood.png
843
+ """
844
+ mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
845
+ sc.pl.embedding(
846
+ mdata[feature_key],
847
+ basis,
848
+ color="Nhood",
849
+ size=30,
850
+ title="Nhood" + str(ix),
851
+ color_map=color_map,
852
+ palette=palette,
853
+ return_fig=return_fig,
854
+ ax=ax,
855
+ show=show,
856
+ save=save,
857
+ **kwargs,
858
+ )
859
+
860
+ def plot_da_beeswarm(
861
+ self,
862
+ mdata: MuData,
863
+ feature_key: str | None = "rna",
864
+ anno_col: str = "nhood_annotation",
865
+ alpha: float = 0.1,
866
+ subset_nhoods: list[str] = None,
867
+ palette: str | Sequence[str] | dict[str, str] | None = None,
868
+ return_fig: bool | None = None,
869
+ save: bool | str | None = None,
870
+ show: bool | None = None,
871
+ ) -> Figure | Axes | None:
872
+ """Plot beeswarm plot of logFC against nhood labels
873
+
874
+ Args:
875
+ mdata: MuData object
876
+ anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
877
+ alpha: Significance threshold. (default: 0.1)
878
+ subset_nhoods: List of nhoods to plot. If None, plot all nhoods.
879
+ palette: Name of Seaborn color palette for violinplots.
880
+ Defaults to pre-defined category colors for violinplots.
881
+
882
+ Examples:
883
+ >>> import pertpy as pt
884
+ >>> import scanpy as sc
885
+ >>> adata = pt.dt.bhattacherjee()
886
+ >>> milo = pt.tl.Milo()
887
+ >>> mdata = milo.load(adata)
888
+ >>> sc.pp.neighbors(mdata["rna"])
889
+ >>> milo.make_nhoods(mdata["rna"])
890
+ >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
891
+ >>> milo.da_nhoods(mdata, design="~label")
892
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
893
+ >>> milo.plot_da_beeswarm(mdata)
894
+
895
+ Preview:
896
+ .. image:: /_static/docstring_previews/milo_da_beeswarm.png
897
+ """
898
+ try:
899
+ nhood_adata = mdata["milo"].T.copy()
900
+ except KeyError:
901
+ raise RuntimeError(
902
+ "mdata should be a MuData object with two slots: feature_key and 'milo'. Run 'milopy.count_nhoods(adata)' first."
903
+ ) from None
904
+
905
+ try:
906
+ nhood_adata.obs[anno_col]
907
+ except KeyError:
908
+ raise RuntimeError(
909
+ f"Unable to find {anno_col} in mdata['milo'].var. Run 'milopy.utils.annotate_nhoods(adata, anno_col)' first"
910
+ ) from None
911
+
912
+ if subset_nhoods is not None:
913
+ nhood_adata = nhood_adata[nhood_adata.obs[anno_col].isin(subset_nhoods)]
914
+
915
+ try:
916
+ nhood_adata.obs["logFC"]
917
+ except KeyError:
918
+ raise RuntimeError(
919
+ "Unable to find 'logFC' in mdata.uns['nhood_adata'].obs. Run 'core.da_nhoods(adata)' first."
920
+ ) from None
921
+
922
+ sorted_annos = (
923
+ nhood_adata.obs[[anno_col, "logFC"]].groupby(anno_col).median().sort_values("logFC", ascending=True).index
924
+ )
925
+
926
+ anno_df = nhood_adata.obs[[anno_col, "logFC", "SpatialFDR"]].copy()
927
+ anno_df["is_signif"] = anno_df["SpatialFDR"] < alpha
928
+ anno_df = anno_df[anno_df[anno_col] != "nan"]
929
+
930
+ try:
931
+ obs_col = nhood_adata.uns["annotation_obs"]
932
+ if palette is None:
933
+ palette = dict(
934
+ zip(
935
+ mdata[feature_key].obs[obs_col].cat.categories,
936
+ mdata[feature_key].uns[f"{obs_col}_colors"],
937
+ strict=False,
938
+ )
939
+ )
940
+ sns.violinplot(
941
+ data=anno_df,
942
+ y=anno_col,
943
+ x="logFC",
944
+ order=sorted_annos,
945
+ inner=None,
946
+ orient="h",
947
+ palette=palette,
948
+ linewidth=0,
949
+ scale="width",
950
+ )
951
+ except BaseException: # noqa: BLE001
952
+ sns.violinplot(
953
+ data=anno_df,
954
+ y=anno_col,
955
+ x="logFC",
956
+ order=sorted_annos,
957
+ inner=None,
958
+ orient="h",
959
+ linewidth=0,
960
+ scale="width",
961
+ )
962
+ sns.stripplot(
963
+ data=anno_df,
964
+ y=anno_col,
965
+ x="logFC",
966
+ order=sorted_annos,
967
+ size=2,
968
+ hue="is_signif",
969
+ palette=["grey", "black"],
970
+ orient="h",
971
+ alpha=0.5,
972
+ )
973
+ plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
974
+ plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
975
+
976
+ if save:
977
+ plt.savefig(save, bbox_inches="tight")
978
+ return None
979
+ if show:
980
+ plt.show()
981
+ return None
982
+ if return_fig:
983
+ return plt.gcf()
984
+ if (not show and not save) or (show is None and save is None):
985
+ return plt.gca()
986
+
987
+ return None
988
+
989
+ def plot_nhood_counts_by_cond(
990
+ self,
991
+ mdata: MuData,
992
+ test_var: str,
993
+ subset_nhoods: list[str] = None,
994
+ log_counts: bool = False,
995
+ return_fig: bool | None = None,
996
+ save: bool | str | None = None,
997
+ show: bool | None = None,
998
+ ) -> Figure | Axes | None:
999
+ """Plot boxplot of cell numbers vs condition of interest.
1000
+
1001
+ Args:
1002
+ mdata: MuData object storing cell level and nhood level information
1003
+ test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
1004
+ subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods.
1005
+ log_counts: Whether to plot log1p of cell counts.
1006
+ """
1007
+ try:
1008
+ nhood_adata = mdata["milo"].T.copy()
1009
+ except KeyError:
1010
+ raise RuntimeError(
1011
+ "mdata should be a MuData object with two slots: feature_key and 'milo'. Run milopy.count_nhoods(mdata) first"
1012
+ ) from None
1013
+
1014
+ if subset_nhoods is None:
1015
+ subset_nhoods = nhood_adata.obs_names
1016
+
1017
+ pl_df = pd.DataFrame(nhood_adata[subset_nhoods].X.A, columns=nhood_adata.var_names).melt(
1018
+ var_name=nhood_adata.uns["sample_col"], value_name="n_cells"
1019
+ )
1020
+ pl_df = pd.merge(pl_df, nhood_adata.var)
1021
+ pl_df["log_n_cells"] = np.log1p(pl_df["n_cells"])
1022
+ if not log_counts:
1023
+ sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue")
1024
+ sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3)
1025
+ plt.ylabel("# cells")
1026
+ else:
1027
+ sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue")
1028
+ sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3)
1029
+ plt.ylabel("log(# cells + 1)")
1030
+
1031
+ plt.xticks(rotation=90)
1032
+ plt.xlabel(test_var)
1033
+
1034
+ if save:
1035
+ plt.savefig(save, bbox_inches="tight")
1036
+ return None
1037
+ if show:
1038
+ plt.show()
1039
+ return None
1040
+ if return_fig:
1041
+ return plt.gcf()
1042
+ if not (show or save):
1043
+ return plt.gca()
1044
+
1045
+ return None