pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_milo.py CHANGED
@@ -3,14 +3,23 @@ from __future__ import annotations
3
3
  import logging
4
4
  import random
5
5
  import re
6
- from typing import Literal
6
+ from typing import TYPE_CHECKING, Literal
7
7
 
8
+ import matplotlib.pyplot as plt
8
9
  import numpy as np
9
10
  import pandas as pd
11
+ import scanpy as sc
12
+ import seaborn as sns
10
13
  from anndata import AnnData
11
14
  from mudata import MuData
12
15
  from rich import print
13
16
 
17
+ if TYPE_CHECKING:
18
+ from collections.abc import Sequence
19
+
20
+ from matplotlib.axes import Axes
21
+ from matplotlib.colors import Colormap
22
+
14
23
  try:
15
24
  from rpy2.robjects import conversion, numpy2ri, pandas2ri
16
25
  from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
@@ -39,7 +48,7 @@ class Milo:
39
48
  input: AnnData
40
49
  feature_key: Key to store the cell-level AnnData object in the MuData object
41
50
  Returns:
42
- MuData: MuData object with original AnnData (default is `mudata[feature_key]`).
51
+ MuData: MuData object with original AnnData. Defaults to`mudata[feature_key]`.
43
52
 
44
53
  Examples:
45
54
  >>> import pertpy as pt
@@ -71,11 +80,11 @@ class Milo:
71
80
  neighbors_key: The key in `adata.obsp` or `mdata[feature_key].obsp` to use as KNN graph.
72
81
  If not specified, `make_nhoods` looks .obsp[‘connectivities’] for connectivities (default storage places for `scanpy.pp.neighbors`).
73
82
  If specified, it looks at .obsp[.uns[neighbors_key][‘connectivities_key’]] for connectivities.
74
- (default: None)
75
- feature_key: If input data is MuData, specify key to cell-level AnnData object. (default: 'rna')
76
- prop: Fraction of cells to sample for neighbourhood index search. (default: 0.1)
77
- seed: Random seed for cell sampling. (default: 0)
78
- copy: Determines whether a copy of the `adata` is returned. (default: False)
83
+ Defaults to None.
84
+ feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
85
+ prop: Fraction of cells to sample for neighbourhood index search. Defaults to 0.1.
86
+ seed: Random seed for cell sampling. Defaults to 0.
87
+ copy: Determines whether a copy of the `adata` is returned. Defaults to False.
79
88
 
80
89
  Returns:
81
90
  If `copy=True`, returns the copy of `adata` with the result in `.obs`, `.obsm`, and `.uns`.
@@ -190,7 +199,7 @@ class Milo:
190
199
  Args:
191
200
  data: AnnData object with neighbourhoods defined in `obsm['nhoods']` or MuData object with a modality with neighbourhoods defined in `obsm['nhoods']`
192
201
  sample_col: Column in adata.obs that contains sample information
193
- feature_key: If input data is MuData, specify key to cell-level AnnData object. (default: 'rna')
202
+ feature_key: If input data is MuData, specify key to cell-level AnnData object. Defaults to 'rna'.
194
203
 
195
204
  Returns:
196
205
  MuData object storing the original (i.e. rna) AnnData in `mudata[feature_key]`
@@ -423,7 +432,7 @@ class Milo:
423
432
  >>> sc.pp.neighbors(mdata["rna"])
424
433
  >>> milo.make_nhoods(mdata["rna"])
425
434
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
426
- >>> milo.annotate_nhoods(mdata, anno_col='cell_type')
435
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
427
436
  """
428
437
  try:
429
438
  sample_adata = mdata["milo"]
@@ -474,7 +483,7 @@ class Milo:
474
483
  >>> sc.pp.neighbors(mdata["rna"])
475
484
  >>> milo.make_nhoods(mdata["rna"])
476
485
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
477
- >>> milo.annotate_nhoods_continuous(mdata, anno_col='nUMI')
486
+ >>> milo.annotate_nhoods_continuous(mdata, anno_col="nUMI")
478
487
  """
479
488
  if "milo" not in mdata.mod:
480
489
  raise ValueError(
@@ -663,7 +672,9 @@ class Milo:
663
672
  sample_adata: AnnData,
664
673
  neighbors_key: str | None = None,
665
674
  ):
666
- """FDR correction weighted on inverse of connectivity of neighbourhoods. The distance to the k-th nearest neighbor is used as a measure of connectivity.
675
+ """FDR correction weighted on inverse of connectivity of neighbourhoods.
676
+
677
+ The distance to the k-th nearest neighbor is used as a measure of connectivity.
667
678
 
668
679
  Args:
669
680
  sample_adata: Sample-level AnnData.
@@ -686,3 +697,326 @@ class Milo:
686
697
 
687
698
  sample_adata.var["SpatialFDR"] = np.nan
688
699
  sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp
700
+
701
+ def plot_nhood_graph(
702
+ self,
703
+ mdata: MuData,
704
+ alpha: float = 0.1,
705
+ min_logFC: float = 0,
706
+ min_size: int = 10,
707
+ plot_edges: bool = False,
708
+ title: str = "DA log-Fold Change",
709
+ color_map: Colormap | str | None = None,
710
+ palette: str | Sequence[str] | None = None,
711
+ ax: Axes | None = None,
712
+ show: bool | None = None,
713
+ save: bool | str | None = None,
714
+ **kwargs,
715
+ ) -> None:
716
+ """Visualize DA results on abstracted graph (wrapper around sc.pl.embedding)
717
+
718
+ Args:
719
+ mdata: MuData object
720
+ alpha: Significance threshold. (default: 0.1)
721
+ min_logFC: Minimum absolute log-Fold Change to show results. If is 0, show all significant neighbourhoods. Defaults to 0.
722
+ min_size: Minimum size of nodes in visualization. (default: 10)
723
+ plot_edges: If edges for neighbourhood overlaps whould be plotted. Defaults to False.
724
+ title: Plot title. Defaults to "DA log-Fold Change".
725
+ show: Show the plot, do not return axis.
726
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
727
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
728
+ **kwargs: Additional arguments to `scanpy.pl.embedding`.
729
+
730
+ Examples:
731
+ >>> import pertpy as pt
732
+ >>> import scanpy as sc
733
+ >>> adata = pt.dt.bhattacherjee()
734
+ >>> milo = pt.tl.Milo()
735
+ >>> mdata = milo.load(adata)
736
+ >>> sc.pp.neighbors(mdata["rna"])
737
+ >>> sc.tl.umap(mdata["rna"])
738
+ >>> milo.make_nhoods(mdata["rna"])
739
+ >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
740
+ >>> milo.da_nhoods(mdata,
741
+ >>> design='~label',
742
+ >>> model_contrasts='labelwithdraw_15d_Cocaine-labelwithdraw_48h_Cocaine')
743
+ >>> milo.build_nhood_graph(mdata)
744
+ >>> milo.plot_nhood_graph(mdata)
745
+
746
+ Preview:
747
+ .. image:: /_static/docstring_previews/milo_nhood_graph.png
748
+ """
749
+ nhood_adata = mdata["milo"].T.copy()
750
+
751
+ if "Nhood_size" not in nhood_adata.obs.columns:
752
+ raise KeyError(
753
+ 'Cannot find "Nhood_size" column in adata.uns["nhood_adata"].obs -- \
754
+ please run milopy.utils.build_nhood_graph(adata)'
755
+ )
756
+
757
+ nhood_adata.obs["graph_color"] = nhood_adata.obs["logFC"]
758
+ nhood_adata.obs.loc[nhood_adata.obs["SpatialFDR"] > alpha, "graph_color"] = np.nan
759
+ nhood_adata.obs["abs_logFC"] = abs(nhood_adata.obs["logFC"])
760
+ nhood_adata.obs.loc[nhood_adata.obs["abs_logFC"] < min_logFC, "graph_color"] = np.nan
761
+
762
+ # Plotting order - extreme logFC on top
763
+ nhood_adata.obs.loc[nhood_adata.obs["graph_color"].isna(), "abs_logFC"] = np.nan
764
+ ordered = nhood_adata.obs.sort_values("abs_logFC", na_position="first").index
765
+ nhood_adata = nhood_adata[ordered]
766
+
767
+ vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
768
+ vmin = -vmax
769
+
770
+ sc.pl.embedding(
771
+ nhood_adata,
772
+ "X_milo_graph",
773
+ color="graph_color",
774
+ cmap="RdBu_r",
775
+ size=nhood_adata.obs["Nhood_size"] * min_size,
776
+ edges=plot_edges,
777
+ neighbors_key="nhood",
778
+ sort_order=False,
779
+ frameon=False,
780
+ vmax=vmax,
781
+ vmin=vmin,
782
+ title=title,
783
+ color_map=color_map,
784
+ palette=palette,
785
+ ax=ax,
786
+ show=show,
787
+ save=save,
788
+ **kwargs,
789
+ )
790
+
791
+ def plot_nhood(
792
+ self,
793
+ mdata: MuData,
794
+ ix: int,
795
+ feature_key: str | None = "rna",
796
+ basis: str = "X_umap",
797
+ color_map: Colormap | str | None = None,
798
+ palette: str | Sequence[str] | None = None,
799
+ return_fig: bool | None = None,
800
+ ax: Axes | None = None,
801
+ show: bool | None = None,
802
+ save: bool | str | None = None,
803
+ **kwargs,
804
+ ) -> None:
805
+ """Visualize cells in a neighbourhood.
806
+
807
+ Args:
808
+ mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
809
+ ix: index of neighbourhood to visualize
810
+ basis: Embedding to use for visualization. Defaults to "X_umap".
811
+ show: Show the plot, do not return axis.
812
+ save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
813
+ **kwargs: Additional arguments to `scanpy.pl.embedding`.
814
+
815
+ Examples:
816
+ >>> import pertpy as pt
817
+ >>> import scanpy as sc
818
+ >>> adata = pt.dt.bhattacherjee()
819
+ >>> milo = pt.tl.Milo()
820
+ >>> mdata = milo.load(adata)
821
+ >>> sc.pp.neighbors(mdata["rna"])
822
+ >>> sc.tl.umap(mdata["rna"])
823
+ >>> milo.make_nhoods(mdata["rna"])
824
+ >>> milo.plot_nhood(mdata, ix=0)
825
+
826
+ Preview:
827
+ .. image:: /_static/docstring_previews/milo_nhood.png
828
+ """
829
+ mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
830
+ sc.pl.embedding(
831
+ mdata[feature_key],
832
+ basis,
833
+ color="Nhood",
834
+ size=30,
835
+ title="Nhood" + str(ix),
836
+ color_map=color_map,
837
+ palette=palette,
838
+ return_fig=return_fig,
839
+ ax=ax,
840
+ show=show,
841
+ save=save,
842
+ **kwargs,
843
+ )
844
+
845
+ def plot_da_beeswarm(
846
+ self,
847
+ mdata: MuData,
848
+ feature_key: str | None = "rna",
849
+ anno_col: str = "nhood_annotation",
850
+ alpha: float = 0.1,
851
+ subset_nhoods: list[str] = None,
852
+ palette: str | Sequence[str] | dict[str, str] | None = None,
853
+ return_fig: bool | None = None,
854
+ save: bool | str | None = None,
855
+ show: bool | None = None,
856
+ ) -> None:
857
+ """Plot beeswarm plot of logFC against nhood labels
858
+
859
+ Args:
860
+ mdata: MuData object
861
+ anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
862
+ alpha: Significance threshold. (default: 0.1)
863
+ subset_nhoods: List of nhoods to plot. If None, plot all nhoods. Defaults to None.
864
+ palette: Name of Seaborn color palette for violinplots.
865
+ Defaults to pre-defined category colors for violinplots.
866
+
867
+ Examples:
868
+ >>> import pertpy as pt
869
+ >>> import scanpy as sc
870
+ >>> adata = pt.dt.bhattacherjee()
871
+ >>> milo = pt.tl.Milo()
872
+ >>> mdata = milo.load(adata)
873
+ >>> sc.pp.neighbors(mdata["rna"])
874
+ >>> milo.make_nhoods(mdata["rna"])
875
+ >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
876
+ >>> milo.da_nhoods(mdata, design="~label")
877
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
878
+ >>> milo.plot_da_beeswarm(mdata)
879
+
880
+ Preview:
881
+ .. image:: /_static/docstring_previews/milo_da_beeswarm.png
882
+ """
883
+ try:
884
+ nhood_adata = mdata["milo"].T.copy()
885
+ except KeyError:
886
+ raise RuntimeError(
887
+ "mdata should be a MuData object with two slots: feature_key and 'milo'. Run 'milopy.count_nhoods(adata)' first."
888
+ ) from None
889
+
890
+ try:
891
+ nhood_adata.obs[anno_col]
892
+ except KeyError:
893
+ raise RuntimeError(
894
+ f"Unable to find {anno_col} in mdata['milo'].var. Run 'milopy.utils.annotate_nhoods(adata, anno_col)' first"
895
+ ) from None
896
+
897
+ if subset_nhoods is not None:
898
+ nhood_adata = nhood_adata[nhood_adata.obs[anno_col].isin(subset_nhoods)]
899
+
900
+ try:
901
+ nhood_adata.obs["logFC"]
902
+ except KeyError:
903
+ raise RuntimeError(
904
+ "Unable to find 'logFC' in mdata.uns['nhood_adata'].obs. Run 'core.da_nhoods(adata)' first."
905
+ ) from None
906
+
907
+ sorted_annos = (
908
+ nhood_adata.obs[[anno_col, "logFC"]].groupby(anno_col).median().sort_values("logFC", ascending=True).index
909
+ )
910
+
911
+ anno_df = nhood_adata.obs[[anno_col, "logFC", "SpatialFDR"]].copy()
912
+ anno_df["is_signif"] = anno_df["SpatialFDR"] < alpha
913
+ anno_df = anno_df[anno_df[anno_col] != "nan"]
914
+
915
+ try:
916
+ obs_col = nhood_adata.uns["annotation_obs"]
917
+ if palette is None:
918
+ palette = dict(
919
+ zip(
920
+ mdata[feature_key].obs[obs_col].cat.categories,
921
+ mdata[feature_key].uns[f"{obs_col}_colors"],
922
+ strict=False,
923
+ )
924
+ )
925
+ sns.violinplot(
926
+ data=anno_df,
927
+ y=anno_col,
928
+ x="logFC",
929
+ order=sorted_annos,
930
+ inner=None,
931
+ orient="h",
932
+ palette=palette,
933
+ linewidth=0,
934
+ scale="width",
935
+ )
936
+ except BaseException: # noqa: BLE001
937
+ sns.violinplot(
938
+ data=anno_df,
939
+ y=anno_col,
940
+ x="logFC",
941
+ order=sorted_annos,
942
+ inner=None,
943
+ orient="h",
944
+ linewidth=0,
945
+ scale="width",
946
+ )
947
+ sns.stripplot(
948
+ data=anno_df,
949
+ y=anno_col,
950
+ x="logFC",
951
+ order=sorted_annos,
952
+ size=2,
953
+ hue="is_signif",
954
+ palette=["grey", "black"],
955
+ orient="h",
956
+ alpha=0.5,
957
+ )
958
+ plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
959
+ plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
960
+
961
+ if save:
962
+ plt.savefig(save, bbox_inches="tight")
963
+ if show:
964
+ plt.show()
965
+ if return_fig:
966
+ return plt.gcf()
967
+ if (not show and not save) or (show is None and save is None):
968
+ return plt.gca()
969
+
970
+ def plot_nhood_counts_by_cond(
971
+ self,
972
+ mdata: MuData,
973
+ test_var: str,
974
+ subset_nhoods: list[str] = None,
975
+ log_counts: bool = False,
976
+ return_fig: bool | None = None,
977
+ save: bool | str | None = None,
978
+ show: bool | None = None,
979
+ ) -> None:
980
+ """Plot boxplot of cell numbers vs condition of interest.
981
+
982
+ Args:
983
+ mdata: MuData object storing cell level and nhood level information
984
+ test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
985
+ subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods. Defaults to None.
986
+ log_counts: Whether to plot log1p of cell counts. Defaults to False.
987
+ """
988
+ try:
989
+ nhood_adata = mdata["milo"].T.copy()
990
+ except KeyError:
991
+ raise RuntimeError(
992
+ "mdata should be a MuData object with two slots: feature_key and 'milo'. Run milopy.count_nhoods(mdata) first"
993
+ ) from None
994
+
995
+ if subset_nhoods is None:
996
+ subset_nhoods = nhood_adata.obs_names
997
+
998
+ pl_df = pd.DataFrame(nhood_adata[subset_nhoods].X.A, columns=nhood_adata.var_names).melt(
999
+ var_name=nhood_adata.uns["sample_col"], value_name="n_cells"
1000
+ )
1001
+ pl_df = pd.merge(pl_df, nhood_adata.var)
1002
+ pl_df["log_n_cells"] = np.log1p(pl_df["n_cells"])
1003
+ if not log_counts:
1004
+ sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue")
1005
+ sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3)
1006
+ plt.ylabel("# cells")
1007
+ else:
1008
+ sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue")
1009
+ sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3)
1010
+ plt.ylabel("log(# cells + 1)")
1011
+
1012
+ plt.xticks(rotation=90)
1013
+ plt.xlabel(test_var)
1014
+
1015
+ if save:
1016
+ plt.savefig(save, bbox_inches="tight")
1017
+ if show:
1018
+ plt.show()
1019
+ if return_fig:
1020
+ return plt.gcf()
1021
+ if not (show or save):
1022
+ return plt.gca()