pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_milo.py CHANGED
@@ -3,21 +3,24 @@ from __future__ import annotations
3
3
  import logging
4
4
  import random
5
5
  import re
6
- from typing import Literal
6
+ from typing import TYPE_CHECKING, Literal
7
7
 
8
+ import matplotlib.pyplot as plt
8
9
  import numpy as np
9
10
  import pandas as pd
11
+ import scanpy as sc
12
+ import seaborn as sns
10
13
  from anndata import AnnData
14
+ from lamin_utils import logger
11
15
  from mudata import MuData
12
- from rich import print
13
-
14
- try:
15
- from rpy2.robjects import conversion, numpy2ri, pandas2ri
16
- from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
17
- except ModuleNotFoundError:
18
- print(
19
- "[bold yellow]ryp2 is not installed. Install with [green]pip install rpy2 [yellow]to run tools with R support."
20
- )
16
+
17
+ if TYPE_CHECKING:
18
+ from collections.abc import Sequence
19
+
20
+ from matplotlib.axes import Axes
21
+ from matplotlib.colors import Colormap
22
+ from matplotlib.figure import Figure
23
+
21
24
  from scipy.sparse import csr_matrix
22
25
  from sklearn.metrics.pairwise import euclidean_distances
23
26
 
@@ -26,7 +29,16 @@ class Milo:
26
29
  """Python implementation of Milo."""
27
30
 
28
31
  def __init__(self):
29
- pass
32
+ try:
33
+ from rpy2.robjects import conversion, numpy2ri, pandas2ri
34
+ from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
35
+ except ModuleNotFoundError:
36
+ raise ImportError("milo requires rpy2 to be installed.") from None
37
+
38
+ try:
39
+ importr("edgeR")
40
+ except ImportError as e:
41
+ raise ImportError("milo requires a valid R installation with edger installed:\n") from e
30
42
 
31
43
  def load(
32
44
  self,
@@ -39,7 +51,7 @@ class Milo:
39
51
  input: AnnData
40
52
  feature_key: Key to store the cell-level AnnData object in the MuData object
41
53
  Returns:
42
- MuData: MuData object with original AnnData (default is `mudata[feature_key]`).
54
+ MuData: MuData object with original AnnData.
43
55
 
44
56
  Examples:
45
57
  >>> import pertpy as pt
@@ -71,11 +83,10 @@ class Milo:
71
83
  neighbors_key: The key in `adata.obsp` or `mdata[feature_key].obsp` to use as KNN graph.
72
84
  If not specified, `make_nhoods` looks .obsp[‘connectivities’] for connectivities (default storage places for `scanpy.pp.neighbors`).
73
85
  If specified, it looks at .obsp[.uns[neighbors_key][‘connectivities_key’]] for connectivities.
74
- (default: None)
75
- feature_key: If input data is MuData, specify key to cell-level AnnData object. (default: 'rna')
76
- prop: Fraction of cells to sample for neighbourhood index search. (default: 0.1)
77
- seed: Random seed for cell sampling. (default: 0)
78
- copy: Determines whether a copy of the `adata` is returned. (default: False)
86
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
87
+ prop: Fraction of cells to sample for neighbourhood index search.
88
+ seed: Random seed for cell sampling.
89
+ copy: Determines whether a copy of the `adata` is returned.
79
90
 
80
91
  Returns:
81
92
  If `copy=True`, returns the copy of `adata` with the result in `.obs`, `.obsm`, and `.uns`.
@@ -119,7 +130,7 @@ class Milo:
119
130
  try:
120
131
  knn_graph = adata.obsp["connectivities"].copy()
121
132
  except KeyError:
122
- print('No "connectivities" slot in adata.obsp -- please run scanpy.pp.neighbors(adata) first')
133
+ logger.error('No "connectivities" slot in adata.obsp -- please run scanpy.pp.neighbors(adata) first')
123
134
  raise
124
135
  else:
125
136
  try:
@@ -174,6 +185,7 @@ class Milo:
174
185
  dist_mat = knn_dists[nhood_ixs, :]
175
186
  k_distances = dist_mat.max(1).toarray().ravel()
176
187
  adata.obs["nhood_kth_distance"] = 0
188
+ adata.obs["nhood_kth_distance"] = adata.obs["nhood_kth_distance"].astype(float)
177
189
  adata.obs.loc[adata.obs["nhood_ixs_refined"] == 1, "nhood_kth_distance"] = k_distances
178
190
 
179
191
  if copy:
@@ -190,7 +202,7 @@ class Milo:
190
202
  Args:
191
203
  data: AnnData object with neighbourhoods defined in `obsm['nhoods']` or MuData object with a modality with neighbourhoods defined in `obsm['nhoods']`
192
204
  sample_col: Column in adata.obs that contains sample information
193
- feature_key: If input data is MuData, specify key to cell-level AnnData object. (default: 'rna')
205
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
194
206
 
195
207
  Returns:
196
208
  MuData object storing the original (i.e. rna) AnnData in `mudata[feature_key]`
@@ -221,7 +233,7 @@ class Milo:
221
233
  try:
222
234
  nhoods = adata.obsm["nhoods"]
223
235
  except KeyError:
224
- print('Cannot find "nhoods" slot in adata.obsm -- please run milopy.make_nhoods(adata)')
236
+ logger.error('Cannot find "nhoods" slot in adata.obsm -- please run milopy.make_nhoods(adata)')
225
237
  raise
226
238
  # Make nhood abundance matrix
227
239
  sample_dummies = pd.get_dummies(adata.obs[sample_col])
@@ -229,7 +241,7 @@ class Milo:
229
241
  sample_dummies = csr_matrix(sample_dummies.values)
230
242
  nhood_count_mat = nhoods.T.dot(sample_dummies)
231
243
  sample_obs = pd.DataFrame(index=all_samples)
232
- sample_adata = AnnData(X=nhood_count_mat.T, obs=sample_obs, dtype=np.float32)
244
+ sample_adata = AnnData(X=nhood_count_mat.T, obs=sample_obs)
233
245
  sample_adata.uns["sample_col"] = sample_col
234
246
  # Save nhood index info
235
247
  sample_adata.var["index_cell"] = adata.obs_names[adata.obs["nhood_ixs_refined"] == 1]
@@ -261,10 +273,10 @@ class Milo:
261
273
  design: Formula for the test, following glm syntax from R (e.g. '~ condition').
262
274
  Terms should be columns in `milo_mdata[feature_key].obs`.
263
275
  model_contrasts: A string vector that defines the contrasts used to perform DA testing, following glm syntax from R (e.g. "conditionDisease - conditionControl").
264
- If no contrast is specified (default), then the last categorical level in condition of interest is used as the test group. Defaults to None.
265
- subset_samples: subset of samples (obs in `milo_mdata['milo']`) to use for the test. Defaults to None.
266
- add_intercept: whether to include an intercept in the model. If False, this is equivalent to adding + 0 in the design formula. When model_contrasts is specified, this is set to False by default. Defaults to True.
267
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
276
+ If no contrast is specified (default), then the last categorical level in condition of interest is used as the test group.
277
+ subset_samples: subset of samples (obs in `milo_mdata['milo']`) to use for the test.
278
+ add_intercept: whether to include an intercept in the model. If False, this is equivalent to adding + 0 in the design formula. When model_contrasts is specified, this is set to False by default.
279
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
268
280
  solver: The solver to fit the model to. One of "edger" (requires R, rpy2 and edgeR to be installed) or "batchglm"
269
281
 
270
282
  Returns:
@@ -288,8 +300,8 @@ class Milo:
288
300
  try:
289
301
  sample_adata = mdata["milo"]
290
302
  except KeyError:
291
- print(
292
- "[bold red]milo_mdata should be a MuData object with two slots:"
303
+ logger.error(
304
+ "milo_mdata should be a MuData object with two slots:"
293
305
  " feature_key and 'milo' - please run milopy.count_nhoods() first"
294
306
  )
295
307
  raise
@@ -303,7 +315,7 @@ class Milo:
303
315
  sample_obs = adata.obs[covariates + [sample_col]].drop_duplicates()
304
316
  except KeyError:
305
317
  missing_cov = [x for x in covariates if x not in sample_adata.obs.columns]
306
- print("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
318
+ logger.warning("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
307
319
  raise
308
320
  sample_obs = sample_obs[covariates + [sample_col]]
309
321
  sample_obs.index = sample_obs[sample_col].astype("str")
@@ -311,7 +323,7 @@ class Milo:
311
323
  try:
312
324
  assert sample_obs.loc[sample_adata.obs_names].shape[0] == len(sample_adata.obs_names)
313
325
  except AssertionError:
314
- print(
326
+ logger.warning(
315
327
  f"Values in mdata[{feature_key}].obs[{covariates}] cannot be unambiguously assigned to each sample"
316
328
  f" -- each sample value should match a single covariate value"
317
329
  )
@@ -323,7 +335,9 @@ class Milo:
323
335
  design_df = sample_adata.obs[covariates]
324
336
  except KeyError:
325
337
  missing_cov = [x for x in covariates if x not in sample_adata.obs.columns]
326
- print('Covariates {c} are not columns in adata.uns["sample_adata"].obs'.format(c=" ".join(missing_cov)))
338
+ logger.error(
339
+ 'Covariates {c} are not columns in adata.uns["sample_adata"].obs'.format(c=" ".join(missing_cov))
340
+ )
327
341
  raise
328
342
  # Get count matrix
329
343
  count_mat = sample_adata.X.T.toarray()
@@ -367,6 +381,8 @@ class Milo:
367
381
  return(colnames(m))
368
382
  }
369
383
  """
384
+ from rpy2.robjects.packages import STAP
385
+
370
386
  get_model_cols = STAP(r_str, "get_model_cols")
371
387
  model_mat_cols = get_model_cols.get_model_cols(design_df, design)
372
388
  model_df = pd.DataFrame(model)
@@ -374,13 +390,16 @@ class Milo:
374
390
  try:
375
391
  mod_contrast = limma.makeContrasts(contrasts=model_contrasts, levels=model_df)
376
392
  except ValueError:
377
- print("Model contrasts must be in the form 'A-B' or 'A+B'")
393
+ logger.error("Model contrasts must be in the form 'A-B' or 'A+B'")
378
394
  raise
379
395
  res = base.as_data_frame(
380
396
  edgeR.topTags(edgeR.glmQLFTest(fit, contrast=mod_contrast), sort_by="none", n=np.inf)
381
397
  )
382
398
  else:
383
399
  res = base.as_data_frame(edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf))
400
+
401
+ from rpy2.robjects import conversion
402
+
384
403
  res = conversion.rpy2py(res)
385
404
  if not isinstance(res, pd.DataFrame):
386
405
  res = pd.DataFrame(res)
@@ -405,7 +424,7 @@ class Milo:
405
424
  Args:
406
425
  mdata: MuData object
407
426
  anno_col: Column in adata.obs containing the cell annotations to use for nhood labelling
408
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
427
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
409
428
 
410
429
  Returns:
411
430
  None. Adds in place:
@@ -423,12 +442,12 @@ class Milo:
423
442
  >>> sc.pp.neighbors(mdata["rna"])
424
443
  >>> milo.make_nhoods(mdata["rna"])
425
444
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
426
- >>> milo.annotate_nhoods(mdata, anno_col='cell_type')
445
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
427
446
  """
428
447
  try:
429
448
  sample_adata = mdata["milo"]
430
449
  except KeyError:
431
- print(
450
+ logger.error(
432
451
  "milo_mdata should be a MuData object with two slots: feature_key and 'milo' - please run milopy.count_nhoods(adata) first"
433
452
  )
434
453
  raise
@@ -459,7 +478,7 @@ class Milo:
459
478
  Args:
460
479
  mdata: MuData object
461
480
  anno_col: Column in adata.obs containing the cell annotations to use for nhood labelling
462
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
481
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
463
482
 
464
483
  Returns:
465
484
  None. Adds in place:
@@ -474,7 +493,7 @@ class Milo:
474
493
  >>> sc.pp.neighbors(mdata["rna"])
475
494
  >>> milo.make_nhoods(mdata["rna"])
476
495
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
477
- >>> milo.annotate_nhoods_continuous(mdata, anno_col='nUMI')
496
+ >>> milo.annotate_nhoods_continuous(mdata, anno_col="nUMI")
478
497
  """
479
498
  if "milo" not in mdata.mod:
480
499
  raise ValueError(
@@ -500,7 +519,7 @@ class Milo:
500
519
  Args:
501
520
  mdata: MuData object
502
521
  new_covariates: columns in `milo_mdata[feature_key].obs` to add to `milo_mdata['milo'].obs`.
503
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
522
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
504
523
 
505
524
  Returns:
506
525
  None, adds columns to `milo_mdata['milo']` in place
@@ -519,7 +538,7 @@ class Milo:
519
538
  try:
520
539
  sample_adata = mdata["milo"]
521
540
  except KeyError:
522
- print(
541
+ logger.error(
523
542
  "milo_mdata should be a MuData object with two slots: feature_key and 'milo' - please run milopy.count_nhoods(adata) first"
524
543
  )
525
544
  raise
@@ -533,14 +552,14 @@ class Milo:
533
552
  sample_obs = adata.obs[covariates + [sample_col]].drop_duplicates()
534
553
  except KeyError:
535
554
  missing_cov = [covar for covar in covariates if covar not in sample_adata.obs.columns]
536
- print("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
555
+ logger.error("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
537
556
  raise
538
557
  sample_obs = sample_obs[covariates + [sample_col]].astype("str")
539
558
  sample_obs.index = sample_obs[sample_col]
540
559
  try:
541
560
  assert sample_obs.loc[sample_adata.obs_names].shape[0] == len(sample_adata.obs_names)
542
561
  except ValueError:
543
- print(
562
+ logger.error(
544
563
  "Covariates cannot be unambiguously assigned to each sample -- each sample value should match a single covariate value"
545
564
  )
546
565
  raise
@@ -551,8 +570,8 @@ class Milo:
551
570
 
552
571
  Args:
553
572
  mdata: MuData object
554
- basis: Name of the obsm basis to use for layout of neighbourhoods (key in `adata.obsm`). Defaults to "X_umap".
555
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
573
+ basis: Name of the obsm basis to use for layout of neighbourhoods (key in `adata.obsm`).
574
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
556
575
 
557
576
  Returns:
558
577
  - `milo_mdata['milo'].varp['nhood_connectivities']`: graph of overlap between neighbourhoods (i.e. no of shared cells)
@@ -584,13 +603,13 @@ class Milo:
584
603
  "distances_key": "",
585
604
  }
586
605
 
587
- def add_nhood_expression(self, mdata: MuData, layer: str | None = None, feature_key: str | None = "rna"):
606
+ def add_nhood_expression(self, mdata: MuData, layer: str | None = None, feature_key: str | None = "rna") -> None:
588
607
  """Calculates the mean expression in neighbourhoods of each feature.
589
608
 
590
609
  Args:
591
610
  mdata: MuData object
592
- layer: If provided, use `milo_mdata[feature_key][layer]` as expression matrix instead of `milo_mdata[feature_key].X`. Defaults to None.
593
- feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
611
+ layer: If provided, use `milo_mdata[feature_key][layer]` as expression matrix instead of `milo_mdata[feature_key].X`.
612
+ feature_key: If input data is MuData, specify key to cell-level AnnData object.
594
613
 
595
614
  Returns:
596
615
  Updates adata in place to store the matrix of average expression in each neighbourhood in `milo_mdata['milo'].varm['expr']`
@@ -609,7 +628,7 @@ class Milo:
609
628
  try:
610
629
  sample_adata = mdata["milo"]
611
630
  except KeyError:
612
- print(
631
+ logger.error(
613
632
  "milo_mdata should be a MuData object with two slots:"
614
633
  " feature_key and 'milo' - please run milopy.count_nhoods(adata) first"
615
634
  )
@@ -633,6 +652,9 @@ class Milo:
633
652
  self,
634
653
  ):
635
654
  """Set up rpy2 to run edgeR"""
655
+ from rpy2.robjects import numpy2ri, pandas2ri
656
+ from rpy2.robjects.packages import importr
657
+
636
658
  numpy2ri.activate()
637
659
  pandas2ri.activate()
638
660
  edgeR = self._try_import_bioc_library("edgeR")
@@ -651,11 +673,13 @@ class Milo:
651
673
  Args:
652
674
  name (str): R packages name
653
675
  """
676
+ from rpy2.robjects.packages import PackageNotInstalledError, importr
677
+
654
678
  try:
655
679
  _r_lib = importr(name)
656
680
  return _r_lib
657
681
  except PackageNotInstalledError:
658
- print(f"Install Bioconductor library `{name!r}` first as `BiocManager::install({name!r}).`")
682
+ logger.error(f"Install Bioconductor library `{name!r}` first as `BiocManager::install({name!r}).`")
659
683
  raise
660
684
 
661
685
  def _graph_spatial_fdr(
@@ -663,11 +687,13 @@ class Milo:
663
687
  sample_adata: AnnData,
664
688
  neighbors_key: str | None = None,
665
689
  ):
666
- """FDR correction weighted on inverse of connectivity of neighbourhoods. The distance to the k-th nearest neighbor is used as a measure of connectivity.
690
+ """FDR correction weighted on inverse of connectivity of neighbourhoods.
691
+
692
+ The distance to the k-th nearest neighbor is used as a measure of connectivity.
667
693
 
668
694
  Args:
669
695
  sample_adata: Sample-level AnnData.
670
- neighbors_key: The key in `adata.obsp` to use as KNN graph. Defaults to None.
696
+ neighbors_key: The key in `adata.obsp` to use as KNN graph.
671
697
  """
672
698
  # use 1/connectivity as the weighting for the weighted BH adjustment from Cydar
673
699
  w = 1 / sample_adata.var["kth_distance"]
@@ -686,3 +712,334 @@ class Milo:
686
712
 
687
713
  sample_adata.var["SpatialFDR"] = np.nan
688
714
  sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp
715
+
716
+ def plot_nhood_graph(
717
+ self,
718
+ mdata: MuData,
719
+ alpha: float = 0.1,
720
+ min_logFC: float = 0,
721
+ min_size: int = 10,
722
+ plot_edges: bool = False,
723
+ title: str = "DA log-Fold Change",
724
+ color_map: Colormap | str | None = None,
725
+ palette: str | Sequence[str] | None = None,
726
+ ax: Axes | None = None,
727
+ show: bool | None = None,
728
+ save: bool | str | None = None,
729
+ **kwargs,
730
+ ) -> None:
731
+ """Visualize DA results on abstracted graph (wrapper around sc.pl.embedding)
732
+
733
+ Args:
734
+ mdata: MuData object
735
+ alpha: Significance threshold. (default: 0.1)
736
+ min_logFC: Minimum absolute log-Fold Change to show results. If is 0, show all significant neighbourhoods.
737
+ min_size: Minimum size of nodes in visualization. (default: 10)
738
+ plot_edges: If edges for neighbourhood overlaps whould be plotted.
739
+ title: Plot title.
740
+ show: Show the plot, do not return axis.
741
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
742
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
743
+ **kwargs: Additional arguments to `scanpy.pl.embedding`.
744
+
745
+ Examples:
746
+ >>> import pertpy as pt
747
+ >>> import scanpy as sc
748
+ >>> adata = pt.dt.bhattacherjee()
749
+ >>> milo = pt.tl.Milo()
750
+ >>> mdata = milo.load(adata)
751
+ >>> sc.pp.neighbors(mdata["rna"])
752
+ >>> sc.tl.umap(mdata["rna"])
753
+ >>> milo.make_nhoods(mdata["rna"])
754
+ >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
755
+ >>> milo.da_nhoods(mdata,
756
+ >>> design='~label',
757
+ >>> model_contrasts='labelwithdraw_15d_Cocaine-labelwithdraw_48h_Cocaine')
758
+ >>> milo.build_nhood_graph(mdata)
759
+ >>> milo.plot_nhood_graph(mdata)
760
+
761
+ Preview:
762
+ .. image:: /_static/docstring_previews/milo_nhood_graph.png
763
+ """
764
+ nhood_adata = mdata["milo"].T.copy()
765
+
766
+ if "Nhood_size" not in nhood_adata.obs.columns:
767
+ raise KeyError(
768
+ 'Cannot find "Nhood_size" column in adata.uns["nhood_adata"].obs -- \
769
+ please run milopy.utils.build_nhood_graph(adata)'
770
+ )
771
+
772
+ nhood_adata.obs["graph_color"] = nhood_adata.obs["logFC"]
773
+ nhood_adata.obs.loc[nhood_adata.obs["SpatialFDR"] > alpha, "graph_color"] = np.nan
774
+ nhood_adata.obs["abs_logFC"] = abs(nhood_adata.obs["logFC"])
775
+ nhood_adata.obs.loc[nhood_adata.obs["abs_logFC"] < min_logFC, "graph_color"] = np.nan
776
+
777
+ # Plotting order - extreme logFC on top
778
+ nhood_adata.obs.loc[nhood_adata.obs["graph_color"].isna(), "abs_logFC"] = np.nan
779
+ ordered = nhood_adata.obs.sort_values("abs_logFC", na_position="first").index
780
+ nhood_adata = nhood_adata[ordered]
781
+
782
+ vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
783
+ vmin = -vmax
784
+
785
+ sc.pl.embedding(
786
+ nhood_adata,
787
+ "X_milo_graph",
788
+ color="graph_color",
789
+ cmap="RdBu_r",
790
+ size=nhood_adata.obs["Nhood_size"] * min_size,
791
+ edges=plot_edges,
792
+ neighbors_key="nhood",
793
+ sort_order=False,
794
+ frameon=False,
795
+ vmax=vmax,
796
+ vmin=vmin,
797
+ title=title,
798
+ color_map=color_map,
799
+ palette=palette,
800
+ ax=ax,
801
+ show=show,
802
+ save=save,
803
+ **kwargs,
804
+ )
805
+
806
+ def plot_nhood(
807
+ self,
808
+ mdata: MuData,
809
+ ix: int,
810
+ feature_key: str | None = "rna",
811
+ basis: str = "X_umap",
812
+ color_map: Colormap | str | None = None,
813
+ palette: str | Sequence[str] | None = None,
814
+ return_fig: bool | None = None,
815
+ ax: Axes | None = None,
816
+ show: bool | None = None,
817
+ save: bool | str | None = None,
818
+ **kwargs,
819
+ ) -> None:
820
+ """Visualize cells in a neighbourhood.
821
+
822
+ Args:
823
+ mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
824
+ ix: index of neighbourhood to visualize
825
+ basis: Embedding to use for visualization.
826
+ show: Show the plot, do not return axis.
827
+ save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
828
+ **kwargs: Additional arguments to `scanpy.pl.embedding`.
829
+
830
+ Examples:
831
+ >>> import pertpy as pt
832
+ >>> import scanpy as sc
833
+ >>> adata = pt.dt.bhattacherjee()
834
+ >>> milo = pt.tl.Milo()
835
+ >>> mdata = milo.load(adata)
836
+ >>> sc.pp.neighbors(mdata["rna"])
837
+ >>> sc.tl.umap(mdata["rna"])
838
+ >>> milo.make_nhoods(mdata["rna"])
839
+ >>> milo.plot_nhood(mdata, ix=0)
840
+
841
+ Preview:
842
+ .. image:: /_static/docstring_previews/milo_nhood.png
843
+ """
844
+ mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
845
+ sc.pl.embedding(
846
+ mdata[feature_key],
847
+ basis,
848
+ color="Nhood",
849
+ size=30,
850
+ title="Nhood" + str(ix),
851
+ color_map=color_map,
852
+ palette=palette,
853
+ return_fig=return_fig,
854
+ ax=ax,
855
+ show=show,
856
+ save=save,
857
+ **kwargs,
858
+ )
859
+
860
+ def plot_da_beeswarm(
861
+ self,
862
+ mdata: MuData,
863
+ feature_key: str | None = "rna",
864
+ anno_col: str = "nhood_annotation",
865
+ alpha: float = 0.1,
866
+ subset_nhoods: list[str] = None,
867
+ palette: str | Sequence[str] | dict[str, str] | None = None,
868
+ return_fig: bool | None = None,
869
+ save: bool | str | None = None,
870
+ show: bool | None = None,
871
+ ) -> Figure | Axes | None:
872
+ """Plot beeswarm plot of logFC against nhood labels
873
+
874
+ Args:
875
+ mdata: MuData object
876
+ anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
877
+ alpha: Significance threshold. (default: 0.1)
878
+ subset_nhoods: List of nhoods to plot. If None, plot all nhoods.
879
+ palette: Name of Seaborn color palette for violinplots.
880
+ Defaults to pre-defined category colors for violinplots.
881
+
882
+ Examples:
883
+ >>> import pertpy as pt
884
+ >>> import scanpy as sc
885
+ >>> adata = pt.dt.bhattacherjee()
886
+ >>> milo = pt.tl.Milo()
887
+ >>> mdata = milo.load(adata)
888
+ >>> sc.pp.neighbors(mdata["rna"])
889
+ >>> milo.make_nhoods(mdata["rna"])
890
+ >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
891
+ >>> milo.da_nhoods(mdata, design="~label")
892
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
893
+ >>> milo.plot_da_beeswarm(mdata)
894
+
895
+ Preview:
896
+ .. image:: /_static/docstring_previews/milo_da_beeswarm.png
897
+ """
898
+ try:
899
+ nhood_adata = mdata["milo"].T.copy()
900
+ except KeyError:
901
+ raise RuntimeError(
902
+ "mdata should be a MuData object with two slots: feature_key and 'milo'. Run 'milopy.count_nhoods(adata)' first."
903
+ ) from None
904
+
905
+ try:
906
+ nhood_adata.obs[anno_col]
907
+ except KeyError:
908
+ raise RuntimeError(
909
+ f"Unable to find {anno_col} in mdata['milo'].var. Run 'milopy.utils.annotate_nhoods(adata, anno_col)' first"
910
+ ) from None
911
+
912
+ if subset_nhoods is not None:
913
+ nhood_adata = nhood_adata[nhood_adata.obs[anno_col].isin(subset_nhoods)]
914
+
915
+ try:
916
+ nhood_adata.obs["logFC"]
917
+ except KeyError:
918
+ raise RuntimeError(
919
+ "Unable to find 'logFC' in mdata.uns['nhood_adata'].obs. Run 'core.da_nhoods(adata)' first."
920
+ ) from None
921
+
922
+ sorted_annos = (
923
+ nhood_adata.obs[[anno_col, "logFC"]].groupby(anno_col).median().sort_values("logFC", ascending=True).index
924
+ )
925
+
926
+ anno_df = nhood_adata.obs[[anno_col, "logFC", "SpatialFDR"]].copy()
927
+ anno_df["is_signif"] = anno_df["SpatialFDR"] < alpha
928
+ anno_df = anno_df[anno_df[anno_col] != "nan"]
929
+
930
+ try:
931
+ obs_col = nhood_adata.uns["annotation_obs"]
932
+ if palette is None:
933
+ palette = dict(
934
+ zip(
935
+ mdata[feature_key].obs[obs_col].cat.categories,
936
+ mdata[feature_key].uns[f"{obs_col}_colors"],
937
+ strict=False,
938
+ )
939
+ )
940
+ sns.violinplot(
941
+ data=anno_df,
942
+ y=anno_col,
943
+ x="logFC",
944
+ order=sorted_annos,
945
+ inner=None,
946
+ orient="h",
947
+ palette=palette,
948
+ linewidth=0,
949
+ scale="width",
950
+ )
951
+ except BaseException: # noqa: BLE001
952
+ sns.violinplot(
953
+ data=anno_df,
954
+ y=anno_col,
955
+ x="logFC",
956
+ order=sorted_annos,
957
+ inner=None,
958
+ orient="h",
959
+ linewidth=0,
960
+ scale="width",
961
+ )
962
+ sns.stripplot(
963
+ data=anno_df,
964
+ y=anno_col,
965
+ x="logFC",
966
+ order=sorted_annos,
967
+ size=2,
968
+ hue="is_signif",
969
+ palette=["grey", "black"],
970
+ orient="h",
971
+ alpha=0.5,
972
+ )
973
+ plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
974
+ plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
975
+
976
+ if save:
977
+ plt.savefig(save, bbox_inches="tight")
978
+ return None
979
+ if show:
980
+ plt.show()
981
+ return None
982
+ if return_fig:
983
+ return plt.gcf()
984
+ if (not show and not save) or (show is None and save is None):
985
+ return plt.gca()
986
+
987
+ return None
988
+
989
+ def plot_nhood_counts_by_cond(
990
+ self,
991
+ mdata: MuData,
992
+ test_var: str,
993
+ subset_nhoods: list[str] = None,
994
+ log_counts: bool = False,
995
+ return_fig: bool | None = None,
996
+ save: bool | str | None = None,
997
+ show: bool | None = None,
998
+ ) -> Figure | Axes | None:
999
+ """Plot boxplot of cell numbers vs condition of interest.
1000
+
1001
+ Args:
1002
+ mdata: MuData object storing cell level and nhood level information
1003
+ test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
1004
+ subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods.
1005
+ log_counts: Whether to plot log1p of cell counts.
1006
+ """
1007
+ try:
1008
+ nhood_adata = mdata["milo"].T.copy()
1009
+ except KeyError:
1010
+ raise RuntimeError(
1011
+ "mdata should be a MuData object with two slots: feature_key and 'milo'. Run milopy.count_nhoods(mdata) first"
1012
+ ) from None
1013
+
1014
+ if subset_nhoods is None:
1015
+ subset_nhoods = nhood_adata.obs_names
1016
+
1017
+ pl_df = pd.DataFrame(nhood_adata[subset_nhoods].X.A, columns=nhood_adata.var_names).melt(
1018
+ var_name=nhood_adata.uns["sample_col"], value_name="n_cells"
1019
+ )
1020
+ pl_df = pd.merge(pl_df, nhood_adata.var)
1021
+ pl_df["log_n_cells"] = np.log1p(pl_df["n_cells"])
1022
+ if not log_counts:
1023
+ sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue")
1024
+ sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3)
1025
+ plt.ylabel("# cells")
1026
+ else:
1027
+ sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue")
1028
+ sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3)
1029
+ plt.ylabel("log(# cells + 1)")
1030
+
1031
+ plt.xticks(rotation=90)
1032
+ plt.xlabel(test_var)
1033
+
1034
+ if save:
1035
+ plt.savefig(save, bbox_inches="tight")
1036
+ return None
1037
+ if show:
1038
+ plt.show()
1039
+ return None
1040
+ if return_fig:
1041
+ return plt.gcf()
1042
+ if not (show or save):
1043
+ return plt.gca()
1044
+
1045
+ return None