pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_milo.py CHANGED
@@ -3,14 +3,23 @@ from __future__ import annotations
3
3
  import logging
4
4
  import random
5
5
  import re
6
- from typing import Literal
6
+ from typing import TYPE_CHECKING, Literal
7
7
 
8
+ import matplotlib.pyplot as plt
8
9
  import numpy as np
9
10
  import pandas as pd
11
+ import scanpy as sc
12
+ import seaborn as sns
10
13
  from anndata import AnnData
11
14
  from mudata import MuData
12
15
  from rich import print
13
16
 
17
+ if TYPE_CHECKING:
18
+ from collections.abc import Sequence
19
+
20
+ from matplotlib.axes import Axes
21
+ from matplotlib.colors import Colormap
22
+
14
23
  try:
15
24
  from rpy2.robjects import conversion, numpy2ri, pandas2ri
16
25
  from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
@@ -39,7 +48,7 @@ class Milo:
39
48
  input: AnnData
40
49
  feature_key: Key to store the cell-level AnnData object in the MuData object
41
50
  Returns:
42
- MuData: MuData object with original AnnData (default is `mudata[feature_key]`).
51
+ MuData: MuData object with original AnnData. Defaults to`mudata[feature_key]`.
43
52
 
44
53
  Examples:
45
54
  >>> import pertpy as pt
@@ -71,11 +80,11 @@ class Milo:
71
80
  neighbors_key: The key in `adata.obsp` or `mdata[feature_key].obsp` to use as KNN graph.
72
81
  If not specified, `make_nhoods` looks .obsp[‘connectivities’] for connectivities (default storage places for `scanpy.pp.neighbors`).
73
82
  If specified, it looks at .obsp[.uns[neighbors_key][‘connectivities_key’]] for connectivities.
74
- (default: None)
75
- feature_key: If input data is MuData, specify key to cell-level AnnData object. (default: 'rna')
76
- prop: Fraction of cells to sample for neighbourhood index search. (default: 0.1)
77
- seed: Random seed for cell sampling. (default: 0)
78
- copy: Determines whether a copy of the `adata` is returned. (default: False)
83
+ Defaults to None.
84
+ feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
85
+ prop: Fraction of cells to sample for neighbourhood index search. Defaults to 0.1.
86
+ seed: Random seed for cell sampling. Defaults to 0.
87
+ copy: Determines whether a copy of the `adata` is returned. Defaults to False.
79
88
 
80
89
  Returns:
81
90
  If `copy=True`, returns the copy of `adata` with the result in `.obs`, `.obsm`, and `.uns`.
@@ -190,7 +199,7 @@ class Milo:
190
199
  Args:
191
200
  data: AnnData object with neighbourhoods defined in `obsm['nhoods']` or MuData object with a modality with neighbourhoods defined in `obsm['nhoods']`
192
201
  sample_col: Column in adata.obs that contains sample information
193
- feature_key: If input data is MuData, specify key to cell-level AnnData object. (default: 'rna')
202
+ feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
194
203
 
195
204
  Returns:
196
205
  MuData object storing the original (i.e. rna) AnnData in `mudata[feature_key]`
@@ -423,7 +432,7 @@ class Milo:
423
432
  >>> sc.pp.neighbors(mdata["rna"])
424
433
  >>> milo.make_nhoods(mdata["rna"])
425
434
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
426
- >>> milo.annotate_nhoods(mdata, anno_col='cell_type')
435
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
427
436
  """
428
437
  try:
429
438
  sample_adata = mdata["milo"]
@@ -474,7 +483,7 @@ class Milo:
474
483
  >>> sc.pp.neighbors(mdata["rna"])
475
484
  >>> milo.make_nhoods(mdata["rna"])
476
485
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
477
- >>> milo.annotate_nhoods_continuous(mdata, anno_col='nUMI')
486
+ >>> milo.annotate_nhoods_continuous(mdata, anno_col="nUMI")
478
487
  """
479
488
  if "milo" not in mdata.mod:
480
489
  raise ValueError(
@@ -663,7 +672,9 @@ class Milo:
663
672
  sample_adata: AnnData,
664
673
  neighbors_key: str | None = None,
665
674
  ):
666
- """FDR correction weighted on inverse of connectivity of neighbourhoods. The distance to the k-th nearest neighbor is used as a measure of connectivity.
675
+ """FDR correction weighted on inverse of connectivity of neighbourhoods.
676
+
677
+ The distance to the k-th nearest neighbor is used as a measure of connectivity.
667
678
 
668
679
  Args:
669
680
  sample_adata: Sample-level AnnData.
@@ -686,3 +697,326 @@ class Milo:
686
697
 
687
698
  sample_adata.var["SpatialFDR"] = np.nan
688
699
  sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp
700
+
701
+ def plot_nhood_graph(
702
+ self,
703
+ mdata: MuData,
704
+ alpha: float = 0.1,
705
+ min_logFC: float = 0,
706
+ min_size: int = 10,
707
+ plot_edges: bool = False,
708
+ title: str = "DA log-Fold Change",
709
+ color_map: Colormap | str | None = None,
710
+ palette: str | Sequence[str] | None = None,
711
+ ax: Axes | None = None,
712
+ show: bool | None = None,
713
+ save: bool | str | None = None,
714
+ **kwargs,
715
+ ) -> None:
716
+ """Visualize DA results on abstracted graph (wrapper around sc.pl.embedding)
717
+
718
+ Args:
719
+ mdata: MuData object
720
+ alpha: Significance threshold. (default: 0.1)
721
+ min_logFC: Minimum absolute log-Fold Change to show results. If is 0, show all significant neighbourhoods. Defaults to 0.
722
+ min_size: Minimum size of nodes in visualization. (default: 10)
723
+ plot_edges: If edges for neighbourhood overlaps whould be plotted. Defaults to False.
724
+ title: Plot title. Defaults to "DA log-Fold Change".
725
+ show: Show the plot, do not return axis.
726
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
727
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
728
+ **kwargs: Additional arguments to `scanpy.pl.embedding`.
729
+
730
+ Examples:
731
+ >>> import pertpy as pt
732
+ >>> import scanpy as sc
733
+ >>> adata = pt.dt.bhattacherjee()
734
+ >>> milo = pt.tl.Milo()
735
+ >>> mdata = milo.load(adata)
736
+ >>> sc.pp.neighbors(mdata["rna"])
737
+ >>> sc.tl.umap(mdata["rna"])
738
+ >>> milo.make_nhoods(mdata["rna"])
739
+ >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
740
+ >>> milo.da_nhoods(mdata,
741
+ >>> design='~label',
742
+ >>> model_contrasts='labelwithdraw_15d_Cocaine-labelwithdraw_48h_Cocaine')
743
+ >>> milo.build_nhood_graph(mdata)
744
+ >>> milo.plot_nhood_graph(mdata)
745
+
746
+ Preview:
747
+ .. image:: /_static/docstring_previews/milo_nhood_graph.png
748
+ """
749
+ nhood_adata = mdata["milo"].T.copy()
750
+
751
+ if "Nhood_size" not in nhood_adata.obs.columns:
752
+ raise KeyError(
753
+ 'Cannot find "Nhood_size" column in adata.uns["nhood_adata"].obs -- \
754
+ please run milopy.utils.build_nhood_graph(adata)'
755
+ )
756
+
757
+ nhood_adata.obs["graph_color"] = nhood_adata.obs["logFC"]
758
+ nhood_adata.obs.loc[nhood_adata.obs["SpatialFDR"] > alpha, "graph_color"] = np.nan
759
+ nhood_adata.obs["abs_logFC"] = abs(nhood_adata.obs["logFC"])
760
+ nhood_adata.obs.loc[nhood_adata.obs["abs_logFC"] < min_logFC, "graph_color"] = np.nan
761
+
762
+ # Plotting order - extreme logFC on top
763
+ nhood_adata.obs.loc[nhood_adata.obs["graph_color"].isna(), "abs_logFC"] = np.nan
764
+ ordered = nhood_adata.obs.sort_values("abs_logFC", na_position="first").index
765
+ nhood_adata = nhood_adata[ordered]
766
+
767
+ vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
768
+ vmin = -vmax
769
+
770
+ sc.pl.embedding(
771
+ nhood_adata,
772
+ "X_milo_graph",
773
+ color="graph_color",
774
+ cmap="RdBu_r",
775
+ size=nhood_adata.obs["Nhood_size"] * min_size,
776
+ edges=plot_edges,
777
+ neighbors_key="nhood",
778
+ sort_order=False,
779
+ frameon=False,
780
+ vmax=vmax,
781
+ vmin=vmin,
782
+ title=title,
783
+ color_map=color_map,
784
+ palette=palette,
785
+ ax=ax,
786
+ show=show,
787
+ save=save,
788
+ **kwargs,
789
+ )
790
+
791
+ def plot_nhood(
792
+ self,
793
+ mdata: MuData,
794
+ ix: int,
795
+ feature_key: str | None = "rna",
796
+ basis: str = "X_umap",
797
+ color_map: Colormap | str | None = None,
798
+ palette: str | Sequence[str] | None = None,
799
+ return_fig: bool | None = None,
800
+ ax: Axes | None = None,
801
+ show: bool | None = None,
802
+ save: bool | str | None = None,
803
+ **kwargs,
804
+ ) -> None:
805
+ """Visualize cells in a neighbourhood.
806
+
807
+ Args:
808
+ mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
809
+ ix: index of neighbourhood to visualize
810
+ basis: Embedding to use for visualization. Defaults to "X_umap".
811
+ show: Show the plot, do not return axis.
812
+ save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
813
+ **kwargs: Additional arguments to `scanpy.pl.embedding`.
814
+
815
+ Examples:
816
+ >>> import pertpy as pt
817
+ >>> import scanpy as sc
818
+ >>> adata = pt.dt.bhattacherjee()
819
+ >>> milo = pt.tl.Milo()
820
+ >>> mdata = milo.load(adata)
821
+ >>> sc.pp.neighbors(mdata["rna"])
822
+ >>> sc.tl.umap(mdata["rna"])
823
+ >>> milo.make_nhoods(mdata["rna"])
824
+ >>> milo.plot_nhood(mdata, ix=0)
825
+
826
+ Preview:
827
+ .. image:: /_static/docstring_previews/milo_nhood.png
828
+ """
829
+ mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
830
+ sc.pl.embedding(
831
+ mdata[feature_key],
832
+ basis,
833
+ color="Nhood",
834
+ size=30,
835
+ title="Nhood" + str(ix),
836
+ color_map=color_map,
837
+ palette=palette,
838
+ return_fig=return_fig,
839
+ ax=ax,
840
+ show=show,
841
+ save=save,
842
+ **kwargs,
843
+ )
844
+
845
+ def plot_da_beeswarm(
846
+ self,
847
+ mdata: MuData,
848
+ feature_key: str | None = "rna",
849
+ anno_col: str = "nhood_annotation",
850
+ alpha: float = 0.1,
851
+ subset_nhoods: list[str] = None,
852
+ palette: str | Sequence[str] | dict[str, str] | None = None,
853
+ return_fig: bool | None = None,
854
+ save: bool | str | None = None,
855
+ show: bool | None = None,
856
+ ) -> None:
857
+ """Plot beeswarm plot of logFC against nhood labels
858
+
859
+ Args:
860
+ mdata: MuData object
861
+ anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
862
+ alpha: Significance threshold. (default: 0.1)
863
+ subset_nhoods: List of nhoods to plot. If None, plot all nhoods. Defaults to None.
864
+ palette: Name of Seaborn color palette for violinplots.
865
+ Defaults to pre-defined category colors for violinplots.
866
+
867
+ Examples:
868
+ >>> import pertpy as pt
869
+ >>> import scanpy as sc
870
+ >>> adata = pt.dt.bhattacherjee()
871
+ >>> milo = pt.tl.Milo()
872
+ >>> mdata = milo.load(adata)
873
+ >>> sc.pp.neighbors(mdata["rna"])
874
+ >>> milo.make_nhoods(mdata["rna"])
875
+ >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
876
+ >>> milo.da_nhoods(mdata, design="~label")
877
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
878
+ >>> milo.plot_da_beeswarm(mdata)
879
+
880
+ Preview:
881
+ .. image:: /_static/docstring_previews/milo_da_beeswarm.png
882
+ """
883
+ try:
884
+ nhood_adata = mdata["milo"].T.copy()
885
+ except KeyError:
886
+ raise RuntimeError(
887
+ "mdata should be a MuData object with two slots: feature_key and 'milo'. Run 'milopy.count_nhoods(adata)' first."
888
+ ) from None
889
+
890
+ try:
891
+ nhood_adata.obs[anno_col]
892
+ except KeyError:
893
+ raise RuntimeError(
894
+ f"Unable to find {anno_col} in mdata['milo'].var. Run 'milopy.utils.annotate_nhoods(adata, anno_col)' first"
895
+ ) from None
896
+
897
+ if subset_nhoods is not None:
898
+ nhood_adata = nhood_adata[nhood_adata.obs[anno_col].isin(subset_nhoods)]
899
+
900
+ try:
901
+ nhood_adata.obs["logFC"]
902
+ except KeyError:
903
+ raise RuntimeError(
904
+ "Unable to find 'logFC' in mdata.uns['nhood_adata'].obs. Run 'core.da_nhoods(adata)' first."
905
+ ) from None
906
+
907
+ sorted_annos = (
908
+ nhood_adata.obs[[anno_col, "logFC"]].groupby(anno_col).median().sort_values("logFC", ascending=True).index
909
+ )
910
+
911
+ anno_df = nhood_adata.obs[[anno_col, "logFC", "SpatialFDR"]].copy()
912
+ anno_df["is_signif"] = anno_df["SpatialFDR"] < alpha
913
+ anno_df = anno_df[anno_df[anno_col] != "nan"]
914
+
915
+ try:
916
+ obs_col = nhood_adata.uns["annotation_obs"]
917
+ if palette is None:
918
+ palette = dict(
919
+ zip(
920
+ mdata[feature_key].obs[obs_col].cat.categories,
921
+ mdata[feature_key].uns[f"{obs_col}_colors"],
922
+ strict=False,
923
+ )
924
+ )
925
+ sns.violinplot(
926
+ data=anno_df,
927
+ y=anno_col,
928
+ x="logFC",
929
+ order=sorted_annos,
930
+ inner=None,
931
+ orient="h",
932
+ palette=palette,
933
+ linewidth=0,
934
+ scale="width",
935
+ )
936
+ except BaseException: # noqa: BLE001
937
+ sns.violinplot(
938
+ data=anno_df,
939
+ y=anno_col,
940
+ x="logFC",
941
+ order=sorted_annos,
942
+ inner=None,
943
+ orient="h",
944
+ linewidth=0,
945
+ scale="width",
946
+ )
947
+ sns.stripplot(
948
+ data=anno_df,
949
+ y=anno_col,
950
+ x="logFC",
951
+ order=sorted_annos,
952
+ size=2,
953
+ hue="is_signif",
954
+ palette=["grey", "black"],
955
+ orient="h",
956
+ alpha=0.5,
957
+ )
958
+ plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
959
+ plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
960
+
961
+ if save:
962
+ plt.savefig(save, bbox_inches="tight")
963
+ if show:
964
+ plt.show()
965
+ if return_fig:
966
+ return plt.gcf()
967
+ if (not show and not save) or (show is None and save is None):
968
+ return plt.gca()
969
+
970
+ def plot_nhood_counts_by_cond(
971
+ self,
972
+ mdata: MuData,
973
+ test_var: str,
974
+ subset_nhoods: list[str] = None,
975
+ log_counts: bool = False,
976
+ return_fig: bool | None = None,
977
+ save: bool | str | None = None,
978
+ show: bool | None = None,
979
+ ) -> None:
980
+ """Plot boxplot of cell numbers vs condition of interest.
981
+
982
+ Args:
983
+ mdata: MuData object storing cell level and nhood level information
984
+ test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
985
+ subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods. Defaults to None.
986
+ log_counts: Whether to plot log1p of cell counts. Defaults to False.
987
+ """
988
+ try:
989
+ nhood_adata = mdata["milo"].T.copy()
990
+ except KeyError:
991
+ raise RuntimeError(
992
+ "mdata should be a MuData object with two slots: feature_key and 'milo'. Run milopy.count_nhoods(mdata) first"
993
+ ) from None
994
+
995
+ if subset_nhoods is None:
996
+ subset_nhoods = nhood_adata.obs_names
997
+
998
+ pl_df = pd.DataFrame(nhood_adata[subset_nhoods].X.A, columns=nhood_adata.var_names).melt(
999
+ var_name=nhood_adata.uns["sample_col"], value_name="n_cells"
1000
+ )
1001
+ pl_df = pd.merge(pl_df, nhood_adata.var)
1002
+ pl_df["log_n_cells"] = np.log1p(pl_df["n_cells"])
1003
+ if not log_counts:
1004
+ sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue")
1005
+ sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3)
1006
+ plt.ylabel("# cells")
1007
+ else:
1008
+ sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue")
1009
+ sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3)
1010
+ plt.ylabel("log(# cells + 1)")
1011
+
1012
+ plt.xticks(rotation=90)
1013
+ plt.xlabel(test_var)
1014
+
1015
+ if save:
1016
+ plt.savefig(save, bbox_inches="tight")
1017
+ if show:
1018
+ plt.show()
1019
+ if return_fig:
1020
+ return plt.gcf()
1021
+ if not (show or save):
1022
+ return plt.gca()