pertpy 0.9.4__py3-none-any.whl → 0.9.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from pathlib import Path
5
- from typing import TYPE_CHECKING, Literal, Optional, Union
5
+ from typing import TYPE_CHECKING, Literal
6
6
 
7
7
  import arviz as az
8
8
  import jax.numpy as jnp
@@ -26,6 +26,8 @@ from rich.console import Console
26
26
  from rich.table import Table
27
27
  from scipy.cluster import hierarchy as sp_hierarchy
28
28
 
29
+ from pertpy._doc import _doc_params, doc_common_plot_args
30
+
29
31
  if TYPE_CHECKING:
30
32
  from collections.abc import Sequence
31
33
 
@@ -307,7 +309,7 @@ class CompositionalModel2(ABC):
307
309
  if copy:
308
310
  sample_adata = sample_adata.copy()
309
311
 
310
- rng_key_array = random.key(rng_key)
312
+ rng_key_array = random.key_data(random.key(rng_key))
311
313
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
312
314
 
313
315
  # Set up NUTS kernel
@@ -1023,7 +1025,7 @@ class CompositionalModel2(ABC):
1023
1025
  >>> key_added="lineage", add_level_name=True
1024
1026
  >>> )
1025
1027
  >>> mdata = tasccoda.prepare(
1026
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1028
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi" : 0}
1027
1029
  >>> )
1028
1030
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1029
1031
  >>> node_effects = tasccoda.get_node_df(mdata)
@@ -1185,22 +1187,21 @@ class CompositionalModel2(ABC):
1185
1187
 
1186
1188
  return ax
1187
1189
 
1190
+ @_doc_params(common_plot_args=doc_common_plot_args)
1188
1191
  def plot_stacked_barplot( # pragma: no cover
1189
1192
  self,
1190
1193
  data: AnnData | MuData,
1191
1194
  feature_name: str,
1195
+ *,
1192
1196
  modality_key: str = "coda",
1193
1197
  palette: ListedColormap | None = cm.tab20,
1194
1198
  show_legend: bool | None = True,
1195
1199
  level_order: list[str] = None,
1196
1200
  figsize: tuple[float, float] | None = None,
1197
1201
  dpi: int | None = 100,
1198
- return_fig: bool | None = None,
1199
- ax: plt.Axes | None = None,
1200
- show: bool | None = None,
1201
- save: str | bool | None = None,
1202
- **kwargs,
1203
- ) -> plt.Axes | plt.Figure | None:
1202
+ show: bool = True,
1203
+ return_fig: bool = False,
1204
+ ) -> Figure | None:
1204
1205
  """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
1205
1206
 
1206
1207
  Args:
@@ -1212,9 +1213,10 @@ class CompositionalModel2(ABC):
1212
1213
  palette: The matplotlib color map for the barplot.
1213
1214
  show_legend: If True, adds a legend.
1214
1215
  level_order: Custom ordering of bars on the x-axis.
1216
+ {common_plot_args}
1215
1217
 
1216
1218
  Returns:
1217
- A :class:`~matplotlib.axes.Axes` object
1219
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1218
1220
 
1219
1221
  Examples:
1220
1222
  >>> import pertpy as pt
@@ -1239,7 +1241,7 @@ class CompositionalModel2(ABC):
1239
1241
  if level_order:
1240
1242
  assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
1241
1243
  data = data[level_order]
1242
- ax = self._stackbar(
1244
+ self._stackbar(
1243
1245
  data.X,
1244
1246
  type_names=data.var.index,
1245
1247
  title="samples",
@@ -1265,7 +1267,7 @@ class CompositionalModel2(ABC):
1265
1267
  l_indices = np.where(data.obs[feature_name] == levels[level])
1266
1268
  feature_totals[level] = np.sum(data.X[l_indices], axis=0)
1267
1269
 
1268
- ax = self._stackbar(
1270
+ self._stackbar(
1269
1271
  feature_totals,
1270
1272
  type_names=ct_names,
1271
1273
  title=feature_name,
@@ -1276,19 +1278,17 @@ class CompositionalModel2(ABC):
1276
1278
  show_legend=show_legend,
1277
1279
  )
1278
1280
 
1279
- if save:
1280
- plt.savefig(save, bbox_inches="tight")
1281
1281
  if show:
1282
1282
  plt.show()
1283
1283
  if return_fig:
1284
1284
  return plt.gcf()
1285
- if not (show or save):
1286
- return ax
1287
1285
  return None
1288
1286
 
1287
+ @_doc_params(common_plot_args=doc_common_plot_args)
1289
1288
  def plot_effects_barplot( # pragma: no cover
1290
1289
  self,
1291
1290
  data: AnnData | MuData,
1291
+ *,
1292
1292
  modality_key: str = "coda",
1293
1293
  covariates: str | list | None = None,
1294
1294
  parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
@@ -1300,11 +1300,9 @@ class CompositionalModel2(ABC):
1300
1300
  args_barplot: dict | None = None,
1301
1301
  figsize: tuple[float, float] | None = None,
1302
1302
  dpi: int | None = 100,
1303
- return_fig: bool | None = None,
1304
- ax: plt.Axes | None = None,
1305
- show: bool | None = None,
1306
- save: str | bool | None = None,
1307
- ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1303
+ show: bool = True,
1304
+ return_fig: bool = False,
1305
+ ) -> Figure | None:
1308
1306
  """Barplot visualization for effects.
1309
1307
 
1310
1308
  The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
@@ -1323,10 +1321,10 @@ class CompositionalModel2(ABC):
1323
1321
  palette: The seaborn color map for the barplot.
1324
1322
  level_order: Custom ordering of bars on the x-axis.
1325
1323
  args_barplot: Arguments passed to sns.barplot.
1324
+ {common_plot_args}
1326
1325
 
1327
1326
  Returns:
1328
- Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1329
- or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1327
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1330
1328
 
1331
1329
  Examples:
1332
1330
  >>> import pertpy as pt
@@ -1380,7 +1378,6 @@ class CompositionalModel2(ABC):
1380
1378
  if len(covariate_names_zero) != 0:
1381
1379
  if plot_facets:
1382
1380
  if plot_zero_covariate and not plot_zero_cell_type:
1383
- plot_df = plot_df[plot_df["value"] != 0]
1384
1381
  for covariate_name_zero in covariate_names_zero:
1385
1382
  new_row = {
1386
1383
  "Covariate": covariate_name_zero,
@@ -1437,16 +1434,6 @@ class CompositionalModel2(ABC):
1437
1434
  if ax.get_xticklabels()[0]._text == "zero":
1438
1435
  ax.set_xticks([])
1439
1436
 
1440
- if save:
1441
- plt.savefig(save, bbox_inches="tight")
1442
- if show:
1443
- plt.show()
1444
- if return_fig:
1445
- return plt.gcf()
1446
- if not (show or save):
1447
- return g
1448
- return None
1449
-
1450
1437
  # If not plot as facets, call barplot to plot cell types on the x-axis.
1451
1438
  else:
1452
1439
  _, ax = plt.subplots(figsize=figsize, dpi=dpi)
@@ -1478,20 +1465,18 @@ class CompositionalModel2(ABC):
1478
1465
  cell_types = pd.unique(plot_df["Cell Type"])
1479
1466
  ax.set_xticklabels(cell_types, rotation=90)
1480
1467
 
1481
- if save:
1482
- plt.savefig(save, bbox_inches="tight")
1483
- if show:
1484
- plt.show()
1485
- if return_fig:
1486
- return plt.gcf()
1487
- if not (show or save):
1488
- return ax
1489
- return None
1468
+ if show:
1469
+ plt.show()
1470
+ if return_fig:
1471
+ return plt.gcf()
1472
+ return None
1490
1473
 
1474
+ @_doc_params(common_plot_args=doc_common_plot_args)
1491
1475
  def plot_boxplots( # pragma: no cover
1492
1476
  self,
1493
1477
  data: AnnData | MuData,
1494
1478
  feature_name: str,
1479
+ *,
1495
1480
  modality_key: str = "coda",
1496
1481
  y_scale: Literal["relative", "log", "log10", "count"] = "relative",
1497
1482
  plot_facets: bool = False,
@@ -1504,11 +1489,9 @@ class CompositionalModel2(ABC):
1504
1489
  level_order: list[str] = None,
1505
1490
  figsize: tuple[float, float] | None = None,
1506
1491
  dpi: int | None = 100,
1507
- return_fig: bool | None = None,
1508
- ax: plt.Axes | None = None,
1509
- show: bool | None = None,
1510
- save: str | bool | None = None,
1511
- ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1492
+ show: bool = True,
1493
+ return_fig: bool = False,
1494
+ ) -> Figure | None:
1512
1495
  """Grouped boxplot visualization.
1513
1496
 
1514
1497
  The cell counts for each cell type are shown as a group of boxplots
@@ -1530,10 +1513,10 @@ class CompositionalModel2(ABC):
1530
1513
  palette: The seaborn color map for the barplot.
1531
1514
  show_legend: If True, adds a legend.
1532
1515
  level_order: Custom ordering of bars on the x-axis.
1516
+ {common_plot_args}
1533
1517
 
1534
1518
  Returns:
1535
- Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1536
- or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1519
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1537
1520
 
1538
1521
  Examples:
1539
1522
  >>> import pertpy as pt
@@ -1651,16 +1634,6 @@ class CompositionalModel2(ABC):
1651
1634
  **args_swarmplot,
1652
1635
  ).set_titles("{col_name}")
1653
1636
 
1654
- if save:
1655
- plt.savefig(save, bbox_inches="tight")
1656
- if show:
1657
- plt.show()
1658
- if return_fig:
1659
- return plt.gcf()
1660
- if not (show or save):
1661
- return g
1662
- return None
1663
-
1664
1637
  # If not plot as facets, call boxplot to plot cell types on the x-axis.
1665
1638
  else:
1666
1639
  if level_order:
@@ -1724,19 +1697,17 @@ class CompositionalModel2(ABC):
1724
1697
  title=feature_name,
1725
1698
  )
1726
1699
 
1727
- if save:
1728
- plt.savefig(save, bbox_inches="tight")
1729
- if show:
1730
- plt.show()
1731
- if return_fig:
1732
- return plt.gcf()
1733
- if not (show or save):
1734
- return ax
1735
- return None
1700
+ if show:
1701
+ plt.show()
1702
+ if return_fig:
1703
+ return plt.gcf()
1704
+ return None
1736
1705
 
1706
+ @_doc_params(common_plot_args=doc_common_plot_args)
1737
1707
  def plot_rel_abundance_dispersion_plot( # pragma: no cover
1738
1708
  self,
1739
1709
  data: AnnData | MuData,
1710
+ *,
1740
1711
  modality_key: str = "coda",
1741
1712
  abundant_threshold: float | None = 0.9,
1742
1713
  default_color: str | None = "Grey",
@@ -1744,11 +1715,10 @@ class CompositionalModel2(ABC):
1744
1715
  label_cell_types: bool = True,
1745
1716
  figsize: tuple[float, float] | None = None,
1746
1717
  dpi: int | None = 100,
1747
- return_fig: bool | None = None,
1748
1718
  ax: plt.Axes | None = None,
1749
- show: bool | None = None,
1750
- save: str | bool | None = None,
1751
- ) -> plt.Axes | plt.Figure | None:
1719
+ show: bool = True,
1720
+ return_fig: bool = False,
1721
+ ) -> Figure | None:
1752
1722
  """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
1753
1723
 
1754
1724
  If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
@@ -1763,9 +1733,10 @@ class CompositionalModel2(ABC):
1763
1733
  figsize: Figure size.
1764
1734
  dpi: Dpi setting.
1765
1735
  ax: A matplotlib axes object. Only works if plotting a single component.
1736
+ {common_plot_args}
1766
1737
 
1767
1738
  Returns:
1768
- A :class:`~matplotlib.axes.Axes` object
1739
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1769
1740
 
1770
1741
  Examples:
1771
1742
  >>> import pertpy as pt
@@ -1849,19 +1820,17 @@ class CompositionalModel2(ABC):
1849
1820
 
1850
1821
  ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
1851
1822
 
1852
- if save:
1853
- plt.savefig(save, bbox_inches="tight")
1854
1823
  if show:
1855
1824
  plt.show()
1856
1825
  if return_fig:
1857
1826
  return plt.gcf()
1858
- if not (show or save):
1859
- return ax
1860
1827
  return None
1861
1828
 
1829
+ @_doc_params(common_plot_args=doc_common_plot_args)
1862
1830
  def plot_draw_tree( # pragma: no cover
1863
1831
  self,
1864
1832
  data: AnnData | MuData,
1833
+ *,
1865
1834
  modality_key: str = "coda",
1866
1835
  tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1867
1836
  tight_text: bool | None = False,
@@ -1869,8 +1838,9 @@ class CompositionalModel2(ABC):
1869
1838
  units: Literal["px", "mm", "in"] | None = "px",
1870
1839
  figsize: tuple[float, float] | None = (None, None),
1871
1840
  dpi: int | None = 100,
1872
- show: bool | None = True,
1873
- save: str | bool | None = None,
1841
+ save: str | bool = False,
1842
+ show: bool = True,
1843
+ return_fig: bool = False,
1874
1844
  ) -> Tree | None:
1875
1845
  """Plot a tree using input ete3 tree object.
1876
1846
 
@@ -1881,12 +1851,11 @@ class CompositionalModel2(ABC):
1881
1851
  tight_text: When False, boundaries of the text are approximated according to general font metrics,
1882
1852
  producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1883
1853
  show_scale: Include the scale legend in the tree image or not.
1884
- show: If True, plot the tree inline. If false, return tree and tree_style objects.
1885
- file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
1886
- Output image can be saved whether show is True or not.
1887
1854
  units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
1888
1855
  figsize: Figure size.
1889
1856
  dpi: Dots per inches.
1857
+ save: Save the tree plot to a file. You can specify the file name here.
1858
+ {common_plot_args}
1890
1859
 
1891
1860
  Returns:
1892
1861
  Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
@@ -1901,7 +1870,7 @@ class CompositionalModel2(ABC):
1901
1870
  >>> key_added="lineage", add_level_name=True
1902
1871
  >>> )
1903
1872
  >>> mdata = tasccoda.prepare(
1904
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1873
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args=dict(phi=0)
1905
1874
  >>> )
1906
1875
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1907
1876
  >>> tasccoda.plot_draw_tree(mdata, tree="lineage")
@@ -1936,13 +1905,16 @@ class CompositionalModel2(ABC):
1936
1905
  tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1937
1906
  if show:
1938
1907
  return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1939
- else:
1908
+ if return_fig:
1940
1909
  return tree, tree_style
1910
+ return None
1941
1911
 
1912
+ @_doc_params(common_plot_args=doc_common_plot_args)
1942
1913
  def plot_draw_effects( # pragma: no cover
1943
1914
  self,
1944
1915
  data: AnnData | MuData,
1945
1916
  covariate: str,
1917
+ *,
1946
1918
  modality_key: str = "coda",
1947
1919
  tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1948
1920
  show_legend: bool | None = None,
@@ -1952,8 +1924,9 @@ class CompositionalModel2(ABC):
1952
1924
  units: Literal["px", "mm", "in"] | None = "px",
1953
1925
  figsize: tuple[float, float] | None = (None, None),
1954
1926
  dpi: int | None = 100,
1955
- show: bool | None = True,
1956
- save: str | None = None,
1927
+ save: str | bool = False,
1928
+ show: bool = True,
1929
+ return_fig: bool = False,
1957
1930
  ) -> Tree | None:
1958
1931
  """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
1959
1932
 
@@ -1968,15 +1941,15 @@ class CompositionalModel2(ABC):
1968
1941
  tight_text: When False, boundaries of the text are approximated according to general font metrics,
1969
1942
  producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1970
1943
  show_scale: Include the scale legend in the tree image or not.
1971
- show: If True, plot the tree inline. If false, return tree and tree_style objects.
1972
- file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
1973
1944
  units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
1974
1945
  figsize: Figure size.
1975
1946
  dpi: Dots per inches.
1947
+ save: Save the tree plot to a file. You can specify the file name here.
1948
+ {common_plot_args}
1976
1949
 
1977
1950
  Returns:
1978
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
1979
- or plot the tree inline (`show = False`)
1951
+ Returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`return_fig = False`)
1952
+ or plot the tree inline (`show = True`)
1980
1953
 
1981
1954
  Examples:
1982
1955
  >>> import pertpy as pt
@@ -1988,7 +1961,7 @@ class CompositionalModel2(ABC):
1988
1961
  >>> key_added="lineage", add_level_name=True
1989
1962
  >>> )
1990
1963
  >>> mdata = tasccoda.prepare(
1991
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1964
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args=dict(phi=0)
1992
1965
  >>> )
1993
1966
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1994
1967
  >>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
@@ -2117,52 +2090,55 @@ class CompositionalModel2(ABC):
2117
2090
  plt.xlim(-leaf_eff_max, leaf_eff_max)
2118
2091
  plt.subplots_adjust(wspace=0)
2119
2092
 
2120
- if save is not None:
2093
+ if save:
2121
2094
  plt.savefig(save)
2122
2095
 
2123
- if save is not None and not show_leaf_effects:
2096
+ if save and not show_leaf_effects:
2124
2097
  tree2.render(save, tree_style=tree_style, units=units)
2125
2098
  if show:
2126
2099
  if not show_leaf_effects:
2127
2100
  return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2128
- else:
2101
+ if return_fig:
2129
2102
  if not show_leaf_effects:
2130
2103
  return tree2, tree_style
2131
2104
  return None
2132
2105
 
2106
+ @_doc_params(common_plot_args=doc_common_plot_args)
2133
2107
  def plot_effects_umap( # pragma: no cover
2134
2108
  self,
2135
2109
  mdata: MuData,
2136
2110
  effect_name: str | list | None,
2137
2111
  cluster_key: str,
2112
+ *,
2138
2113
  modality_key_1: str = "rna",
2139
2114
  modality_key_2: str = "coda",
2140
2115
  color_map: Colormap | str | None = None,
2141
2116
  palette: str | Sequence[str] | None = None,
2142
- return_fig: bool | None = None,
2143
2117
  ax: Axes = None,
2144
- show: bool = None,
2145
- save: str | bool | None = None,
2118
+ show: bool = True,
2119
+ return_fig: bool = False,
2146
2120
  **kwargs,
2147
- ) -> plt.Axes | plt.Figure | None:
2121
+ ) -> Figure | None:
2148
2122
  """Plot a UMAP visualization colored by effect strength.
2149
2123
 
2150
2124
  Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
2151
2125
  (default is data['rna']) depending on the cluster they were assigned to.
2152
2126
 
2153
2127
  Args:
2154
- mudata: MuData object.
2128
+ mdata: MuData object.
2155
2129
  effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
2156
2130
  cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
2157
2131
  To assign cell types' effects to original cells.
2158
2132
  modality_key_1: Key to the cell-level AnnData in the MuData object.
2159
2133
  modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
2160
- show: Whether to display the figure or return axis.
2134
+ color_map: The color map to use for plotting.
2135
+ palette: The color palette to use for plotting.
2161
2136
  ax: A matplotlib axes object. Only works if plotting a single component.
2137
+ {common_plot_args}
2162
2138
  **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
2163
2139
 
2164
2140
  Returns:
2165
- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
2141
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
2166
2142
 
2167
2143
  Examples:
2168
2144
  >>> import pertpy as pt
@@ -2182,7 +2158,7 @@ class CompositionalModel2(ABC):
2182
2158
  >>> modality_key="coda",
2183
2159
  >>> reference_cell_type="18",
2184
2160
  >>> formula="condition",
2185
- >>> pen_args={"phi": 0, "lambda_1": 3.5},
2161
+ >>> pen_args=dict(phi=0, lambda_1=3.5),
2186
2162
  >>> tree_key="tree"
2187
2163
  >>> )
2188
2164
  >>> tasccoda_model.run_nuts(
@@ -2220,7 +2196,7 @@ class CompositionalModel2(ABC):
2220
2196
  else:
2221
2197
  vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
2222
2198
 
2223
- return sc.pl.umap(
2199
+ fig = sc.pl.umap(
2224
2200
  data_rna,
2225
2201
  color=effect_name,
2226
2202
  vmax=vmax,
@@ -2229,11 +2205,16 @@ class CompositionalModel2(ABC):
2229
2205
  color_map=color_map,
2230
2206
  return_fig=return_fig,
2231
2207
  ax=ax,
2232
- show=show,
2233
- save=save,
2208
+ show=False,
2234
2209
  **kwargs,
2235
2210
  )
2236
2211
 
2212
+ if show:
2213
+ plt.show()
2214
+ if return_fig:
2215
+ return fig
2216
+ return None
2217
+
2237
2218
 
2238
2219
  def get_a(
2239
2220
  tree: tt.core.ToyTree,
pertpy/tools/_dialogue.py CHANGED
@@ -25,6 +25,8 @@ from sklearn.linear_model import LinearRegression
25
25
  from sparsecca import lp_pmd, multicca_permute, multicca_pmd
26
26
  from statsmodels.sandbox.stats.multicomp import multipletests
27
27
 
28
+ from pertpy._doc import _doc_params, doc_common_plot_args
29
+
28
30
  if TYPE_CHECKING:
29
31
  from matplotlib.axes import Axes
30
32
  from matplotlib.figure import Figure
@@ -1059,18 +1061,18 @@ class Dialogue:
1059
1061
 
1060
1062
  return rank_dfs
1061
1063
 
1064
+ @_doc_params(common_plot_args=doc_common_plot_args)
1062
1065
  def plot_split_violins(
1063
1066
  self,
1064
1067
  adata: AnnData,
1065
1068
  split_key: str,
1066
1069
  celltype_key: str,
1070
+ *,
1067
1071
  split_which: tuple[str, str] = None,
1068
1072
  mcp: str = "mcp_0",
1069
- return_fig: bool | None = None,
1070
- ax: Axes | None = None,
1071
- save: bool | str | None = None,
1072
- show: bool | None = None,
1073
- ) -> Axes | Figure | None:
1073
+ show: bool = True,
1074
+ return_fig: bool = False,
1075
+ ) -> Figure | None:
1074
1076
  """Plots split violin plots for a given MCP and split variable.
1075
1077
 
1076
1078
  Any cells with a value for split_key not in split_which are removed from the plot.
@@ -1081,9 +1083,10 @@ class Dialogue:
1081
1083
  celltype_key: Key for cell type annotations.
1082
1084
  split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
1083
1085
  mcp: Key for MCP data.
1086
+ {common_plot_args}
1084
1087
 
1085
1088
  Returns:
1086
- A :class:`~matplotlib.axes.Axes` object
1089
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1087
1090
 
1088
1091
  Examples:
1089
1092
  >>> import pertpy as pt
@@ -1105,30 +1108,26 @@ class Dialogue:
1105
1108
  df[split_key] = df[split_key].cat.remove_unused_categories()
1106
1109
 
1107
1110
  ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
1108
-
1109
1111
  ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1110
1112
 
1111
- if save:
1112
- plt.savefig(save, bbox_inches="tight")
1113
1113
  if show:
1114
1114
  plt.show()
1115
1115
  if return_fig:
1116
1116
  return plt.gcf()
1117
- if not (show or save):
1118
- return ax
1119
1117
  return None
1120
1118
 
1119
+ @_doc_params(common_plot_args=doc_common_plot_args)
1121
1120
  def plot_pairplot(
1122
1121
  self,
1123
1122
  adata: AnnData,
1124
1123
  celltype_key: str,
1125
1124
  color: str,
1126
1125
  sample_id: str,
1126
+ *,
1127
1127
  mcp: str = "mcp_0",
1128
- return_fig: bool | None = None,
1129
- show: bool | None = None,
1130
- save: bool | str | None = None,
1131
- ) -> PairGrid | Figure | None:
1128
+ show: bool = True,
1129
+ return_fig: bool = False,
1130
+ ) -> Figure | None:
1132
1131
  """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
1133
1132
 
1134
1133
  Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
@@ -1140,9 +1139,10 @@ class Dialogue:
1140
1139
  color: Key in `adata.obs` for color annotations. This parameter is used as the hue
1141
1140
  sample_id: Key in `adata.obs` for the sample annotations.
1142
1141
  mcp: Key in `adata.obs` for MCP feature values.
1142
+ {common_plot_args}
1143
1143
 
1144
1144
  Returns:
1145
- Seaborn Pairgrid object.
1145
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1146
1146
 
1147
1147
  Examples:
1148
1148
  >>> import pertpy as pt
@@ -1165,14 +1165,10 @@ class Dialogue:
1165
1165
  aggstats = aggstats.loc[list(mcp_pivot.index), :]
1166
1166
  aggstats[color] = aggstats["top"]
1167
1167
  mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
1168
- ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
1168
+ sns.pairplot(mcp_pivot, hue=color, corner=True)
1169
1169
 
1170
- if save:
1171
- plt.savefig(save, bbox_inches="tight")
1172
1170
  if show:
1173
1171
  plt.show()
1174
1172
  if return_fig:
1175
1173
  return plt.gcf()
1176
- if not (show or save):
1177
- return ax
1178
1174
  return None
@@ -1,4 +1,4 @@
1
- from ._base import ContrastType, LinearModelBase, MethodBase
1
+ from ._base import LinearModelBase, MethodBase
2
2
  from ._dge_comparison import DGEEVAL
3
3
  from ._edger import EdgeR
4
4
  from ._pydeseq2 import PyDESeq2
@@ -14,7 +14,6 @@ __all__ = [
14
14
  "SimpleComparisonBase",
15
15
  "WilcoxonTest",
16
16
  "TTest",
17
- "ContrastType",
18
17
  ]
19
18
 
20
19
  AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]