pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from pathlib import Path
5
- from typing import TYPE_CHECKING, Literal, Optional, Union
5
+ from typing import TYPE_CHECKING, Literal
6
6
 
7
7
  import arviz as az
8
8
  import jax.numpy as jnp
@@ -26,6 +26,8 @@ from rich.console import Console
26
26
  from rich.table import Table
27
27
  from scipy.cluster import hierarchy as sp_hierarchy
28
28
 
29
+ from pertpy._doc import _doc_params, doc_common_plot_args
30
+
29
31
  if TYPE_CHECKING:
30
32
  from collections.abc import Sequence
31
33
 
@@ -307,7 +309,7 @@ class CompositionalModel2(ABC):
307
309
  if copy:
308
310
  sample_adata = sample_adata.copy()
309
311
 
310
- rng_key_array = random.key(rng_key)
312
+ rng_key_array = random.key_data(random.key(rng_key))
311
313
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
312
314
 
313
315
  # Set up NUTS kernel
@@ -848,7 +850,7 @@ class CompositionalModel2(ABC):
848
850
  table = Table(title="Compositional Analysis summary", box=box.SQUARE, expand=True, highlight=True)
849
851
  table.add_column("Name", justify="left", style="cyan")
850
852
  table.add_column("Value", justify="left")
851
- table.add_row("Data", "Data: %d samples, %d cell types" % data_dims)
853
+ table.add_row("Data", f"Data: {data_dims[0]} samples, {data_dims[1]} cell types")
852
854
  table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
853
855
  table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
854
856
  if extended:
@@ -1023,7 +1025,7 @@ class CompositionalModel2(ABC):
1023
1025
  >>> key_added="lineage", add_level_name=True
1024
1026
  >>> )
1025
1027
  >>> mdata = tasccoda.prepare(
1026
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1028
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi" : 0}
1027
1029
  >>> )
1028
1030
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1029
1031
  >>> node_effects = tasccoda.get_node_df(mdata)
@@ -1185,22 +1187,20 @@ class CompositionalModel2(ABC):
1185
1187
 
1186
1188
  return ax
1187
1189
 
1190
+ @_doc_params(common_plot_args=doc_common_plot_args)
1188
1191
  def plot_stacked_barplot( # pragma: no cover
1189
1192
  self,
1190
1193
  data: AnnData | MuData,
1191
1194
  feature_name: str,
1195
+ *,
1192
1196
  modality_key: str = "coda",
1193
1197
  palette: ListedColormap | None = cm.tab20,
1194
1198
  show_legend: bool | None = True,
1195
1199
  level_order: list[str] = None,
1196
1200
  figsize: tuple[float, float] | None = None,
1197
1201
  dpi: int | None = 100,
1198
- return_fig: bool | None = None,
1199
- ax: plt.Axes | None = None,
1200
- show: bool | None = None,
1201
- save: str | bool | None = None,
1202
- **kwargs,
1203
- ) -> plt.Axes | plt.Figure | None:
1202
+ return_fig: bool = False,
1203
+ ) -> Figure | None:
1204
1204
  """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
1205
1205
 
1206
1206
  Args:
@@ -1212,9 +1212,10 @@ class CompositionalModel2(ABC):
1212
1212
  palette: The matplotlib color map for the barplot.
1213
1213
  show_legend: If True, adds a legend.
1214
1214
  level_order: Custom ordering of bars on the x-axis.
1215
+ {common_plot_args}
1215
1216
 
1216
1217
  Returns:
1217
- A :class:`~matplotlib.axes.Axes` object
1218
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1218
1219
 
1219
1220
  Examples:
1220
1221
  >>> import pertpy as pt
@@ -1239,7 +1240,7 @@ class CompositionalModel2(ABC):
1239
1240
  if level_order:
1240
1241
  assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
1241
1242
  data = data[level_order]
1242
- ax = self._stackbar(
1243
+ self._stackbar(
1243
1244
  data.X,
1244
1245
  type_names=data.var.index,
1245
1246
  title="samples",
@@ -1265,7 +1266,7 @@ class CompositionalModel2(ABC):
1265
1266
  l_indices = np.where(data.obs[feature_name] == levels[level])
1266
1267
  feature_totals[level] = np.sum(data.X[l_indices], axis=0)
1267
1268
 
1268
- ax = self._stackbar(
1269
+ self._stackbar(
1269
1270
  feature_totals,
1270
1271
  type_names=ct_names,
1271
1272
  title=feature_name,
@@ -1276,19 +1277,16 @@ class CompositionalModel2(ABC):
1276
1277
  show_legend=show_legend,
1277
1278
  )
1278
1279
 
1279
- if save:
1280
- plt.savefig(save, bbox_inches="tight")
1281
- if show:
1282
- plt.show()
1283
1280
  if return_fig:
1284
1281
  return plt.gcf()
1285
- if not (show or save):
1286
- return ax
1282
+ plt.show()
1287
1283
  return None
1288
1284
 
1285
+ @_doc_params(common_plot_args=doc_common_plot_args)
1289
1286
  def plot_effects_barplot( # pragma: no cover
1290
1287
  self,
1291
1288
  data: AnnData | MuData,
1289
+ *,
1292
1290
  modality_key: str = "coda",
1293
1291
  covariates: str | list | None = None,
1294
1292
  parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
@@ -1300,11 +1298,8 @@ class CompositionalModel2(ABC):
1300
1298
  args_barplot: dict | None = None,
1301
1299
  figsize: tuple[float, float] | None = None,
1302
1300
  dpi: int | None = 100,
1303
- return_fig: bool | None = None,
1304
- ax: plt.Axes | None = None,
1305
- show: bool | None = None,
1306
- save: str | bool | None = None,
1307
- ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1301
+ return_fig: bool = False,
1302
+ ) -> Figure | None:
1308
1303
  """Barplot visualization for effects.
1309
1304
 
1310
1305
  The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
@@ -1323,10 +1318,10 @@ class CompositionalModel2(ABC):
1323
1318
  palette: The seaborn color map for the barplot.
1324
1319
  level_order: Custom ordering of bars on the x-axis.
1325
1320
  args_barplot: Arguments passed to sns.barplot.
1321
+ {common_plot_args}
1326
1322
 
1327
1323
  Returns:
1328
- Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1329
- or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1324
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1330
1325
 
1331
1326
  Examples:
1332
1327
  >>> import pertpy as pt
@@ -1380,7 +1375,6 @@ class CompositionalModel2(ABC):
1380
1375
  if len(covariate_names_zero) != 0:
1381
1376
  if plot_facets:
1382
1377
  if plot_zero_covariate and not plot_zero_cell_type:
1383
- plot_df = plot_df[plot_df["value"] != 0]
1384
1378
  for covariate_name_zero in covariate_names_zero:
1385
1379
  new_row = {
1386
1380
  "Covariate": covariate_name_zero,
@@ -1437,16 +1431,6 @@ class CompositionalModel2(ABC):
1437
1431
  if ax.get_xticklabels()[0]._text == "zero":
1438
1432
  ax.set_xticks([])
1439
1433
 
1440
- if save:
1441
- plt.savefig(save, bbox_inches="tight")
1442
- if show:
1443
- plt.show()
1444
- if return_fig:
1445
- return plt.gcf()
1446
- if not (show or save):
1447
- return g
1448
- return None
1449
-
1450
1434
  # If not plot as facets, call barplot to plot cell types on the x-axis.
1451
1435
  else:
1452
1436
  _, ax = plt.subplots(figsize=figsize, dpi=dpi)
@@ -1478,20 +1462,19 @@ class CompositionalModel2(ABC):
1478
1462
  cell_types = pd.unique(plot_df["Cell Type"])
1479
1463
  ax.set_xticklabels(cell_types, rotation=90)
1480
1464
 
1481
- if save:
1482
- plt.savefig(save, bbox_inches="tight")
1483
- if show:
1484
- plt.show()
1485
- if return_fig:
1486
- return plt.gcf()
1487
- if not (show or save):
1488
- return ax
1489
- return None
1465
+ if return_fig and plot_facets:
1466
+ return g
1467
+ if return_fig and not plot_facets:
1468
+ return plt.gcf()
1469
+ plt.show()
1470
+ return None
1490
1471
 
1472
+ @_doc_params(common_plot_args=doc_common_plot_args)
1491
1473
  def plot_boxplots( # pragma: no cover
1492
1474
  self,
1493
1475
  data: AnnData | MuData,
1494
1476
  feature_name: str,
1477
+ *,
1495
1478
  modality_key: str = "coda",
1496
1479
  y_scale: Literal["relative", "log", "log10", "count"] = "relative",
1497
1480
  plot_facets: bool = False,
@@ -1504,11 +1487,8 @@ class CompositionalModel2(ABC):
1504
1487
  level_order: list[str] = None,
1505
1488
  figsize: tuple[float, float] | None = None,
1506
1489
  dpi: int | None = 100,
1507
- return_fig: bool | None = None,
1508
- ax: plt.Axes | None = None,
1509
- show: bool | None = None,
1510
- save: str | bool | None = None,
1511
- ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1490
+ return_fig: bool = False,
1491
+ ) -> Figure | None:
1512
1492
  """Grouped boxplot visualization.
1513
1493
 
1514
1494
  The cell counts for each cell type are shown as a group of boxplots
@@ -1530,10 +1510,10 @@ class CompositionalModel2(ABC):
1530
1510
  palette: The seaborn color map for the barplot.
1531
1511
  show_legend: If True, adds a legend.
1532
1512
  level_order: Custom ordering of bars on the x-axis.
1513
+ {common_plot_args}
1533
1514
 
1534
1515
  Returns:
1535
- Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1536
- or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1516
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1537
1517
 
1538
1518
  Examples:
1539
1519
  >>> import pertpy as pt
@@ -1651,16 +1631,6 @@ class CompositionalModel2(ABC):
1651
1631
  **args_swarmplot,
1652
1632
  ).set_titles("{col_name}")
1653
1633
 
1654
- if save:
1655
- plt.savefig(save, bbox_inches="tight")
1656
- if show:
1657
- plt.show()
1658
- if return_fig:
1659
- return plt.gcf()
1660
- if not (show or save):
1661
- return g
1662
- return None
1663
-
1664
1634
  # If not plot as facets, call boxplot to plot cell types on the x-axis.
1665
1635
  else:
1666
1636
  if level_order:
@@ -1724,19 +1694,18 @@ class CompositionalModel2(ABC):
1724
1694
  title=feature_name,
1725
1695
  )
1726
1696
 
1727
- if save:
1728
- plt.savefig(save, bbox_inches="tight")
1729
- if show:
1730
- plt.show()
1731
- if return_fig:
1732
- return plt.gcf()
1733
- if not (show or save):
1734
- return ax
1735
- return None
1697
+ if return_fig and plot_facets:
1698
+ return g
1699
+ if return_fig and not plot_facets:
1700
+ return plt.gcf()
1701
+ plt.show()
1702
+ return None
1736
1703
 
1704
+ @_doc_params(common_plot_args=doc_common_plot_args)
1737
1705
  def plot_rel_abundance_dispersion_plot( # pragma: no cover
1738
1706
  self,
1739
1707
  data: AnnData | MuData,
1708
+ *,
1740
1709
  modality_key: str = "coda",
1741
1710
  abundant_threshold: float | None = 0.9,
1742
1711
  default_color: str | None = "Grey",
@@ -1744,11 +1713,9 @@ class CompositionalModel2(ABC):
1744
1713
  label_cell_types: bool = True,
1745
1714
  figsize: tuple[float, float] | None = None,
1746
1715
  dpi: int | None = 100,
1747
- return_fig: bool | None = None,
1748
1716
  ax: plt.Axes | None = None,
1749
- show: bool | None = None,
1750
- save: str | bool | None = None,
1751
- ) -> plt.Axes | plt.Figure | None:
1717
+ return_fig: bool = False,
1718
+ ) -> Figure | None:
1752
1719
  """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
1753
1720
 
1754
1721
  If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
@@ -1763,9 +1730,10 @@ class CompositionalModel2(ABC):
1763
1730
  figsize: Figure size.
1764
1731
  dpi: Dpi setting.
1765
1732
  ax: A matplotlib axes object. Only works if plotting a single component.
1733
+ {common_plot_args}
1766
1734
 
1767
1735
  Returns:
1768
- A :class:`~matplotlib.axes.Axes` object
1736
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1769
1737
 
1770
1738
  Examples:
1771
1739
  >>> import pertpy as pt
@@ -1849,19 +1817,16 @@ class CompositionalModel2(ABC):
1849
1817
 
1850
1818
  ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
1851
1819
 
1852
- if save:
1853
- plt.savefig(save, bbox_inches="tight")
1854
- if show:
1855
- plt.show()
1856
1820
  if return_fig:
1857
1821
  return plt.gcf()
1858
- if not (show or save):
1859
- return ax
1822
+ plt.show()
1860
1823
  return None
1861
1824
 
1825
+ @_doc_params(common_plot_args=doc_common_plot_args)
1862
1826
  def plot_draw_tree( # pragma: no cover
1863
1827
  self,
1864
1828
  data: AnnData | MuData,
1829
+ *,
1865
1830
  modality_key: str = "coda",
1866
1831
  tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1867
1832
  tight_text: bool | None = False,
@@ -1869,8 +1834,8 @@ class CompositionalModel2(ABC):
1869
1834
  units: Literal["px", "mm", "in"] | None = "px",
1870
1835
  figsize: tuple[float, float] | None = (None, None),
1871
1836
  dpi: int | None = 100,
1872
- show: bool | None = True,
1873
- save: str | bool | None = None,
1837
+ save: str | bool = False,
1838
+ return_fig: bool = False,
1874
1839
  ) -> Tree | None:
1875
1840
  """Plot a tree using input ete3 tree object.
1876
1841
 
@@ -1881,12 +1846,11 @@ class CompositionalModel2(ABC):
1881
1846
  tight_text: When False, boundaries of the text are approximated according to general font metrics,
1882
1847
  producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1883
1848
  show_scale: Include the scale legend in the tree image or not.
1884
- show: If True, plot the tree inline. If false, return tree and tree_style objects.
1885
- file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
1886
- Output image can be saved whether show is True or not.
1887
1849
  units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
1888
1850
  figsize: Figure size.
1889
1851
  dpi: Dots per inches.
1852
+ save: Save the tree plot to a file. You can specify the file name here.
1853
+ {common_plot_args}
1890
1854
 
1891
1855
  Returns:
1892
1856
  Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
@@ -1901,7 +1865,7 @@ class CompositionalModel2(ABC):
1901
1865
  >>> key_added="lineage", add_level_name=True
1902
1866
  >>> )
1903
1867
  >>> mdata = tasccoda.prepare(
1904
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1868
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args=dict(phi=0)
1905
1869
  >>> )
1906
1870
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1907
1871
  >>> tasccoda.plot_draw_tree(mdata, tree="lineage")
@@ -1934,15 +1898,17 @@ class CompositionalModel2(ABC):
1934
1898
 
1935
1899
  if save is not None:
1936
1900
  tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1937
- if show:
1938
- return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1939
- else:
1901
+ if return_fig:
1940
1902
  return tree, tree_style
1903
+ return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1904
+ return None
1941
1905
 
1906
+ @_doc_params(common_plot_args=doc_common_plot_args)
1942
1907
  def plot_draw_effects( # pragma: no cover
1943
1908
  self,
1944
1909
  data: AnnData | MuData,
1945
1910
  covariate: str,
1911
+ *,
1946
1912
  modality_key: str = "coda",
1947
1913
  tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1948
1914
  show_legend: bool | None = None,
@@ -1952,8 +1918,8 @@ class CompositionalModel2(ABC):
1952
1918
  units: Literal["px", "mm", "in"] | None = "px",
1953
1919
  figsize: tuple[float, float] | None = (None, None),
1954
1920
  dpi: int | None = 100,
1955
- show: bool | None = True,
1956
- save: str | None = None,
1921
+ save: str | bool = False,
1922
+ return_fig: bool = False,
1957
1923
  ) -> Tree | None:
1958
1924
  """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
1959
1925
 
@@ -1968,15 +1934,15 @@ class CompositionalModel2(ABC):
1968
1934
  tight_text: When False, boundaries of the text are approximated according to general font metrics,
1969
1935
  producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1970
1936
  show_scale: Include the scale legend in the tree image or not.
1971
- show: If True, plot the tree inline. If false, return tree and tree_style objects.
1972
- file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
1973
1937
  units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
1974
1938
  figsize: Figure size.
1975
1939
  dpi: Dots per inches.
1940
+ save: Save the tree plot to a file. You can specify the file name here.
1941
+ {common_plot_args}
1976
1942
 
1977
1943
  Returns:
1978
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
1979
- or plot the tree inline (`show = False`)
1944
+ Returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`return_fig = False`)
1945
+ or plot the tree inline (`show = True`)
1980
1946
 
1981
1947
  Examples:
1982
1948
  >>> import pertpy as pt
@@ -1988,7 +1954,7 @@ class CompositionalModel2(ABC):
1988
1954
  >>> key_added="lineage", add_level_name=True
1989
1955
  >>> )
1990
1956
  >>> mdata = tasccoda.prepare(
1991
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1957
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args=dict(phi=0)
1992
1958
  >>> )
1993
1959
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1994
1960
  >>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
@@ -2117,52 +2083,55 @@ class CompositionalModel2(ABC):
2117
2083
  plt.xlim(-leaf_eff_max, leaf_eff_max)
2118
2084
  plt.subplots_adjust(wspace=0)
2119
2085
 
2120
- if save is not None:
2086
+ if save:
2121
2087
  plt.savefig(save)
2088
+ if return_fig:
2089
+ return plt.gcf()
2122
2090
 
2123
- if save is not None and not show_leaf_effects:
2124
- tree2.render(save, tree_style=tree_style, units=units)
2125
- if show:
2126
- if not show_leaf_effects:
2127
- return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2128
2091
  else:
2129
- if not show_leaf_effects:
2092
+ if save:
2093
+ tree2.render(save, tree_style=tree_style, units=units)
2094
+ if return_fig:
2130
2095
  return tree2, tree_style
2096
+ return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2097
+
2131
2098
  return None
2132
2099
 
2100
+ @_doc_params(common_plot_args=doc_common_plot_args)
2133
2101
  def plot_effects_umap( # pragma: no cover
2134
2102
  self,
2135
2103
  mdata: MuData,
2136
2104
  effect_name: str | list | None,
2137
2105
  cluster_key: str,
2106
+ *,
2138
2107
  modality_key_1: str = "rna",
2139
2108
  modality_key_2: str = "coda",
2140
2109
  color_map: Colormap | str | None = None,
2141
2110
  palette: str | Sequence[str] | None = None,
2142
- return_fig: bool | None = None,
2143
2111
  ax: Axes = None,
2144
- show: bool = None,
2145
- save: str | bool | None = None,
2112
+ return_fig: bool = False,
2146
2113
  **kwargs,
2147
- ) -> plt.Axes | plt.Figure | None:
2114
+ ) -> Figure | None:
2148
2115
  """Plot a UMAP visualization colored by effect strength.
2149
2116
 
2150
2117
  Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
2151
2118
  (default is data['rna']) depending on the cluster they were assigned to.
2152
2119
 
2153
2120
  Args:
2154
- mudata: MuData object.
2121
+ mdata: MuData object.
2155
2122
  effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
2156
2123
  cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
2157
2124
  To assign cell types' effects to original cells.
2158
2125
  modality_key_1: Key to the cell-level AnnData in the MuData object.
2159
2126
  modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
2160
- show: Whether to display the figure or return axis.
2127
+ color_map: The color map to use for plotting.
2128
+ palette: The color palette to use for plotting.
2161
2129
  ax: A matplotlib axes object. Only works if plotting a single component.
2130
+ {common_plot_args}
2162
2131
  **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
2163
2132
 
2164
2133
  Returns:
2165
- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
2134
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
2166
2135
 
2167
2136
  Examples:
2168
2137
  >>> import pertpy as pt
@@ -2182,7 +2151,7 @@ class CompositionalModel2(ABC):
2182
2151
  >>> modality_key="coda",
2183
2152
  >>> reference_cell_type="18",
2184
2153
  >>> formula="condition",
2185
- >>> pen_args={"phi": 0, "lambda_1": 3.5},
2154
+ >>> pen_args=dict(phi=0, lambda_1=3.5),
2186
2155
  >>> tree_key="tree"
2187
2156
  >>> )
2188
2157
  >>> tasccoda_model.run_nuts(
@@ -2220,7 +2189,7 @@ class CompositionalModel2(ABC):
2220
2189
  else:
2221
2190
  vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
2222
2191
 
2223
- return sc.pl.umap(
2192
+ fig = sc.pl.umap(
2224
2193
  data_rna,
2225
2194
  color=effect_name,
2226
2195
  vmax=vmax,
@@ -2229,11 +2198,15 @@ class CompositionalModel2(ABC):
2229
2198
  color_map=color_map,
2230
2199
  return_fig=return_fig,
2231
2200
  ax=ax,
2232
- show=show,
2233
- save=save,
2201
+ show=False,
2234
2202
  **kwargs,
2235
2203
  )
2236
2204
 
2205
+ if return_fig:
2206
+ return fig
2207
+ plt.show()
2208
+ return None
2209
+
2237
2210
 
2238
2211
  def get_a(
2239
2212
  tree: tt.core.ToyTree,