pertpy 0.9.4__py3-none-any.whl → 0.9.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,7 @@ from collections.abc import Sequence
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
- from scanpy import logging
5
+ from lamin_utils import logger
6
6
  from scipy.sparse import issparse
7
7
 
8
8
  from ._base import LinearModelBase
@@ -27,16 +27,11 @@ class EdgeR(LinearModelBase):
27
27
  # pandas2ri.activate()
28
28
  # rpy2.robjects.numpy2ri.activate()
29
29
  try:
30
- import rpy2.robjects.numpy2ri
31
- import rpy2.robjects.pandas2ri
32
30
  from rpy2 import robjects as ro
33
31
  from rpy2.robjects import numpy2ri, pandas2ri
34
- from rpy2.robjects.conversion import localconverter
32
+ from rpy2.robjects.conversion import get_conversion, localconverter
35
33
  from rpy2.robjects.packages import importr
36
34
 
37
- pandas2ri.activate()
38
- rpy2.robjects.numpy2ri.activate()
39
-
40
35
  except ImportError:
41
36
  raise ImportError("edger requires rpy2 to be installed.") from None
42
37
 
@@ -49,25 +44,30 @@ class EdgeR(LinearModelBase):
49
44
  ) from e
50
45
 
51
46
  # Convert dataframe
52
- with localconverter(ro.default_converter + numpy2ri.converter):
47
+ with localconverter(get_conversion() + numpy2ri.converter):
53
48
  expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
54
49
  if issparse(expr):
55
50
  expr = expr.T.toarray()
56
51
  else:
57
52
  expr = expr.T
58
53
 
59
- expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
54
+ with localconverter(get_conversion() + pandas2ri.converter):
55
+ expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
56
+ samples_r = ro.conversion.py2rpy(self.adata.obs)
60
57
 
61
- dge = edger.DGEList(counts=expr_r, samples=self.adata.obs)
58
+ dge = edger.DGEList(counts=expr_r, samples=samples_r)
62
59
 
63
- logging.info("Calculating NormFactors")
60
+ logger.info("Calculating NormFactors")
64
61
  dge = edger.calcNormFactors(dge)
65
62
 
66
- logging.info("Estimating Dispersions")
67
- dge = edger.estimateDisp(dge, design=self.design)
63
+ with localconverter(get_conversion() + numpy2ri.converter):
64
+ design_r = ro.conversion.py2rpy(self.design.values)
65
+
66
+ logger.info("Estimating Dispersions")
67
+ dge = edger.estimateDisp(dge, design=design_r)
68
68
 
69
- logging.info("Fitting linear model")
70
- fit = edger.glmQLFit(dge, design=self.design, **kwargs)
69
+ logger.info("Fitting linear model")
70
+ fit = edger.glmQLFit(dge, design=design_r, **kwargs)
71
71
 
72
72
  ro.globalenv["fit"] = fit
73
73
  self.fit = fit
@@ -88,11 +88,9 @@ class EdgeR(LinearModelBase):
88
88
  # Fix mask for .fit()
89
89
 
90
90
  try:
91
- import rpy2.robjects.numpy2ri
92
- import rpy2.robjects.pandas2ri
93
91
  from rpy2 import robjects as ro
94
92
  from rpy2.robjects import numpy2ri, pandas2ri
95
- from rpy2.robjects.conversion import localconverter
93
+ from rpy2.robjects.conversion import get_conversion, localconverter
96
94
  from rpy2.robjects.packages import importr
97
95
 
98
96
  except ImportError:
@@ -106,7 +104,8 @@ class EdgeR(LinearModelBase):
106
104
  ) from None
107
105
 
108
106
  # Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
109
- contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
107
+ with localconverter(get_conversion() + numpy2ri.converter):
108
+ contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
110
109
  ro.globalenv["contrast_vec"] = contrast_vec_r
111
110
 
112
111
  # Test contrast with R
@@ -117,8 +116,18 @@ class EdgeR(LinearModelBase):
117
116
  """
118
117
  )
119
118
 
120
- # Convert results to pandas
121
- de_res = ro.conversion.rpy2py(ro.globalenv["de_res"])
119
+ # Retrieve the `de_res` object
120
+ de_res = ro.globalenv["de_res"]
121
+
122
+ # If already a Pandas DataFrame, return it directly
123
+ if isinstance(de_res, pd.DataFrame):
124
+ de_res.index.name = "variable"
125
+ return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
126
+
127
+ # Convert to Pandas DataFrame if still an R object
128
+ with localconverter(get_conversion() + pandas2ri.converter):
129
+ de_res = ro.conversion.rpy2py(de_res)
130
+
122
131
  de_res.index.name = "variable"
123
132
  de_res = de_res.reset_index()
124
133
 
@@ -2,6 +2,7 @@ import os
2
2
  import re
3
3
  import warnings
4
4
 
5
+ import numpy as np
5
6
  import pandas as pd
6
7
  from anndata import AnnData
7
8
  from numpy import ndarray
@@ -40,33 +41,25 @@ class PyDESeq2(LinearModelBase):
40
41
  Args:
41
42
  **kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
42
43
  """
43
- inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", os.cpu_count() - 1))
44
- covars = self.design.columns.tolist()
45
- if "Intercept" not in covars:
46
- warnings.warn(
47
- "Warning: Pydeseq is hard-coded to use Intercept, please include intercept into the model", stacklevel=2
48
- )
49
- processed_covars = list({re.sub(r"\[T\.(.*)\]", "", col) for col in covars if col != "Intercept"})
44
+ try:
45
+ usable_cpus = len(os.sched_getaffinity(0))
46
+ except AttributeError:
47
+ usable_cpus = os.cpu_count()
48
+
49
+ inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", usable_cpus))
50
+
50
51
  dds = DeseqDataSet(
51
- adata=self.adata, design_factors=processed_covars, refit_cooks=True, inference=inference, **kwargs
52
+ adata=self.adata,
53
+ design=self.design, # initialize using design matrix, not formula
54
+ refit_cooks=True,
55
+ inference=inference,
56
+ **kwargs,
52
57
  )
53
- # workaround code to insert design array
54
- des_mtx_cols = dds.obsm["design_matrix"].columns
55
- dds.obsm["design_matrix"] = self.design
56
- if dds.obsm["design_matrix"].shape[1] == len(des_mtx_cols):
57
- dds.obsm["design_matrix"].columns = des_mtx_cols.copy()
58
58
 
59
59
  dds.deseq2()
60
60
  self.dds = dds
61
61
 
62
- # TODO: PyDeseq2 doesn't support arbitrary designs and contrasts yet
63
- # see https://github.com/owkin/PyDESeq2/issues/213
64
-
65
- # Therefore these functions are overridden in a way to make it work with PyDESeq2,
66
- # ingoring the inconsistency of function signatures. Once arbitrary design
67
- # matrices and contrasts are supported by PyDEseq2, we can fully support the
68
- # Linear model interface.
69
- def _test_single_contrast(self, contrast: list[str], alpha=0.05, **kwargs) -> pd.DataFrame: # type: ignore
62
+ def _test_single_contrast(self, contrast, alpha=0.05, **kwargs) -> pd.DataFrame:
70
63
  """Conduct a specific test and returns a Pandas DataFrame.
71
64
 
72
65
  Args:
@@ -74,6 +67,7 @@ class PyDESeq2(LinearModelBase):
74
67
  alpha: p value threshold used for controlling fdr with independent hypothesis weighting
75
68
  **kwargs: extra arguments to pass to DeseqStats()
76
69
  """
70
+ contrast = np.array(contrast)
77
71
  stat_res = DeseqStats(self.dds, contrast=contrast, alpha=alpha, **kwargs)
78
72
  # Calling `.summary()` is required to fill the `results_df` data frame
79
73
  stat_res.summary()
@@ -85,11 +79,3 @@ class PyDESeq2(LinearModelBase):
85
79
  res_df.index.name = "variable"
86
80
  res_df = res_df.reset_index()
87
81
  return res_df
88
-
89
- def cond(self, **kwargs) -> ndarray:
90
- raise NotImplementedError(
91
- "PyDESeq2 currently doesn't support arbitrary contrasts, see https://github.com/owkin/PyDESeq2/issues/213"
92
- )
93
-
94
- def contrast(self, column: str, baseline: str, group_to_compare: str) -> tuple[str, str, str]: # type: ignore
95
- return (column, group_to_compare, baseline)
@@ -59,14 +59,3 @@ class Statsmodels(LinearModelBase):
59
59
  }
60
60
  )
61
61
  return pd.DataFrame(res).sort_values("p_value")
62
-
63
- def contrast(self, column: str, baseline: str, group_to_compare: str) -> np.ndarray:
64
- """Build a simple contrast for pairwise comparisons.
65
-
66
- This is equivalent to
67
-
68
- ```
69
- model.cond(<column> = baseline) - model.cond(<column> = group_to_compare)
70
- ```
71
- """
72
- return self.cond(**{column: baseline}) - self.cond(**{column: group_to_compare})
@@ -344,9 +344,9 @@ class Distance:
344
344
  else:
345
345
  embedding = adata.obsm[self.obsm_key].copy()
346
346
  for index_x, group_x in enumerate(fct(groups)):
347
- cells_x = embedding[grouping == group_x].copy()
347
+ cells_x = embedding[np.asarray(grouping == group_x)].copy()
348
348
  for group_y in groups[index_x:]: # type: ignore
349
- cells_y = embedding[grouping == group_y].copy()
349
+ cells_y = embedding[np.asarray(grouping == group_y)].copy()
350
350
  if not bootstrap:
351
351
  # By distance axiom, the distance between a group and itself is 0
352
352
  dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
@@ -478,9 +478,9 @@ class Distance:
478
478
  else:
479
479
  embedding = adata.obsm[self.obsm_key].copy()
480
480
  for group_x in fct(groups):
481
- cells_x = embedding[grouping == group_x].copy()
481
+ cells_x = embedding[np.asarray(grouping == group_x)].copy()
482
482
  group_y = selected_group
483
- cells_y = embedding[grouping == group_y].copy()
483
+ cells_y = embedding[np.asarray(grouping == group_y)].copy()
484
484
  if not bootstrap:
485
485
  # By distance axiom, the distance between a group and itself is 0
486
486
  dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
@@ -691,17 +691,18 @@ class MMD(AbstractDistance):
691
691
 
692
692
 
693
693
  class WassersteinDistance(AbstractDistance):
694
- """Wasserstein distance metric (solved with entropy regularized Sinkhorn)."""
695
-
696
694
  def __init__(self) -> None:
697
695
  super().__init__()
698
696
  self.accepts_precomputed = False
699
697
 
700
698
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
699
+ X = np.asarray(X, dtype=np.float64)
700
+ Y = np.asarray(Y, dtype=np.float64)
701
701
  geom = PointCloud(X, Y)
702
702
  return self.solve_ot_problem(geom, **kwargs)
703
703
 
704
704
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
705
+ P = np.asarray(P, dtype=np.float64)
705
706
  geom = Geometry(cost_matrix=P[idx, :][:, ~idx])
706
707
  return self.solve_ot_problem(geom, **kwargs)
707
708
 
@@ -709,7 +710,13 @@ class WassersteinDistance(AbstractDistance):
709
710
  ot_prob = LinearProblem(geom)
710
711
  solver = Sinkhorn()
711
712
  ot = solver(ot_prob, **kwargs)
712
- return ot.reg_ot_cost.item()
713
+ cost = float(ot.reg_ot_cost)
714
+
715
+ # Check for NaN or invalid cost
716
+ if not np.isfinite(cost):
717
+ return 1.0
718
+ else:
719
+ return cost
713
720
 
714
721
 
715
722
  class EuclideanDistance(AbstractDistance):
@@ -981,7 +988,7 @@ class NBLL(AbstractDistance):
981
988
  try:
982
989
  nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
983
990
  return _compute_nll(y, nb_params, epsilon)
984
- except np.linalg.linalg.LinAlgError:
991
+ except np.linalg.LinAlgError:
985
992
  if x.mean() < 10 and y.mean() < 10:
986
993
  return 0.0
987
994
  else:
@@ -3,6 +3,7 @@ from collections.abc import Sequence
3
3
  from typing import Any, Literal
4
4
 
5
5
  import blitzgsea
6
+ import matplotlib.pyplot as plt
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  import scanpy as sc
@@ -14,6 +15,7 @@ from scipy.sparse import issparse
14
15
  from scipy.stats import hypergeom
15
16
  from statsmodels.stats.multitest import multipletests
16
17
 
18
+ from pertpy._doc import _doc_params, doc_common_plot_args
17
19
  from pertpy.metadata import Drug
18
20
 
19
21
 
@@ -290,9 +292,11 @@ class Enrichment:
290
292
 
291
293
  return enrichment
292
294
 
295
+ @_doc_params(common_plot_args=doc_common_plot_args)
293
296
  def plot_dotplot(
294
297
  self,
295
298
  adata: AnnData,
299
+ *,
296
300
  targets: dict[str, dict[str, list[str]]] = None,
297
301
  source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
298
302
  category_name: str = "interaction_type",
@@ -300,10 +304,10 @@ class Enrichment:
300
304
  groupby: str = None,
301
305
  key: str = "pertpy_enrichment",
302
306
  ax: Axes | None = None,
303
- save: bool | str | None = None,
304
- show: bool | None = None,
307
+ show: bool = True,
308
+ return_fig: bool = False,
305
309
  **kwargs,
306
- ) -> DotPlot | dict | None:
310
+ ) -> DotPlot | None:
307
311
  """Plots a dotplot by groupby and categories.
308
312
 
309
313
  Wraps scanpy's dotplot but formats it nicely by categories.
@@ -319,11 +323,11 @@ class Enrichment:
319
323
  category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`.
320
324
  groupby: dotplot groupby such as clusters or cell types.
321
325
  key: Prefix key of enrichment results in `uns`.
326
+ {common_plot_args}
322
327
  kwargs: Passed to scanpy dotplot.
323
328
 
324
329
  Returns:
325
- If `return_fig` is `True`, returns a :class:`~scanpy.pl.DotPlot` object,
326
- else if `show` is false, return axes dict.
330
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
327
331
 
328
332
  Examples:
329
333
  >>> import pertpy as pt
@@ -403,21 +407,27 @@ class Enrichment:
403
407
  "var_group_labels": var_group_labels,
404
408
  }
405
409
 
406
- return sc.pl.dotplot(
410
+ fig = sc.pl.dotplot(
407
411
  enrichment_score_adata,
408
412
  groupby=groupby,
409
413
  swap_axes=True,
410
414
  ax=ax,
411
- save=save,
412
- show=show,
415
+ show=False,
413
416
  **plot_args,
414
417
  **kwargs,
415
418
  )
416
419
 
420
+ if show:
421
+ plt.show()
422
+ if return_fig:
423
+ return fig
424
+ return None
425
+
417
426
  def plot_gsea(
418
427
  self,
419
428
  adata: AnnData,
420
429
  enrichment: dict[str, pd.DataFrame],
430
+ *,
421
431
  n: int = 10,
422
432
  key: str = "pertpy_enrichment_gsea",
423
433
  interactive_plot: bool = False,
pertpy/tools/_milo.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  import random
5
4
  import re
6
5
  from typing import TYPE_CHECKING, Literal
@@ -14,6 +13,8 @@ from anndata import AnnData
14
13
  from lamin_utils import logger
15
14
  from mudata import MuData
16
15
 
16
+ from pertpy._doc import _doc_params, doc_common_plot_args
17
+
17
18
  if TYPE_CHECKING:
18
19
  from collections.abc import Sequence
19
20
 
@@ -125,7 +126,7 @@ class Milo:
125
126
  try:
126
127
  use_rep = adata.uns["neighbors"]["params"]["use_rep"]
127
128
  except KeyError:
128
- logging.warning("Using X_pca as default embedding")
129
+ logger.warning("Using X_pca as default embedding")
129
130
  use_rep = "X_pca"
130
131
  try:
131
132
  knn_graph = adata.obsp["connectivities"].copy()
@@ -136,7 +137,7 @@ class Milo:
136
137
  try:
137
138
  use_rep = adata.uns[neighbors_key]["params"]["use_rep"]
138
139
  except KeyError:
139
- logging.warning("Using X_pca as default embedding")
140
+ logger.warning("Using X_pca as default embedding")
140
141
  use_rep = "X_pca"
141
142
  knn_graph = adata.obsp[neighbors_key + "_connectivities"].copy()
142
143
 
@@ -182,7 +183,7 @@ class Milo:
182
183
  knn_dists = adata.obsp[neighbors_key + "_distances"]
183
184
 
184
185
  nhood_ixs = adata.obs["nhood_ixs_refined"] == 1
185
- dist_mat = knn_dists[nhood_ixs, :]
186
+ dist_mat = knn_dists[np.asarray(nhood_ixs), :]
186
187
  k_distances = dist_mat.max(1).toarray().ravel()
187
188
  adata.obs["nhood_kth_distance"] = 0
188
189
  adata.obs["nhood_kth_distance"] = adata.obs["nhood_kth_distance"].astype(float)
@@ -703,8 +704,8 @@ class Milo:
703
704
  pvalues = sample_adata.var["PValue"]
704
705
  keep_nhoods = ~pvalues.isna() # Filtering in case of test on subset of nhoods
705
706
  o = pvalues[keep_nhoods].argsort()
706
- pvalues = pvalues[keep_nhoods][o]
707
- w = w[keep_nhoods][o]
707
+ pvalues = pvalues.loc[keep_nhoods].iloc[o]
708
+ w = w.loc[keep_nhoods].iloc[o]
708
709
 
709
710
  adjp = np.zeros(shape=len(o))
710
711
  adjp[o] = (sum(w) * pvalues / np.cumsum(w))[::-1].cummin()[::-1]
@@ -713,9 +714,11 @@ class Milo:
713
714
  sample_adata.var["SpatialFDR"] = np.nan
714
715
  sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp
715
716
 
717
+ @_doc_params(common_plot_args=doc_common_plot_args)
716
718
  def plot_nhood_graph(
717
719
  self,
718
720
  mdata: MuData,
721
+ *,
719
722
  alpha: float = 0.1,
720
723
  min_logFC: float = 0,
721
724
  min_size: int = 10,
@@ -724,10 +727,10 @@ class Milo:
724
727
  color_map: Colormap | str | None = None,
725
728
  palette: str | Sequence[str] | None = None,
726
729
  ax: Axes | None = None,
727
- show: bool | None = None,
728
- save: bool | str | None = None,
730
+ show: bool = True,
731
+ return_fig: bool = False,
729
732
  **kwargs,
730
- ) -> None:
733
+ ) -> Figure | None:
731
734
  """Visualize DA results on abstracted graph (wrapper around sc.pl.embedding)
732
735
 
733
736
  Args:
@@ -737,9 +740,7 @@ class Milo:
737
740
  min_size: Minimum size of nodes in visualization. (default: 10)
738
741
  plot_edges: If edges for neighbourhood overlaps whould be plotted.
739
742
  title: Plot title.
740
- show: Show the plot, do not return axis.
741
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
742
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
743
+ {common_plot_args}
743
744
  **kwargs: Additional arguments to `scanpy.pl.embedding`.
744
745
 
745
746
  Examples:
@@ -782,7 +783,7 @@ class Milo:
782
783
  vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
783
784
  vmin = -vmax
784
785
 
785
- sc.pl.embedding(
786
+ fig = sc.pl.embedding(
786
787
  nhood_adata,
787
788
  "X_milo_graph",
788
789
  color="graph_color",
@@ -798,33 +799,42 @@ class Milo:
798
799
  color_map=color_map,
799
800
  palette=palette,
800
801
  ax=ax,
801
- show=show,
802
- save=save,
802
+ show=False,
803
803
  **kwargs,
804
804
  )
805
805
 
806
+ if show:
807
+ plt.show()
808
+ if return_fig:
809
+ return fig
810
+ return None
811
+
812
+ @_doc_params(common_plot_args=doc_common_plot_args)
806
813
  def plot_nhood(
807
814
  self,
808
815
  mdata: MuData,
809
816
  ix: int,
817
+ *,
810
818
  feature_key: str | None = "rna",
811
819
  basis: str = "X_umap",
812
820
  color_map: Colormap | str | None = None,
813
821
  palette: str | Sequence[str] | None = None,
814
- return_fig: bool | None = None,
815
822
  ax: Axes | None = None,
816
- show: bool | None = None,
817
- save: bool | str | None = None,
823
+ show: bool = True,
824
+ return_fig: bool = False,
818
825
  **kwargs,
819
- ) -> None:
826
+ ) -> Figure | None:
820
827
  """Visualize cells in a neighbourhood.
821
828
 
822
829
  Args:
823
830
  mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
824
831
  ix: index of neighbourhood to visualize
832
+ feature_key: Key in mdata to the cell-level AnnData object.
825
833
  basis: Embedding to use for visualization.
826
- show: Show the plot, do not return axis.
827
- save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
834
+ color_map: Colormap to use for coloring.
835
+ palette: Color palette to use for coloring.
836
+ ax: Axes to plot on.
837
+ {common_plot_args}
828
838
  **kwargs: Additional arguments to `scanpy.pl.embedding`.
829
839
 
830
840
  Examples:
@@ -842,7 +852,7 @@ class Milo:
842
852
  .. image:: /_static/docstring_previews/milo_nhood.png
843
853
  """
844
854
  mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
845
- sc.pl.embedding(
855
+ fig = sc.pl.embedding(
846
856
  mdata[feature_key],
847
857
  basis,
848
858
  color="Nhood",
@@ -852,32 +862,43 @@ class Milo:
852
862
  palette=palette,
853
863
  return_fig=return_fig,
854
864
  ax=ax,
855
- show=show,
856
- save=save,
865
+ show=False,
857
866
  **kwargs,
858
867
  )
859
868
 
869
+ if show:
870
+ plt.show()
871
+ if return_fig:
872
+ return fig
873
+ return None
874
+
875
+ @_doc_params(common_plot_args=doc_common_plot_args)
860
876
  def plot_da_beeswarm(
861
877
  self,
862
878
  mdata: MuData,
879
+ *,
863
880
  feature_key: str | None = "rna",
864
881
  anno_col: str = "nhood_annotation",
865
882
  alpha: float = 0.1,
866
883
  subset_nhoods: list[str] = None,
867
884
  palette: str | Sequence[str] | dict[str, str] | None = None,
868
- return_fig: bool | None = None,
869
- save: bool | str | None = None,
870
- show: bool | None = None,
871
- ) -> Figure | Axes | None:
885
+ show: bool = True,
886
+ return_fig: bool = False,
887
+ ) -> Figure | None:
872
888
  """Plot beeswarm plot of logFC against nhood labels
873
889
 
874
890
  Args:
875
891
  mdata: MuData object
892
+ feature_key: Key in mdata to the cell-level AnnData object.
876
893
  anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
877
894
  alpha: Significance threshold. (default: 0.1)
878
895
  subset_nhoods: List of nhoods to plot. If None, plot all nhoods.
879
896
  palette: Name of Seaborn color palette for violinplots.
880
897
  Defaults to pre-defined category colors for violinplots.
898
+ {common_plot_args}
899
+
900
+ Returns:
901
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
881
902
 
882
903
  Examples:
883
904
  >>> import pertpy as pt
@@ -973,29 +994,23 @@ class Milo:
973
994
  plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
974
995
  plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
975
996
 
976
- if save:
977
- plt.savefig(save, bbox_inches="tight")
978
- return None
979
997
  if show:
980
998
  plt.show()
981
- return None
982
999
  if return_fig:
983
1000
  return plt.gcf()
984
- if (not show and not save) or (show is None and save is None):
985
- return plt.gca()
986
-
987
1001
  return None
988
1002
 
1003
+ @_doc_params(common_plot_args=doc_common_plot_args)
989
1004
  def plot_nhood_counts_by_cond(
990
1005
  self,
991
1006
  mdata: MuData,
992
1007
  test_var: str,
1008
+ *,
993
1009
  subset_nhoods: list[str] = None,
994
1010
  log_counts: bool = False,
995
- return_fig: bool | None = None,
996
- save: bool | str | None = None,
997
- show: bool | None = None,
998
- ) -> Figure | Axes | None:
1011
+ show: bool = True,
1012
+ return_fig: bool = False,
1013
+ ) -> Figure | None:
999
1014
  """Plot boxplot of cell numbers vs condition of interest.
1000
1015
 
1001
1016
  Args:
@@ -1003,6 +1018,10 @@ class Milo:
1003
1018
  test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
1004
1019
  subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods.
1005
1020
  log_counts: Whether to plot log1p of cell counts.
1021
+ {common_plot_args}
1022
+
1023
+ Returns:
1024
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1006
1025
  """
1007
1026
  try:
1008
1027
  nhood_adata = mdata["milo"].T.copy()
@@ -1031,15 +1050,8 @@ class Milo:
1031
1050
  plt.xticks(rotation=90)
1032
1051
  plt.xlabel(test_var)
1033
1052
 
1034
- if save:
1035
- plt.savefig(save, bbox_inches="tight")
1036
- return None
1037
1053
  if show:
1038
1054
  plt.show()
1039
- return None
1040
1055
  if return_fig:
1041
1056
  return plt.gcf()
1042
- if not (show or save):
1043
- return plt.gca()
1044
-
1045
1057
  return None