pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,7 +2,7 @@ from collections.abc import Sequence
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
- from scanpy import logging
5
+ from lamin_utils import logger
6
6
  from scipy.sparse import issparse
7
7
 
8
8
  from ._base import LinearModelBase
@@ -27,16 +27,11 @@ class EdgeR(LinearModelBase):
27
27
  # pandas2ri.activate()
28
28
  # rpy2.robjects.numpy2ri.activate()
29
29
  try:
30
- import rpy2.robjects.numpy2ri
31
- import rpy2.robjects.pandas2ri
32
30
  from rpy2 import robjects as ro
33
31
  from rpy2.robjects import numpy2ri, pandas2ri
34
- from rpy2.robjects.conversion import localconverter
32
+ from rpy2.robjects.conversion import get_conversion, localconverter
35
33
  from rpy2.robjects.packages import importr
36
34
 
37
- pandas2ri.activate()
38
- rpy2.robjects.numpy2ri.activate()
39
-
40
35
  except ImportError:
41
36
  raise ImportError("edger requires rpy2 to be installed.") from None
42
37
 
@@ -49,25 +44,30 @@ class EdgeR(LinearModelBase):
49
44
  ) from e
50
45
 
51
46
  # Convert dataframe
52
- with localconverter(ro.default_converter + numpy2ri.converter):
47
+ with localconverter(get_conversion() + numpy2ri.converter):
53
48
  expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
54
49
  if issparse(expr):
55
50
  expr = expr.T.toarray()
56
51
  else:
57
52
  expr = expr.T
58
53
 
59
- expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
54
+ with localconverter(get_conversion() + pandas2ri.converter):
55
+ expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
56
+ samples_r = ro.conversion.py2rpy(self.adata.obs)
60
57
 
61
- dge = edger.DGEList(counts=expr_r, samples=self.adata.obs)
58
+ dge = edger.DGEList(counts=expr_r, samples=samples_r)
62
59
 
63
- logging.info("Calculating NormFactors")
60
+ logger.info("Calculating NormFactors")
64
61
  dge = edger.calcNormFactors(dge)
65
62
 
66
- logging.info("Estimating Dispersions")
67
- dge = edger.estimateDisp(dge, design=self.design)
63
+ with localconverter(get_conversion() + numpy2ri.converter):
64
+ design_r = ro.conversion.py2rpy(self.design.values)
65
+
66
+ logger.info("Estimating Dispersions")
67
+ dge = edger.estimateDisp(dge, design=design_r)
68
68
 
69
- logging.info("Fitting linear model")
70
- fit = edger.glmQLFit(dge, design=self.design, **kwargs)
69
+ logger.info("Fitting linear model")
70
+ fit = edger.glmQLFit(dge, design=design_r, **kwargs)
71
71
 
72
72
  ro.globalenv["fit"] = fit
73
73
  self.fit = fit
@@ -88,11 +88,9 @@ class EdgeR(LinearModelBase):
88
88
  # Fix mask for .fit()
89
89
 
90
90
  try:
91
- import rpy2.robjects.numpy2ri
92
- import rpy2.robjects.pandas2ri
93
91
  from rpy2 import robjects as ro
94
92
  from rpy2.robjects import numpy2ri, pandas2ri
95
- from rpy2.robjects.conversion import localconverter
93
+ from rpy2.robjects.conversion import get_conversion, localconverter
96
94
  from rpy2.robjects.packages import importr
97
95
 
98
96
  except ImportError:
@@ -106,7 +104,8 @@ class EdgeR(LinearModelBase):
106
104
  ) from None
107
105
 
108
106
  # Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
109
- contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
107
+ with localconverter(get_conversion() + numpy2ri.converter):
108
+ contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
110
109
  ro.globalenv["contrast_vec"] = contrast_vec_r
111
110
 
112
111
  # Test contrast with R
@@ -117,8 +116,18 @@ class EdgeR(LinearModelBase):
117
116
  """
118
117
  )
119
118
 
120
- # Convert results to pandas
121
- de_res = ro.conversion.rpy2py(ro.globalenv["de_res"])
119
+ # Retrieve the `de_res` object
120
+ de_res = ro.globalenv["de_res"]
121
+
122
+ # If already a Pandas DataFrame, return it directly
123
+ if isinstance(de_res, pd.DataFrame):
124
+ de_res.index.name = "variable"
125
+ return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
126
+
127
+ # Convert to Pandas DataFrame if still an R object
128
+ with localconverter(get_conversion() + pandas2ri.converter):
129
+ de_res = ro.conversion.rpy2py(de_res)
130
+
122
131
  de_res.index.name = "variable"
123
132
  de_res = de_res.reset_index()
124
133
 
@@ -2,6 +2,7 @@ import os
2
2
  import re
3
3
  import warnings
4
4
 
5
+ import numpy as np
5
6
  import pandas as pd
6
7
  from anndata import AnnData
7
8
  from numpy import ndarray
@@ -40,33 +41,25 @@ class PyDESeq2(LinearModelBase):
40
41
  Args:
41
42
  **kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
42
43
  """
43
- inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", os.cpu_count() - 1))
44
- covars = self.design.columns.tolist()
45
- if "Intercept" not in covars:
46
- warnings.warn(
47
- "Warning: Pydeseq is hard-coded to use Intercept, please include intercept into the model", stacklevel=2
48
- )
49
- processed_covars = list({re.sub(r"\[T\.(.*)\]", "", col) for col in covars if col != "Intercept"})
44
+ try:
45
+ usable_cpus = len(os.sched_getaffinity(0))
46
+ except AttributeError:
47
+ usable_cpus = os.cpu_count()
48
+
49
+ inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", usable_cpus))
50
+
50
51
  dds = DeseqDataSet(
51
- adata=self.adata, design_factors=processed_covars, refit_cooks=True, inference=inference, **kwargs
52
+ adata=self.adata,
53
+ design=self.design, # initialize using design matrix, not formula
54
+ refit_cooks=True,
55
+ inference=inference,
56
+ **kwargs,
52
57
  )
53
- # workaround code to insert design array
54
- des_mtx_cols = dds.obsm["design_matrix"].columns
55
- dds.obsm["design_matrix"] = self.design
56
- if dds.obsm["design_matrix"].shape[1] == len(des_mtx_cols):
57
- dds.obsm["design_matrix"].columns = des_mtx_cols.copy()
58
58
 
59
59
  dds.deseq2()
60
60
  self.dds = dds
61
61
 
62
- # TODO: PyDeseq2 doesn't support arbitrary designs and contrasts yet
63
- # see https://github.com/owkin/PyDESeq2/issues/213
64
-
65
- # Therefore these functions are overridden in a way to make it work with PyDESeq2,
66
- # ingoring the inconsistency of function signatures. Once arbitrary design
67
- # matrices and contrasts are supported by PyDEseq2, we can fully support the
68
- # Linear model interface.
69
- def _test_single_contrast(self, contrast: list[str], alpha=0.05, **kwargs) -> pd.DataFrame: # type: ignore
62
+ def _test_single_contrast(self, contrast, alpha=0.05, **kwargs) -> pd.DataFrame:
70
63
  """Conduct a specific test and returns a Pandas DataFrame.
71
64
 
72
65
  Args:
@@ -74,6 +67,7 @@ class PyDESeq2(LinearModelBase):
74
67
  alpha: p value threshold used for controlling fdr with independent hypothesis weighting
75
68
  **kwargs: extra arguments to pass to DeseqStats()
76
69
  """
70
+ contrast = np.array(contrast)
77
71
  stat_res = DeseqStats(self.dds, contrast=contrast, alpha=alpha, **kwargs)
78
72
  # Calling `.summary()` is required to fill the `results_df` data frame
79
73
  stat_res.summary()
@@ -85,11 +79,3 @@ class PyDESeq2(LinearModelBase):
85
79
  res_df.index.name = "variable"
86
80
  res_df = res_df.reset_index()
87
81
  return res_df
88
-
89
- def cond(self, **kwargs) -> ndarray:
90
- raise NotImplementedError(
91
- "PyDESeq2 currently doesn't support arbitrary contrasts, see https://github.com/owkin/PyDESeq2/issues/213"
92
- )
93
-
94
- def contrast(self, column: str, baseline: str, group_to_compare: str) -> tuple[str, str, str]: # type: ignore
95
- return (column, group_to_compare, baseline)
@@ -59,14 +59,3 @@ class Statsmodels(LinearModelBase):
59
59
  }
60
60
  )
61
61
  return pd.DataFrame(res).sort_values("p_value")
62
-
63
- def contrast(self, column: str, baseline: str, group_to_compare: str) -> np.ndarray:
64
- """Build a simple contrast for pairwise comparisons.
65
-
66
- This is equivalent to
67
-
68
- ```
69
- model.cond(<column> = baseline) - model.cond(<column> = group_to_compare)
70
- ```
71
- """
72
- return self.cond(**{column: baseline}) - self.cond(**{column: group_to_compare})
@@ -344,9 +344,9 @@ class Distance:
344
344
  else:
345
345
  embedding = adata.obsm[self.obsm_key].copy()
346
346
  for index_x, group_x in enumerate(fct(groups)):
347
- cells_x = embedding[grouping == group_x].copy()
347
+ cells_x = embedding[np.asarray(grouping == group_x)].copy()
348
348
  for group_y in groups[index_x:]: # type: ignore
349
- cells_y = embedding[grouping == group_y].copy()
349
+ cells_y = embedding[np.asarray(grouping == group_y)].copy()
350
350
  if not bootstrap:
351
351
  # By distance axiom, the distance between a group and itself is 0
352
352
  dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
@@ -478,9 +478,9 @@ class Distance:
478
478
  else:
479
479
  embedding = adata.obsm[self.obsm_key].copy()
480
480
  for group_x in fct(groups):
481
- cells_x = embedding[grouping == group_x].copy()
481
+ cells_x = embedding[np.asarray(grouping == group_x)].copy()
482
482
  group_y = selected_group
483
- cells_y = embedding[grouping == group_y].copy()
483
+ cells_y = embedding[np.asarray(grouping == group_y)].copy()
484
484
  if not bootstrap:
485
485
  # By distance axiom, the distance between a group and itself is 0
486
486
  dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
@@ -691,17 +691,18 @@ class MMD(AbstractDistance):
691
691
 
692
692
 
693
693
  class WassersteinDistance(AbstractDistance):
694
- """Wasserstein distance metric (solved with entropy regularized Sinkhorn)."""
695
-
696
694
  def __init__(self) -> None:
697
695
  super().__init__()
698
696
  self.accepts_precomputed = False
699
697
 
700
698
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
699
+ X = np.asarray(X, dtype=np.float64)
700
+ Y = np.asarray(Y, dtype=np.float64)
701
701
  geom = PointCloud(X, Y)
702
702
  return self.solve_ot_problem(geom, **kwargs)
703
703
 
704
704
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
705
+ P = np.asarray(P, dtype=np.float64)
705
706
  geom = Geometry(cost_matrix=P[idx, :][:, ~idx])
706
707
  return self.solve_ot_problem(geom, **kwargs)
707
708
 
@@ -709,7 +710,13 @@ class WassersteinDistance(AbstractDistance):
709
710
  ot_prob = LinearProblem(geom)
710
711
  solver = Sinkhorn()
711
712
  ot = solver(ot_prob, **kwargs)
712
- return ot.reg_ot_cost.item()
713
+ cost = float(ot.reg_ot_cost)
714
+
715
+ # Check for NaN or invalid cost
716
+ if not np.isfinite(cost):
717
+ return 1.0
718
+ else:
719
+ return cost
713
720
 
714
721
 
715
722
  class EuclideanDistance(AbstractDistance):
@@ -981,7 +988,7 @@ class NBLL(AbstractDistance):
981
988
  try:
982
989
  nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
983
990
  return _compute_nll(y, nb_params, epsilon)
984
- except np.linalg.linalg.LinAlgError:
991
+ except np.linalg.LinAlgError:
985
992
  if x.mean() < 10 and y.mean() < 10:
986
993
  return 0.0
987
994
  else:
@@ -3,6 +3,7 @@ from collections.abc import Sequence
3
3
  from typing import Any, Literal
4
4
 
5
5
  import blitzgsea
6
+ import matplotlib.pyplot as plt
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  import scanpy as sc
@@ -14,6 +15,7 @@ from scipy.sparse import issparse
14
15
  from scipy.stats import hypergeom
15
16
  from statsmodels.stats.multitest import multipletests
16
17
 
18
+ from pertpy._doc import _doc_params, doc_common_plot_args
17
19
  from pertpy.metadata import Drug
18
20
 
19
21
 
@@ -290,9 +292,11 @@ class Enrichment:
290
292
 
291
293
  return enrichment
292
294
 
295
+ @_doc_params(common_plot_args=doc_common_plot_args)
293
296
  def plot_dotplot(
294
297
  self,
295
298
  adata: AnnData,
299
+ *,
296
300
  targets: dict[str, dict[str, list[str]]] = None,
297
301
  source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
298
302
  category_name: str = "interaction_type",
@@ -300,10 +304,10 @@ class Enrichment:
300
304
  groupby: str = None,
301
305
  key: str = "pertpy_enrichment",
302
306
  ax: Axes | None = None,
303
- save: bool | str | None = None,
304
- show: bool | None = None,
307
+ show: bool = True,
308
+ return_fig: bool = False,
305
309
  **kwargs,
306
- ) -> DotPlot | dict | None:
310
+ ) -> DotPlot | None:
307
311
  """Plots a dotplot by groupby and categories.
308
312
 
309
313
  Wraps scanpy's dotplot but formats it nicely by categories.
@@ -319,11 +323,11 @@ class Enrichment:
319
323
  category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`.
320
324
  groupby: dotplot groupby such as clusters or cell types.
321
325
  key: Prefix key of enrichment results in `uns`.
326
+ {common_plot_args}
322
327
  kwargs: Passed to scanpy dotplot.
323
328
 
324
329
  Returns:
325
- If `return_fig` is `True`, returns a :class:`~scanpy.pl.DotPlot` object,
326
- else if `show` is false, return axes dict.
330
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
327
331
 
328
332
  Examples:
329
333
  >>> import pertpy as pt
@@ -403,21 +407,27 @@ class Enrichment:
403
407
  "var_group_labels": var_group_labels,
404
408
  }
405
409
 
406
- return sc.pl.dotplot(
410
+ fig = sc.pl.dotplot(
407
411
  enrichment_score_adata,
408
412
  groupby=groupby,
409
413
  swap_axes=True,
410
414
  ax=ax,
411
- save=save,
412
- show=show,
415
+ show=False,
413
416
  **plot_args,
414
417
  **kwargs,
415
418
  )
416
419
 
420
+ if show:
421
+ plt.show()
422
+ if return_fig:
423
+ return fig
424
+ return None
425
+
417
426
  def plot_gsea(
418
427
  self,
419
428
  adata: AnnData,
420
429
  enrichment: dict[str, pd.DataFrame],
430
+ *,
421
431
  n: int = 10,
422
432
  key: str = "pertpy_enrichment_gsea",
423
433
  interactive_plot: bool = False,
pertpy/tools/_milo.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  import random
5
4
  import re
6
5
  from typing import TYPE_CHECKING, Literal
@@ -14,6 +13,8 @@ from anndata import AnnData
14
13
  from lamin_utils import logger
15
14
  from mudata import MuData
16
15
 
16
+ from pertpy._doc import _doc_params, doc_common_plot_args
17
+
17
18
  if TYPE_CHECKING:
18
19
  from collections.abc import Sequence
19
20
 
@@ -125,7 +126,7 @@ class Milo:
125
126
  try:
126
127
  use_rep = adata.uns["neighbors"]["params"]["use_rep"]
127
128
  except KeyError:
128
- logging.warning("Using X_pca as default embedding")
129
+ logger.warning("Using X_pca as default embedding")
129
130
  use_rep = "X_pca"
130
131
  try:
131
132
  knn_graph = adata.obsp["connectivities"].copy()
@@ -136,7 +137,7 @@ class Milo:
136
137
  try:
137
138
  use_rep = adata.uns[neighbors_key]["params"]["use_rep"]
138
139
  except KeyError:
139
- logging.warning("Using X_pca as default embedding")
140
+ logger.warning("Using X_pca as default embedding")
140
141
  use_rep = "X_pca"
141
142
  knn_graph = adata.obsp[neighbors_key + "_connectivities"].copy()
142
143
 
@@ -182,7 +183,7 @@ class Milo:
182
183
  knn_dists = adata.obsp[neighbors_key + "_distances"]
183
184
 
184
185
  nhood_ixs = adata.obs["nhood_ixs_refined"] == 1
185
- dist_mat = knn_dists[nhood_ixs, :]
186
+ dist_mat = knn_dists[np.asarray(nhood_ixs), :]
186
187
  k_distances = dist_mat.max(1).toarray().ravel()
187
188
  adata.obs["nhood_kth_distance"] = 0
188
189
  adata.obs["nhood_kth_distance"] = adata.obs["nhood_kth_distance"].astype(float)
@@ -703,8 +704,8 @@ class Milo:
703
704
  pvalues = sample_adata.var["PValue"]
704
705
  keep_nhoods = ~pvalues.isna() # Filtering in case of test on subset of nhoods
705
706
  o = pvalues[keep_nhoods].argsort()
706
- pvalues = pvalues[keep_nhoods][o]
707
- w = w[keep_nhoods][o]
707
+ pvalues = pvalues.loc[keep_nhoods].iloc[o]
708
+ w = w.loc[keep_nhoods].iloc[o]
708
709
 
709
710
  adjp = np.zeros(shape=len(o))
710
711
  adjp[o] = (sum(w) * pvalues / np.cumsum(w))[::-1].cummin()[::-1]
@@ -713,9 +714,11 @@ class Milo:
713
714
  sample_adata.var["SpatialFDR"] = np.nan
714
715
  sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp
715
716
 
717
+ @_doc_params(common_plot_args=doc_common_plot_args)
716
718
  def plot_nhood_graph(
717
719
  self,
718
720
  mdata: MuData,
721
+ *,
719
722
  alpha: float = 0.1,
720
723
  min_logFC: float = 0,
721
724
  min_size: int = 10,
@@ -724,10 +727,10 @@ class Milo:
724
727
  color_map: Colormap | str | None = None,
725
728
  palette: str | Sequence[str] | None = None,
726
729
  ax: Axes | None = None,
727
- show: bool | None = None,
728
- save: bool | str | None = None,
730
+ show: bool = True,
731
+ return_fig: bool = False,
729
732
  **kwargs,
730
- ) -> None:
733
+ ) -> Figure | None:
731
734
  """Visualize DA results on abstracted graph (wrapper around sc.pl.embedding)
732
735
 
733
736
  Args:
@@ -737,9 +740,7 @@ class Milo:
737
740
  min_size: Minimum size of nodes in visualization. (default: 10)
738
741
  plot_edges: If edges for neighbourhood overlaps whould be plotted.
739
742
  title: Plot title.
740
- show: Show the plot, do not return axis.
741
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
742
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
743
+ {common_plot_args}
743
744
  **kwargs: Additional arguments to `scanpy.pl.embedding`.
744
745
 
745
746
  Examples:
@@ -782,7 +783,7 @@ class Milo:
782
783
  vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
783
784
  vmin = -vmax
784
785
 
785
- sc.pl.embedding(
786
+ fig = sc.pl.embedding(
786
787
  nhood_adata,
787
788
  "X_milo_graph",
788
789
  color="graph_color",
@@ -798,33 +799,42 @@ class Milo:
798
799
  color_map=color_map,
799
800
  palette=palette,
800
801
  ax=ax,
801
- show=show,
802
- save=save,
802
+ show=False,
803
803
  **kwargs,
804
804
  )
805
805
 
806
+ if show:
807
+ plt.show()
808
+ if return_fig:
809
+ return fig
810
+ return None
811
+
812
+ @_doc_params(common_plot_args=doc_common_plot_args)
806
813
  def plot_nhood(
807
814
  self,
808
815
  mdata: MuData,
809
816
  ix: int,
817
+ *,
810
818
  feature_key: str | None = "rna",
811
819
  basis: str = "X_umap",
812
820
  color_map: Colormap | str | None = None,
813
821
  palette: str | Sequence[str] | None = None,
814
- return_fig: bool | None = None,
815
822
  ax: Axes | None = None,
816
- show: bool | None = None,
817
- save: bool | str | None = None,
823
+ show: bool = True,
824
+ return_fig: bool = False,
818
825
  **kwargs,
819
- ) -> None:
826
+ ) -> Figure | None:
820
827
  """Visualize cells in a neighbourhood.
821
828
 
822
829
  Args:
823
830
  mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
824
831
  ix: index of neighbourhood to visualize
832
+ feature_key: Key in mdata to the cell-level AnnData object.
825
833
  basis: Embedding to use for visualization.
826
- show: Show the plot, do not return axis.
827
- save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
834
+ color_map: Colormap to use for coloring.
835
+ palette: Color palette to use for coloring.
836
+ ax: Axes to plot on.
837
+ {common_plot_args}
828
838
  **kwargs: Additional arguments to `scanpy.pl.embedding`.
829
839
 
830
840
  Examples:
@@ -842,7 +852,7 @@ class Milo:
842
852
  .. image:: /_static/docstring_previews/milo_nhood.png
843
853
  """
844
854
  mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
845
- sc.pl.embedding(
855
+ fig = sc.pl.embedding(
846
856
  mdata[feature_key],
847
857
  basis,
848
858
  color="Nhood",
@@ -852,32 +862,43 @@ class Milo:
852
862
  palette=palette,
853
863
  return_fig=return_fig,
854
864
  ax=ax,
855
- show=show,
856
- save=save,
865
+ show=False,
857
866
  **kwargs,
858
867
  )
859
868
 
869
+ if show:
870
+ plt.show()
871
+ if return_fig:
872
+ return fig
873
+ return None
874
+
875
+ @_doc_params(common_plot_args=doc_common_plot_args)
860
876
  def plot_da_beeswarm(
861
877
  self,
862
878
  mdata: MuData,
879
+ *,
863
880
  feature_key: str | None = "rna",
864
881
  anno_col: str = "nhood_annotation",
865
882
  alpha: float = 0.1,
866
883
  subset_nhoods: list[str] = None,
867
884
  palette: str | Sequence[str] | dict[str, str] | None = None,
868
- return_fig: bool | None = None,
869
- save: bool | str | None = None,
870
- show: bool | None = None,
871
- ) -> Figure | Axes | None:
885
+ show: bool = True,
886
+ return_fig: bool = False,
887
+ ) -> Figure | None:
872
888
  """Plot beeswarm plot of logFC against nhood labels
873
889
 
874
890
  Args:
875
891
  mdata: MuData object
892
+ feature_key: Key in mdata to the cell-level AnnData object.
876
893
  anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
877
894
  alpha: Significance threshold. (default: 0.1)
878
895
  subset_nhoods: List of nhoods to plot. If None, plot all nhoods.
879
896
  palette: Name of Seaborn color palette for violinplots.
880
897
  Defaults to pre-defined category colors for violinplots.
898
+ {common_plot_args}
899
+
900
+ Returns:
901
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
881
902
 
882
903
  Examples:
883
904
  >>> import pertpy as pt
@@ -973,29 +994,23 @@ class Milo:
973
994
  plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
974
995
  plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
975
996
 
976
- if save:
977
- plt.savefig(save, bbox_inches="tight")
978
- return None
979
997
  if show:
980
998
  plt.show()
981
- return None
982
999
  if return_fig:
983
1000
  return plt.gcf()
984
- if (not show and not save) or (show is None and save is None):
985
- return plt.gca()
986
-
987
1001
  return None
988
1002
 
1003
+ @_doc_params(common_plot_args=doc_common_plot_args)
989
1004
  def plot_nhood_counts_by_cond(
990
1005
  self,
991
1006
  mdata: MuData,
992
1007
  test_var: str,
1008
+ *,
993
1009
  subset_nhoods: list[str] = None,
994
1010
  log_counts: bool = False,
995
- return_fig: bool | None = None,
996
- save: bool | str | None = None,
997
- show: bool | None = None,
998
- ) -> Figure | Axes | None:
1011
+ show: bool = True,
1012
+ return_fig: bool = False,
1013
+ ) -> Figure | None:
999
1014
  """Plot boxplot of cell numbers vs condition of interest.
1000
1015
 
1001
1016
  Args:
@@ -1003,6 +1018,10 @@ class Milo:
1003
1018
  test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
1004
1019
  subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods.
1005
1020
  log_counts: Whether to plot log1p of cell counts.
1021
+ {common_plot_args}
1022
+
1023
+ Returns:
1024
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1006
1025
  """
1007
1026
  try:
1008
1027
  nhood_adata = mdata["milo"].T.copy()
@@ -1031,15 +1050,8 @@ class Milo:
1031
1050
  plt.xticks(rotation=90)
1032
1051
  plt.xlabel(test_var)
1033
1052
 
1034
- if save:
1035
- plt.savefig(save, bbox_inches="tight")
1036
- return None
1037
1053
  if show:
1038
1054
  plt.show()
1039
- return None
1040
1055
  if return_fig:
1041
1056
  return plt.gcf()
1042
- if not (show or save):
1043
- return plt.gca()
1044
-
1045
1057
  return None