pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -850,7 +850,7 @@ class CompositionalModel2(ABC):
850
850
  table = Table(title="Compositional Analysis summary", box=box.SQUARE, expand=True, highlight=True)
851
851
  table.add_column("Name", justify="left", style="cyan")
852
852
  table.add_column("Value", justify="left")
853
- table.add_row("Data", "Data: %d samples, %d cell types" % data_dims)
853
+ table.add_row("Data", f"Data: {data_dims[0]} samples, {data_dims[1]} cell types")
854
854
  table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
855
855
  table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
856
856
  if extended:
@@ -1199,7 +1199,6 @@ class CompositionalModel2(ABC):
1199
1199
  level_order: list[str] = None,
1200
1200
  figsize: tuple[float, float] | None = None,
1201
1201
  dpi: int | None = 100,
1202
- show: bool = True,
1203
1202
  return_fig: bool = False,
1204
1203
  ) -> Figure | None:
1205
1204
  """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
@@ -1278,10 +1277,9 @@ class CompositionalModel2(ABC):
1278
1277
  show_legend=show_legend,
1279
1278
  )
1280
1279
 
1281
- if show:
1282
- plt.show()
1283
1280
  if return_fig:
1284
1281
  return plt.gcf()
1282
+ plt.show()
1285
1283
  return None
1286
1284
 
1287
1285
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1300,7 +1298,6 @@ class CompositionalModel2(ABC):
1300
1298
  args_barplot: dict | None = None,
1301
1299
  figsize: tuple[float, float] | None = None,
1302
1300
  dpi: int | None = 100,
1303
- show: bool = True,
1304
1301
  return_fig: bool = False,
1305
1302
  ) -> Figure | None:
1306
1303
  """Barplot visualization for effects.
@@ -1465,10 +1462,11 @@ class CompositionalModel2(ABC):
1465
1462
  cell_types = pd.unique(plot_df["Cell Type"])
1466
1463
  ax.set_xticklabels(cell_types, rotation=90)
1467
1464
 
1468
- if show:
1469
- plt.show()
1470
- if return_fig:
1465
+ if return_fig and plot_facets:
1466
+ return g
1467
+ if return_fig and not plot_facets:
1471
1468
  return plt.gcf()
1469
+ plt.show()
1472
1470
  return None
1473
1471
 
1474
1472
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1489,7 +1487,6 @@ class CompositionalModel2(ABC):
1489
1487
  level_order: list[str] = None,
1490
1488
  figsize: tuple[float, float] | None = None,
1491
1489
  dpi: int | None = 100,
1492
- show: bool = True,
1493
1490
  return_fig: bool = False,
1494
1491
  ) -> Figure | None:
1495
1492
  """Grouped boxplot visualization.
@@ -1697,10 +1694,11 @@ class CompositionalModel2(ABC):
1697
1694
  title=feature_name,
1698
1695
  )
1699
1696
 
1700
- if show:
1701
- plt.show()
1702
- if return_fig:
1697
+ if return_fig and plot_facets:
1698
+ return g
1699
+ if return_fig and not plot_facets:
1703
1700
  return plt.gcf()
1701
+ plt.show()
1704
1702
  return None
1705
1703
 
1706
1704
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1716,7 +1714,6 @@ class CompositionalModel2(ABC):
1716
1714
  figsize: tuple[float, float] | None = None,
1717
1715
  dpi: int | None = 100,
1718
1716
  ax: plt.Axes | None = None,
1719
- show: bool = True,
1720
1717
  return_fig: bool = False,
1721
1718
  ) -> Figure | None:
1722
1719
  """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
@@ -1820,10 +1817,9 @@ class CompositionalModel2(ABC):
1820
1817
 
1821
1818
  ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
1822
1819
 
1823
- if show:
1824
- plt.show()
1825
1820
  if return_fig:
1826
1821
  return plt.gcf()
1822
+ plt.show()
1827
1823
  return None
1828
1824
 
1829
1825
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1839,7 +1835,6 @@ class CompositionalModel2(ABC):
1839
1835
  figsize: tuple[float, float] | None = (None, None),
1840
1836
  dpi: int | None = 100,
1841
1837
  save: str | bool = False,
1842
- show: bool = True,
1843
1838
  return_fig: bool = False,
1844
1839
  ) -> Tree | None:
1845
1840
  """Plot a tree using input ete3 tree object.
@@ -1903,10 +1898,9 @@ class CompositionalModel2(ABC):
1903
1898
 
1904
1899
  if save is not None:
1905
1900
  tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1906
- if show:
1907
- return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1908
1901
  if return_fig:
1909
1902
  return tree, tree_style
1903
+ return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1910
1904
  return None
1911
1905
 
1912
1906
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1925,7 +1919,6 @@ class CompositionalModel2(ABC):
1925
1919
  figsize: tuple[float, float] | None = (None, None),
1926
1920
  dpi: int | None = 100,
1927
1921
  save: str | bool = False,
1928
- show: bool = True,
1929
1922
  return_fig: bool = False,
1930
1923
  ) -> Tree | None:
1931
1924
  """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
@@ -2092,15 +2085,16 @@ class CompositionalModel2(ABC):
2092
2085
 
2093
2086
  if save:
2094
2087
  plt.savefig(save)
2088
+ if return_fig:
2089
+ return plt.gcf()
2095
2090
 
2096
- if save and not show_leaf_effects:
2097
- tree2.render(save, tree_style=tree_style, units=units)
2098
- if show:
2099
- if not show_leaf_effects:
2100
- return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2101
- if return_fig:
2102
- if not show_leaf_effects:
2091
+ else:
2092
+ if save:
2093
+ tree2.render(save, tree_style=tree_style, units=units)
2094
+ if return_fig:
2103
2095
  return tree2, tree_style
2096
+ return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2097
+
2104
2098
  return None
2105
2099
 
2106
2100
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -2115,7 +2109,6 @@ class CompositionalModel2(ABC):
2115
2109
  color_map: Colormap | str | None = None,
2116
2110
  palette: str | Sequence[str] | None = None,
2117
2111
  ax: Axes = None,
2118
- show: bool = True,
2119
2112
  return_fig: bool = False,
2120
2113
  **kwargs,
2121
2114
  ) -> Figure | None:
@@ -2209,10 +2202,9 @@ class CompositionalModel2(ABC):
2209
2202
  **kwargs,
2210
2203
  )
2211
2204
 
2212
- if show:
2213
- plt.show()
2214
2205
  if return_fig:
2215
2206
  return fig
2207
+ plt.show()
2216
2208
  return None
2217
2209
 
2218
2210
 
pertpy/tools/_dialogue.py CHANGED
@@ -82,27 +82,27 @@ class Dialogue:
82
82
 
83
83
  return pseudobulk
84
84
 
85
- def _pseudobulk_pca(self, adata: AnnData, groupby: str, n_components: int = 50) -> pd.DataFrame:
86
- """Return cell-averaged PCA components.
85
+ def _pseudobulk_feature_space(
86
+ self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
87
+ ) -> pd.DataFrame:
88
+ """Return Cell-averaged components from a passed feature space.
87
89
 
88
90
  TODO: consider merging with `get_pseudobulks`
89
91
  TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
90
92
 
91
93
  Args:
92
- groupby: The key to groupby for pseudobulks
93
- n_components: The number of PCA components
94
+ groupby: The key to groupby for pseudobulks.
95
+ n_components: The number of components to use.
96
+ feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
94
97
 
95
98
  Returns:
96
- A pseudobulk of PCA components.
99
+ A pseudobulk DataFrame of the averaged components.
97
100
  """
98
101
  aggr = {}
99
-
100
102
  for category in adata.obs.loc[:, groupby].cat.categories:
101
103
  temp = adata.obs.loc[:, groupby] == category
102
- aggr[category] = adata[temp].obsm["X_pca"][:, :n_components].mean(axis=0)
103
-
104
+ aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
104
105
  aggr = pd.DataFrame(aggr)
105
-
106
106
  return aggr
107
107
 
108
108
  def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
@@ -558,7 +558,7 @@ class Dialogue:
558
558
  self,
559
559
  adata: AnnData,
560
560
  ct_order: list[str],
561
- agg_pca: bool = True,
561
+ agg_feature: bool = True,
562
562
  normalize: bool = True,
563
563
  ) -> tuple[list, dict]:
564
564
  """Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
@@ -568,14 +568,14 @@ class Dialogue:
568
568
  Args:
569
569
  adata: AnnData object generate celltype objects for
570
570
  ct_order: The order of cell types
571
- agg_pca: Whether to aggregate pseudobulks with PCA or not.
571
+ agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
572
572
  normalize: Whether to mimic DIALOGUE behavior or not.
573
573
 
574
574
  Returns:
575
575
  A celltype_label:array dictionary.
576
576
  """
577
577
  ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
578
- fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
578
+ fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
579
579
  ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
580
580
 
581
581
  # TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
@@ -593,7 +593,7 @@ class Dialogue:
593
593
  adata: AnnData,
594
594
  penalties: list[int] = None,
595
595
  ct_order: list[str] = None,
596
- agg_pca: bool = True,
596
+ agg_feature: bool = True,
597
597
  solver: Literal["lp", "bs"] = "bs",
598
598
  normalize: bool = True,
599
599
  ) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
@@ -606,7 +606,7 @@ class Dialogue:
606
606
  sample_id: Key to use for pseudobulk determination.
607
607
  penalties: PMD penalties.
608
608
  ct_order: The order of cell types.
609
- agg_pca: Whether to calculate cell-averaged PCA components.
609
+ agg_features: Whether to calculate cell-averaged principal components.
610
610
  solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
611
611
  For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
612
612
  normalize: Whether to mimic DIALOGUE as close as possible
@@ -631,7 +631,7 @@ class Dialogue:
631
631
  else:
632
632
  ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
633
633
 
634
- mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
634
+ mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
635
635
 
636
636
  n_samples = mcca_in[0].shape[1]
637
637
  if penalties is None:
@@ -1070,7 +1070,6 @@ class Dialogue:
1070
1070
  *,
1071
1071
  split_which: tuple[str, str] = None,
1072
1072
  mcp: str = "mcp_0",
1073
- show: bool = True,
1074
1073
  return_fig: bool = False,
1075
1074
  ) -> Figure | None:
1076
1075
  """Plots split violin plots for a given MCP and split variable.
@@ -1110,10 +1109,9 @@ class Dialogue:
1110
1109
  ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
1111
1110
  ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1112
1111
 
1113
- if show:
1114
- plt.show()
1115
1112
  if return_fig:
1116
1113
  return plt.gcf()
1114
+ plt.show()
1117
1115
  return None
1118
1116
 
1119
1117
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1125,7 +1123,6 @@ class Dialogue:
1125
1123
  sample_id: str,
1126
1124
  *,
1127
1125
  mcp: str = "mcp_0",
1128
- show: bool = True,
1129
1126
  return_fig: bool = False,
1130
1127
  ) -> Figure | None:
1131
1128
  """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
@@ -1167,8 +1164,7 @@ class Dialogue:
1167
1164
  mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
1168
1165
  sns.pairplot(mcp_pivot, hue=color, corner=True)
1169
1166
 
1170
- if show:
1171
- plt.show()
1172
1167
  if return_fig:
1173
1168
  return plt.gcf()
1169
+ plt.show()
1174
1170
  return None
@@ -125,7 +125,6 @@ class MethodBase(ABC):
125
125
  shape_order: list[str] | None = None,
126
126
  x_label: str | None = None,
127
127
  y_label: str | None = None,
128
- show: bool = True,
129
128
  return_fig: bool = False,
130
129
  **kwargs: int,
131
130
  ) -> Figure | None:
@@ -484,10 +483,9 @@ class MethodBase(ABC):
484
483
 
485
484
  plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
486
485
 
487
- if show:
488
- plt.show()
489
486
  if return_fig:
490
487
  return plt.gcf()
488
+ plt.show()
491
489
  return None
492
490
 
493
491
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -511,7 +509,6 @@ class MethodBase(ABC):
511
509
  pvalue_template=lambda x: f"p={x:.2e}",
512
510
  boxplot_properties=None,
513
511
  palette=None,
514
- show: bool = True,
515
512
  return_fig: bool = False,
516
513
  ) -> Figure | None:
517
514
  """Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
@@ -679,10 +676,9 @@ class MethodBase(ABC):
679
676
  )
680
677
 
681
678
  plt.tight_layout()
682
- if show:
683
- plt.show()
684
679
  if return_fig:
685
680
  return plt.gcf()
681
+ plt.show()
686
682
  return None
687
683
 
688
684
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -696,7 +692,6 @@ class MethodBase(ABC):
696
692
  symbol_col: str = "variable",
697
693
  y_label: str = "Log2 fold change",
698
694
  figsize: tuple[int, int] = (10, 5),
699
- show: bool = True,
700
695
  return_fig: bool = False,
701
696
  **barplot_kwargs,
702
697
  ) -> Figure | None:
@@ -762,10 +757,9 @@ class MethodBase(ABC):
762
757
  plt.xlabel("")
763
758
  plt.ylabel(y_label)
764
759
 
765
- if show:
766
- plt.show()
767
760
  if return_fig:
768
761
  return plt.gcf()
762
+ plt.show()
769
763
  return None
770
764
 
771
765
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -782,7 +776,6 @@ class MethodBase(ABC):
782
776
  figsize: tuple[int, int] = (10, 2),
783
777
  x_label: str = "Contrast",
784
778
  y_label: str = "Gene",
785
- show: bool = True,
786
779
  return_fig: bool = False,
787
780
  **heatmap_kwargs,
788
781
  ) -> Figure | None:
@@ -880,10 +873,9 @@ class MethodBase(ABC):
880
873
  plt.xlabel(x_label)
881
874
  plt.ylabel(y_label)
882
875
 
883
- if show:
884
- plt.show()
885
876
  if return_fig:
886
877
  return plt.gcf()
878
+ plt.show()
887
879
  return None
888
880
 
889
881
 
@@ -1117,67 +1117,75 @@ class MeanVarDistributionDistance(AbstractDistance):
1117
1117
  super().__init__()
1118
1118
  self.accepts_precomputed = False
1119
1119
 
1120
+ @staticmethod
1121
+ def _mean_var(x, log: bool = False):
1122
+ mean = np.mean(x, axis=0)
1123
+ var = np.var(x, axis=0)
1124
+ positive = mean > 0
1125
+ mean = mean[positive]
1126
+ var = var[positive]
1127
+ if log:
1128
+ mean = np.log(mean)
1129
+ var = np.log(var)
1130
+ return mean, var
1131
+
1132
+ @staticmethod
1133
+ def _prep_kde_data(x, y):
1134
+ return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
1135
+
1136
+ @staticmethod
1137
+ def _grid_points(d, n_points=100):
1138
+ # Make grid, add 1 bin on lower/upper end to get final n_points
1139
+ d_min = d.min()
1140
+ d_max = d.max()
1141
+ # Compute bin size
1142
+ d_bin = (d_max - d_min) / (n_points - 2)
1143
+ d_min = d_min - d_bin
1144
+ d_max = d_max + d_bin
1145
+ return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
1146
+
1147
+ @staticmethod
1148
+ def _kde_eval_both(x_kde, y_kde, grid):
1149
+ n_points = len(grid)
1150
+ chunk_size = 10000
1151
+
1152
+ result_x = np.zeros(n_points)
1153
+ result_y = np.zeros(n_points)
1154
+
1155
+ # Process same chunks for both KDEs
1156
+ for start in range(0, n_points, chunk_size):
1157
+ end = min(start + chunk_size, n_points)
1158
+ chunk = grid[start:end]
1159
+ result_x[start:end] = x_kde.score_samples(chunk)
1160
+ result_y[start:end] = y_kde.score_samples(chunk)
1161
+
1162
+ return result_x, result_y
1163
+
1120
1164
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1121
1165
  """Difference of mean-var distributions in 2 matrices.
1122
-
1123
1166
  Args:
1124
1167
  X: Normalized and log transformed cells x genes count matrix.
1125
1168
  Y: Normalized and log transformed cells x genes count matrix.
1126
1169
  """
1170
+ mean_x, var_x = self._mean_var(X, log=True)
1171
+ mean_y, var_y = self._mean_var(Y, log=True)
1127
1172
 
1128
- def _mean_var(x, log: bool = False):
1129
- mean = np.mean(x, axis=0)
1130
- var = np.var(x, axis=0)
1131
- positive = mean > 0
1132
- mean = mean[positive]
1133
- var = var[positive]
1134
- if log:
1135
- mean = np.log(mean)
1136
- var = np.log(var)
1137
- return mean, var
1138
-
1139
- def _prep_kde_data(x, y):
1140
- return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
1141
-
1142
- def _grid_points(d, n_points=100):
1143
- # Make grid, add 1 bin on lower/upper end to get final n_points
1144
- d_min = d.min()
1145
- d_max = d.max()
1146
- # Compute bin size
1147
- d_bin = (d_max - d_min) / (n_points - 2)
1148
- d_min = d_min - d_bin
1149
- d_max = d_max + d_bin
1150
- return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
1151
-
1152
- def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
1153
- # the thread_count is determined using the factor 0.875 as recommended here:
1154
- # https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
1155
- with multiprocessing.Pool(thread_count) as p:
1156
- return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
1157
-
1158
- def _kde_eval(d, grid):
1159
- # Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
1160
- # can not be compared well on regions further away from the data as they are -inf
1161
- kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
1162
- return _parallel_score_samples(kde, grid)
1163
-
1164
- mean_x, var_x = _mean_var(X, log=True)
1165
- mean_y, var_y = _mean_var(Y, log=True)
1166
-
1167
- x = _prep_kde_data(mean_x, var_x)
1168
- y = _prep_kde_data(mean_y, var_y)
1173
+ x = self._prep_kde_data(mean_x, var_x)
1174
+ y = self._prep_kde_data(mean_y, var_y)
1169
1175
 
1170
1176
  # Gridpoints to eval KDE on
1171
- mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
1172
- var_grid = _grid_points(np.concatenate([var_x, var_y]))
1177
+ mean_grid = self._grid_points(np.concatenate([mean_x, mean_y]))
1178
+ var_grid = self._grid_points(np.concatenate([var_x, var_y]))
1173
1179
  grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
1174
1180
 
1175
- kde_x = _kde_eval(x, grid)
1176
- kde_y = _kde_eval(y, grid)
1181
+ # Fit both KDEs first
1182
+ x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x)
1183
+ y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y)
1177
1184
 
1178
- kde_diff = ((kde_x - kde_y) ** 2).mean()
1185
+ # Evaluate both KDEs on same grid chunks
1186
+ kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid)
1179
1187
 
1180
- return kde_diff
1188
+ return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean()
1181
1189
 
1182
1190
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1183
1191
  raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
@@ -304,7 +304,6 @@ class Enrichment:
304
304
  groupby: str = None,
305
305
  key: str = "pertpy_enrichment",
306
306
  ax: Axes | None = None,
307
- show: bool = True,
308
307
  return_fig: bool = False,
309
308
  **kwargs,
310
309
  ) -> DotPlot | None:
@@ -417,10 +416,9 @@ class Enrichment:
417
416
  **kwargs,
418
417
  )
419
418
 
420
- if show:
421
- plt.show()
422
419
  if return_fig:
423
420
  return fig
421
+ plt.show()
424
422
  return None
425
423
 
426
424
  def plot_gsea(
pertpy/tools/_milo.py CHANGED
@@ -727,7 +727,6 @@ class Milo:
727
727
  color_map: Colormap | str | None = None,
728
728
  palette: str | Sequence[str] | None = None,
729
729
  ax: Axes | None = None,
730
- show: bool = True,
731
730
  return_fig: bool = False,
732
731
  **kwargs,
733
732
  ) -> Figure | None:
@@ -803,10 +802,9 @@ class Milo:
803
802
  **kwargs,
804
803
  )
805
804
 
806
- if show:
807
- plt.show()
808
805
  if return_fig:
809
806
  return fig
807
+ plt.show()
810
808
  return None
811
809
 
812
810
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -820,7 +818,6 @@ class Milo:
820
818
  color_map: Colormap | str | None = None,
821
819
  palette: str | Sequence[str] | None = None,
822
820
  ax: Axes | None = None,
823
- show: bool = True,
824
821
  return_fig: bool = False,
825
822
  **kwargs,
826
823
  ) -> Figure | None:
@@ -866,10 +863,9 @@ class Milo:
866
863
  **kwargs,
867
864
  )
868
865
 
869
- if show:
870
- plt.show()
871
866
  if return_fig:
872
867
  return fig
868
+ plt.show()
873
869
  return None
874
870
 
875
871
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -882,7 +878,6 @@ class Milo:
882
878
  alpha: float = 0.1,
883
879
  subset_nhoods: list[str] = None,
884
880
  palette: str | Sequence[str] | dict[str, str] | None = None,
885
- show: bool = True,
886
881
  return_fig: bool = False,
887
882
  ) -> Figure | None:
888
883
  """Plot beeswarm plot of logFC against nhood labels
@@ -994,10 +989,9 @@ class Milo:
994
989
  plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
995
990
  plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
996
991
 
997
- if show:
998
- plt.show()
999
992
  if return_fig:
1000
993
  return plt.gcf()
994
+ plt.show()
1001
995
  return None
1002
996
 
1003
997
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1008,7 +1002,6 @@ class Milo:
1008
1002
  *,
1009
1003
  subset_nhoods: list[str] = None,
1010
1004
  log_counts: bool = False,
1011
- show: bool = True,
1012
1005
  return_fig: bool = False,
1013
1006
  ) -> Figure | None:
1014
1007
  """Plot boxplot of cell numbers vs condition of interest.
@@ -1050,8 +1043,7 @@ class Milo:
1050
1043
  plt.xticks(rotation=90)
1051
1044
  plt.xlabel(test_var)
1052
1045
 
1053
- if show:
1054
- plt.show()
1055
1046
  if return_fig:
1056
1047
  return plt.gcf()
1048
+ plt.show()
1057
1049
  return None