pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
pertpy/tools/_milo.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  import random
5
4
  import re
6
5
  from typing import TYPE_CHECKING, Literal
@@ -14,6 +13,8 @@ from anndata import AnnData
14
13
  from lamin_utils import logger
15
14
  from mudata import MuData
16
15
 
16
+ from pertpy._doc import _doc_params, doc_common_plot_args
17
+
17
18
  if TYPE_CHECKING:
18
19
  from collections.abc import Sequence
19
20
 
@@ -125,7 +126,7 @@ class Milo:
125
126
  try:
126
127
  use_rep = adata.uns["neighbors"]["params"]["use_rep"]
127
128
  except KeyError:
128
- logging.warning("Using X_pca as default embedding")
129
+ logger.warning("Using X_pca as default embedding")
129
130
  use_rep = "X_pca"
130
131
  try:
131
132
  knn_graph = adata.obsp["connectivities"].copy()
@@ -136,7 +137,7 @@ class Milo:
136
137
  try:
137
138
  use_rep = adata.uns[neighbors_key]["params"]["use_rep"]
138
139
  except KeyError:
139
- logging.warning("Using X_pca as default embedding")
140
+ logger.warning("Using X_pca as default embedding")
140
141
  use_rep = "X_pca"
141
142
  knn_graph = adata.obsp[neighbors_key + "_connectivities"].copy()
142
143
 
@@ -182,7 +183,7 @@ class Milo:
182
183
  knn_dists = adata.obsp[neighbors_key + "_distances"]
183
184
 
184
185
  nhood_ixs = adata.obs["nhood_ixs_refined"] == 1
185
- dist_mat = knn_dists[nhood_ixs, :]
186
+ dist_mat = knn_dists[np.asarray(nhood_ixs), :]
186
187
  k_distances = dist_mat.max(1).toarray().ravel()
187
188
  adata.obs["nhood_kth_distance"] = 0
188
189
  adata.obs["nhood_kth_distance"] = adata.obs["nhood_kth_distance"].astype(float)
@@ -703,8 +704,8 @@ class Milo:
703
704
  pvalues = sample_adata.var["PValue"]
704
705
  keep_nhoods = ~pvalues.isna() # Filtering in case of test on subset of nhoods
705
706
  o = pvalues[keep_nhoods].argsort()
706
- pvalues = pvalues[keep_nhoods][o]
707
- w = w[keep_nhoods][o]
707
+ pvalues = pvalues.loc[keep_nhoods].iloc[o]
708
+ w = w.loc[keep_nhoods].iloc[o]
708
709
 
709
710
  adjp = np.zeros(shape=len(o))
710
711
  adjp[o] = (sum(w) * pvalues / np.cumsum(w))[::-1].cummin()[::-1]
@@ -713,9 +714,11 @@ class Milo:
713
714
  sample_adata.var["SpatialFDR"] = np.nan
714
715
  sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp
715
716
 
717
+ @_doc_params(common_plot_args=doc_common_plot_args)
716
718
  def plot_nhood_graph(
717
719
  self,
718
720
  mdata: MuData,
721
+ *,
719
722
  alpha: float = 0.1,
720
723
  min_logFC: float = 0,
721
724
  min_size: int = 10,
@@ -724,10 +727,9 @@ class Milo:
724
727
  color_map: Colormap | str | None = None,
725
728
  palette: str | Sequence[str] | None = None,
726
729
  ax: Axes | None = None,
727
- show: bool | None = None,
728
- save: bool | str | None = None,
730
+ return_fig: bool = False,
729
731
  **kwargs,
730
- ) -> None:
732
+ ) -> Figure | None:
731
733
  """Visualize DA results on abstracted graph (wrapper around sc.pl.embedding)
732
734
 
733
735
  Args:
@@ -737,9 +739,7 @@ class Milo:
737
739
  min_size: Minimum size of nodes in visualization. (default: 10)
738
740
  plot_edges: If edges for neighbourhood overlaps whould be plotted.
739
741
  title: Plot title.
740
- show: Show the plot, do not return axis.
741
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
742
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
742
+ {common_plot_args}
743
743
  **kwargs: Additional arguments to `scanpy.pl.embedding`.
744
744
 
745
745
  Examples:
@@ -782,7 +782,7 @@ class Milo:
782
782
  vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
783
783
  vmin = -vmax
784
784
 
785
- sc.pl.embedding(
785
+ fig = sc.pl.embedding(
786
786
  nhood_adata,
787
787
  "X_milo_graph",
788
788
  color="graph_color",
@@ -798,33 +798,40 @@ class Milo:
798
798
  color_map=color_map,
799
799
  palette=palette,
800
800
  ax=ax,
801
- show=show,
802
- save=save,
801
+ show=False,
803
802
  **kwargs,
804
803
  )
805
804
 
805
+ if return_fig:
806
+ return fig
807
+ plt.show()
808
+ return None
809
+
810
+ @_doc_params(common_plot_args=doc_common_plot_args)
806
811
  def plot_nhood(
807
812
  self,
808
813
  mdata: MuData,
809
814
  ix: int,
815
+ *,
810
816
  feature_key: str | None = "rna",
811
817
  basis: str = "X_umap",
812
818
  color_map: Colormap | str | None = None,
813
819
  palette: str | Sequence[str] | None = None,
814
- return_fig: bool | None = None,
815
820
  ax: Axes | None = None,
816
- show: bool | None = None,
817
- save: bool | str | None = None,
821
+ return_fig: bool = False,
818
822
  **kwargs,
819
- ) -> None:
823
+ ) -> Figure | None:
820
824
  """Visualize cells in a neighbourhood.
821
825
 
822
826
  Args:
823
827
  mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
824
828
  ix: index of neighbourhood to visualize
829
+ feature_key: Key in mdata to the cell-level AnnData object.
825
830
  basis: Embedding to use for visualization.
826
- show: Show the plot, do not return axis.
827
- save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
831
+ color_map: Colormap to use for coloring.
832
+ palette: Color palette to use for coloring.
833
+ ax: Axes to plot on.
834
+ {common_plot_args}
828
835
  **kwargs: Additional arguments to `scanpy.pl.embedding`.
829
836
 
830
837
  Examples:
@@ -842,7 +849,7 @@ class Milo:
842
849
  .. image:: /_static/docstring_previews/milo_nhood.png
843
850
  """
844
851
  mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
845
- sc.pl.embedding(
852
+ fig = sc.pl.embedding(
846
853
  mdata[feature_key],
847
854
  basis,
848
855
  color="Nhood",
@@ -852,32 +859,41 @@ class Milo:
852
859
  palette=palette,
853
860
  return_fig=return_fig,
854
861
  ax=ax,
855
- show=show,
856
- save=save,
862
+ show=False,
857
863
  **kwargs,
858
864
  )
859
865
 
866
+ if return_fig:
867
+ return fig
868
+ plt.show()
869
+ return None
870
+
871
+ @_doc_params(common_plot_args=doc_common_plot_args)
860
872
  def plot_da_beeswarm(
861
873
  self,
862
874
  mdata: MuData,
875
+ *,
863
876
  feature_key: str | None = "rna",
864
877
  anno_col: str = "nhood_annotation",
865
878
  alpha: float = 0.1,
866
879
  subset_nhoods: list[str] = None,
867
880
  palette: str | Sequence[str] | dict[str, str] | None = None,
868
- return_fig: bool | None = None,
869
- save: bool | str | None = None,
870
- show: bool | None = None,
871
- ) -> Figure | Axes | None:
881
+ return_fig: bool = False,
882
+ ) -> Figure | None:
872
883
  """Plot beeswarm plot of logFC against nhood labels
873
884
 
874
885
  Args:
875
886
  mdata: MuData object
887
+ feature_key: Key in mdata to the cell-level AnnData object.
876
888
  anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
877
889
  alpha: Significance threshold. (default: 0.1)
878
890
  subset_nhoods: List of nhoods to plot. If None, plot all nhoods.
879
891
  palette: Name of Seaborn color palette for violinplots.
880
892
  Defaults to pre-defined category colors for violinplots.
893
+ {common_plot_args}
894
+
895
+ Returns:
896
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
881
897
 
882
898
  Examples:
883
899
  >>> import pertpy as pt
@@ -973,29 +989,21 @@ class Milo:
973
989
  plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
974
990
  plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
975
991
 
976
- if save:
977
- plt.savefig(save, bbox_inches="tight")
978
- return None
979
- if show:
980
- plt.show()
981
- return None
982
992
  if return_fig:
983
993
  return plt.gcf()
984
- if (not show and not save) or (show is None and save is None):
985
- return plt.gca()
986
-
994
+ plt.show()
987
995
  return None
988
996
 
997
+ @_doc_params(common_plot_args=doc_common_plot_args)
989
998
  def plot_nhood_counts_by_cond(
990
999
  self,
991
1000
  mdata: MuData,
992
1001
  test_var: str,
1002
+ *,
993
1003
  subset_nhoods: list[str] = None,
994
1004
  log_counts: bool = False,
995
- return_fig: bool | None = None,
996
- save: bool | str | None = None,
997
- show: bool | None = None,
998
- ) -> Figure | Axes | None:
1005
+ return_fig: bool = False,
1006
+ ) -> Figure | None:
999
1007
  """Plot boxplot of cell numbers vs condition of interest.
1000
1008
 
1001
1009
  Args:
@@ -1003,6 +1011,10 @@ class Milo:
1003
1011
  test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
1004
1012
  subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods.
1005
1013
  log_counts: Whether to plot log1p of cell counts.
1014
+ {common_plot_args}
1015
+
1016
+ Returns:
1017
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1006
1018
  """
1007
1019
  try:
1008
1020
  nhood_adata = mdata["milo"].T.copy()
@@ -1031,15 +1043,7 @@ class Milo:
1031
1043
  plt.xticks(rotation=90)
1032
1044
  plt.xlabel(test_var)
1033
1045
 
1034
- if save:
1035
- plt.savefig(save, bbox_inches="tight")
1036
- return None
1037
- if show:
1038
- plt.show()
1039
- return None
1040
1046
  if return_fig:
1041
1047
  return plt.gcf()
1042
- if not (show or save):
1043
- return plt.gca()
1044
-
1048
+ plt.show()
1045
1049
  return None