pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -850,7 +850,7 @@ class CompositionalModel2(ABC):
850
850
  table = Table(title="Compositional Analysis summary", box=box.SQUARE, expand=True, highlight=True)
851
851
  table.add_column("Name", justify="left", style="cyan")
852
852
  table.add_column("Value", justify="left")
853
- table.add_row("Data", "Data: %d samples, %d cell types" % data_dims)
853
+ table.add_row("Data", f"Data: {data_dims[0]} samples, {data_dims[1]} cell types")
854
854
  table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
855
855
  table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
856
856
  if extended:
@@ -1199,7 +1199,6 @@ class CompositionalModel2(ABC):
1199
1199
  level_order: list[str] = None,
1200
1200
  figsize: tuple[float, float] | None = None,
1201
1201
  dpi: int | None = 100,
1202
- show: bool = True,
1203
1202
  return_fig: bool = False,
1204
1203
  ) -> Figure | None:
1205
1204
  """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
@@ -1278,10 +1277,9 @@ class CompositionalModel2(ABC):
1278
1277
  show_legend=show_legend,
1279
1278
  )
1280
1279
 
1281
- if show:
1282
- plt.show()
1283
1280
  if return_fig:
1284
1281
  return plt.gcf()
1282
+ plt.show()
1285
1283
  return None
1286
1284
 
1287
1285
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1300,7 +1298,6 @@ class CompositionalModel2(ABC):
1300
1298
  args_barplot: dict | None = None,
1301
1299
  figsize: tuple[float, float] | None = None,
1302
1300
  dpi: int | None = 100,
1303
- show: bool = True,
1304
1301
  return_fig: bool = False,
1305
1302
  ) -> Figure | None:
1306
1303
  """Barplot visualization for effects.
@@ -1465,10 +1462,11 @@ class CompositionalModel2(ABC):
1465
1462
  cell_types = pd.unique(plot_df["Cell Type"])
1466
1463
  ax.set_xticklabels(cell_types, rotation=90)
1467
1464
 
1468
- if show:
1469
- plt.show()
1470
- if return_fig:
1465
+ if return_fig and plot_facets:
1466
+ return g
1467
+ if return_fig and not plot_facets:
1471
1468
  return plt.gcf()
1469
+ plt.show()
1472
1470
  return None
1473
1471
 
1474
1472
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1489,7 +1487,6 @@ class CompositionalModel2(ABC):
1489
1487
  level_order: list[str] = None,
1490
1488
  figsize: tuple[float, float] | None = None,
1491
1489
  dpi: int | None = 100,
1492
- show: bool = True,
1493
1490
  return_fig: bool = False,
1494
1491
  ) -> Figure | None:
1495
1492
  """Grouped boxplot visualization.
@@ -1697,10 +1694,11 @@ class CompositionalModel2(ABC):
1697
1694
  title=feature_name,
1698
1695
  )
1699
1696
 
1700
- if show:
1701
- plt.show()
1702
- if return_fig:
1697
+ if return_fig and plot_facets:
1698
+ return g
1699
+ if return_fig and not plot_facets:
1703
1700
  return plt.gcf()
1701
+ plt.show()
1704
1702
  return None
1705
1703
 
1706
1704
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1716,7 +1714,6 @@ class CompositionalModel2(ABC):
1716
1714
  figsize: tuple[float, float] | None = None,
1717
1715
  dpi: int | None = 100,
1718
1716
  ax: plt.Axes | None = None,
1719
- show: bool = True,
1720
1717
  return_fig: bool = False,
1721
1718
  ) -> Figure | None:
1722
1719
  """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
@@ -1820,10 +1817,9 @@ class CompositionalModel2(ABC):
1820
1817
 
1821
1818
  ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
1822
1819
 
1823
- if show:
1824
- plt.show()
1825
1820
  if return_fig:
1826
1821
  return plt.gcf()
1822
+ plt.show()
1827
1823
  return None
1828
1824
 
1829
1825
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1839,7 +1835,6 @@ class CompositionalModel2(ABC):
1839
1835
  figsize: tuple[float, float] | None = (None, None),
1840
1836
  dpi: int | None = 100,
1841
1837
  save: str | bool = False,
1842
- show: bool = True,
1843
1838
  return_fig: bool = False,
1844
1839
  ) -> Tree | None:
1845
1840
  """Plot a tree using input ete3 tree object.
@@ -1903,10 +1898,9 @@ class CompositionalModel2(ABC):
1903
1898
 
1904
1899
  if save is not None:
1905
1900
  tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1906
- if show:
1907
- return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1908
1901
  if return_fig:
1909
1902
  return tree, tree_style
1903
+ return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1910
1904
  return None
1911
1905
 
1912
1906
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1925,7 +1919,6 @@ class CompositionalModel2(ABC):
1925
1919
  figsize: tuple[float, float] | None = (None, None),
1926
1920
  dpi: int | None = 100,
1927
1921
  save: str | bool = False,
1928
- show: bool = True,
1929
1922
  return_fig: bool = False,
1930
1923
  ) -> Tree | None:
1931
1924
  """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
@@ -2092,15 +2085,16 @@ class CompositionalModel2(ABC):
2092
2085
 
2093
2086
  if save:
2094
2087
  plt.savefig(save)
2088
+ if return_fig:
2089
+ return plt.gcf()
2095
2090
 
2096
- if save and not show_leaf_effects:
2097
- tree2.render(save, tree_style=tree_style, units=units)
2098
- if show:
2099
- if not show_leaf_effects:
2100
- return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2101
- if return_fig:
2102
- if not show_leaf_effects:
2091
+ else:
2092
+ if save:
2093
+ tree2.render(save, tree_style=tree_style, units=units)
2094
+ if return_fig:
2103
2095
  return tree2, tree_style
2096
+ return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2097
+
2104
2098
  return None
2105
2099
 
2106
2100
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -2115,7 +2109,6 @@ class CompositionalModel2(ABC):
2115
2109
  color_map: Colormap | str | None = None,
2116
2110
  palette: str | Sequence[str] | None = None,
2117
2111
  ax: Axes = None,
2118
- show: bool = True,
2119
2112
  return_fig: bool = False,
2120
2113
  **kwargs,
2121
2114
  ) -> Figure | None:
@@ -2209,10 +2202,9 @@ class CompositionalModel2(ABC):
2209
2202
  **kwargs,
2210
2203
  )
2211
2204
 
2212
- if show:
2213
- plt.show()
2214
2205
  if return_fig:
2215
2206
  return fig
2207
+ plt.show()
2216
2208
  return None
2217
2209
 
2218
2210
 
pertpy/tools/_dialogue.py CHANGED
@@ -82,27 +82,27 @@ class Dialogue:
82
82
 
83
83
  return pseudobulk
84
84
 
85
- def _pseudobulk_pca(self, adata: AnnData, groupby: str, n_components: int = 50) -> pd.DataFrame:
86
- """Return cell-averaged PCA components.
85
+ def _pseudobulk_feature_space(
86
+ self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
87
+ ) -> pd.DataFrame:
88
+ """Return Cell-averaged components from a passed feature space.
87
89
 
88
90
  TODO: consider merging with `get_pseudobulks`
89
91
  TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
90
92
 
91
93
  Args:
92
- groupby: The key to groupby for pseudobulks
93
- n_components: The number of PCA components
94
+ groupby: The key to groupby for pseudobulks.
95
+ n_components: The number of components to use.
96
+ feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
94
97
 
95
98
  Returns:
96
- A pseudobulk of PCA components.
99
+ A pseudobulk DataFrame of the averaged components.
97
100
  """
98
101
  aggr = {}
99
-
100
102
  for category in adata.obs.loc[:, groupby].cat.categories:
101
103
  temp = adata.obs.loc[:, groupby] == category
102
- aggr[category] = adata[temp].obsm["X_pca"][:, :n_components].mean(axis=0)
103
-
104
+ aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
104
105
  aggr = pd.DataFrame(aggr)
105
-
106
106
  return aggr
107
107
 
108
108
  def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
@@ -558,7 +558,7 @@ class Dialogue:
558
558
  self,
559
559
  adata: AnnData,
560
560
  ct_order: list[str],
561
- agg_pca: bool = True,
561
+ agg_feature: bool = True,
562
562
  normalize: bool = True,
563
563
  ) -> tuple[list, dict]:
564
564
  """Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
@@ -568,14 +568,14 @@ class Dialogue:
568
568
  Args:
569
569
  adata: AnnData object generate celltype objects for
570
570
  ct_order: The order of cell types
571
- agg_pca: Whether to aggregate pseudobulks with PCA or not.
571
+ agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
572
572
  normalize: Whether to mimic DIALOGUE behavior or not.
573
573
 
574
574
  Returns:
575
575
  A celltype_label:array dictionary.
576
576
  """
577
577
  ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
578
- fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
578
+ fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
579
579
  ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
580
580
 
581
581
  # TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
@@ -593,7 +593,7 @@ class Dialogue:
593
593
  adata: AnnData,
594
594
  penalties: list[int] = None,
595
595
  ct_order: list[str] = None,
596
- agg_pca: bool = True,
596
+ agg_feature: bool = True,
597
597
  solver: Literal["lp", "bs"] = "bs",
598
598
  normalize: bool = True,
599
599
  ) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
@@ -606,7 +606,7 @@ class Dialogue:
606
606
  sample_id: Key to use for pseudobulk determination.
607
607
  penalties: PMD penalties.
608
608
  ct_order: The order of cell types.
609
- agg_pca: Whether to calculate cell-averaged PCA components.
609
+ agg_features: Whether to calculate cell-averaged principal components.
610
610
  solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
611
611
  For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
612
612
  normalize: Whether to mimic DIALOGUE as close as possible
@@ -631,7 +631,7 @@ class Dialogue:
631
631
  else:
632
632
  ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
633
633
 
634
- mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
634
+ mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
635
635
 
636
636
  n_samples = mcca_in[0].shape[1]
637
637
  if penalties is None:
@@ -1070,7 +1070,6 @@ class Dialogue:
1070
1070
  *,
1071
1071
  split_which: tuple[str, str] = None,
1072
1072
  mcp: str = "mcp_0",
1073
- show: bool = True,
1074
1073
  return_fig: bool = False,
1075
1074
  ) -> Figure | None:
1076
1075
  """Plots split violin plots for a given MCP and split variable.
@@ -1110,10 +1109,9 @@ class Dialogue:
1110
1109
  ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
1111
1110
  ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1112
1111
 
1113
- if show:
1114
- plt.show()
1115
1112
  if return_fig:
1116
1113
  return plt.gcf()
1114
+ plt.show()
1117
1115
  return None
1118
1116
 
1119
1117
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1125,7 +1123,6 @@ class Dialogue:
1125
1123
  sample_id: str,
1126
1124
  *,
1127
1125
  mcp: str = "mcp_0",
1128
- show: bool = True,
1129
1126
  return_fig: bool = False,
1130
1127
  ) -> Figure | None:
1131
1128
  """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
@@ -1167,8 +1164,7 @@ class Dialogue:
1167
1164
  mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
1168
1165
  sns.pairplot(mcp_pivot, hue=color, corner=True)
1169
1166
 
1170
- if show:
1171
- plt.show()
1172
1167
  if return_fig:
1173
1168
  return plt.gcf()
1169
+ plt.show()
1174
1170
  return None
@@ -125,7 +125,6 @@ class MethodBase(ABC):
125
125
  shape_order: list[str] | None = None,
126
126
  x_label: str | None = None,
127
127
  y_label: str | None = None,
128
- show: bool = True,
129
128
  return_fig: bool = False,
130
129
  **kwargs: int,
131
130
  ) -> Figure | None:
@@ -484,10 +483,9 @@ class MethodBase(ABC):
484
483
 
485
484
  plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
486
485
 
487
- if show:
488
- plt.show()
489
486
  if return_fig:
490
487
  return plt.gcf()
488
+ plt.show()
491
489
  return None
492
490
 
493
491
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -511,7 +509,6 @@ class MethodBase(ABC):
511
509
  pvalue_template=lambda x: f"p={x:.2e}",
512
510
  boxplot_properties=None,
513
511
  palette=None,
514
- show: bool = True,
515
512
  return_fig: bool = False,
516
513
  ) -> Figure | None:
517
514
  """Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
@@ -679,10 +676,9 @@ class MethodBase(ABC):
679
676
  )
680
677
 
681
678
  plt.tight_layout()
682
- if show:
683
- plt.show()
684
679
  if return_fig:
685
680
  return plt.gcf()
681
+ plt.show()
686
682
  return None
687
683
 
688
684
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -696,7 +692,6 @@ class MethodBase(ABC):
696
692
  symbol_col: str = "variable",
697
693
  y_label: str = "Log2 fold change",
698
694
  figsize: tuple[int, int] = (10, 5),
699
- show: bool = True,
700
695
  return_fig: bool = False,
701
696
  **barplot_kwargs,
702
697
  ) -> Figure | None:
@@ -762,10 +757,9 @@ class MethodBase(ABC):
762
757
  plt.xlabel("")
763
758
  plt.ylabel(y_label)
764
759
 
765
- if show:
766
- plt.show()
767
760
  if return_fig:
768
761
  return plt.gcf()
762
+ plt.show()
769
763
  return None
770
764
 
771
765
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -782,7 +776,6 @@ class MethodBase(ABC):
782
776
  figsize: tuple[int, int] = (10, 2),
783
777
  x_label: str = "Contrast",
784
778
  y_label: str = "Gene",
785
- show: bool = True,
786
779
  return_fig: bool = False,
787
780
  **heatmap_kwargs,
788
781
  ) -> Figure | None:
@@ -880,10 +873,9 @@ class MethodBase(ABC):
880
873
  plt.xlabel(x_label)
881
874
  plt.ylabel(y_label)
882
875
 
883
- if show:
884
- plt.show()
885
876
  if return_fig:
886
877
  return plt.gcf()
878
+ plt.show()
887
879
  return None
888
880
 
889
881
 
@@ -1117,67 +1117,75 @@ class MeanVarDistributionDistance(AbstractDistance):
1117
1117
  super().__init__()
1118
1118
  self.accepts_precomputed = False
1119
1119
 
1120
+ @staticmethod
1121
+ def _mean_var(x, log: bool = False):
1122
+ mean = np.mean(x, axis=0)
1123
+ var = np.var(x, axis=0)
1124
+ positive = mean > 0
1125
+ mean = mean[positive]
1126
+ var = var[positive]
1127
+ if log:
1128
+ mean = np.log(mean)
1129
+ var = np.log(var)
1130
+ return mean, var
1131
+
1132
+ @staticmethod
1133
+ def _prep_kde_data(x, y):
1134
+ return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
1135
+
1136
+ @staticmethod
1137
+ def _grid_points(d, n_points=100):
1138
+ # Make grid, add 1 bin on lower/upper end to get final n_points
1139
+ d_min = d.min()
1140
+ d_max = d.max()
1141
+ # Compute bin size
1142
+ d_bin = (d_max - d_min) / (n_points - 2)
1143
+ d_min = d_min - d_bin
1144
+ d_max = d_max + d_bin
1145
+ return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
1146
+
1147
+ @staticmethod
1148
+ def _kde_eval_both(x_kde, y_kde, grid):
1149
+ n_points = len(grid)
1150
+ chunk_size = 10000
1151
+
1152
+ result_x = np.zeros(n_points)
1153
+ result_y = np.zeros(n_points)
1154
+
1155
+ # Process same chunks for both KDEs
1156
+ for start in range(0, n_points, chunk_size):
1157
+ end = min(start + chunk_size, n_points)
1158
+ chunk = grid[start:end]
1159
+ result_x[start:end] = x_kde.score_samples(chunk)
1160
+ result_y[start:end] = y_kde.score_samples(chunk)
1161
+
1162
+ return result_x, result_y
1163
+
1120
1164
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1121
1165
  """Difference of mean-var distributions in 2 matrices.
1122
-
1123
1166
  Args:
1124
1167
  X: Normalized and log transformed cells x genes count matrix.
1125
1168
  Y: Normalized and log transformed cells x genes count matrix.
1126
1169
  """
1170
+ mean_x, var_x = self._mean_var(X, log=True)
1171
+ mean_y, var_y = self._mean_var(Y, log=True)
1127
1172
 
1128
- def _mean_var(x, log: bool = False):
1129
- mean = np.mean(x, axis=0)
1130
- var = np.var(x, axis=0)
1131
- positive = mean > 0
1132
- mean = mean[positive]
1133
- var = var[positive]
1134
- if log:
1135
- mean = np.log(mean)
1136
- var = np.log(var)
1137
- return mean, var
1138
-
1139
- def _prep_kde_data(x, y):
1140
- return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
1141
-
1142
- def _grid_points(d, n_points=100):
1143
- # Make grid, add 1 bin on lower/upper end to get final n_points
1144
- d_min = d.min()
1145
- d_max = d.max()
1146
- # Compute bin size
1147
- d_bin = (d_max - d_min) / (n_points - 2)
1148
- d_min = d_min - d_bin
1149
- d_max = d_max + d_bin
1150
- return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
1151
-
1152
- def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
1153
- # the thread_count is determined using the factor 0.875 as recommended here:
1154
- # https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
1155
- with multiprocessing.Pool(thread_count) as p:
1156
- return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
1157
-
1158
- def _kde_eval(d, grid):
1159
- # Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
1160
- # can not be compared well on regions further away from the data as they are -inf
1161
- kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
1162
- return _parallel_score_samples(kde, grid)
1163
-
1164
- mean_x, var_x = _mean_var(X, log=True)
1165
- mean_y, var_y = _mean_var(Y, log=True)
1166
-
1167
- x = _prep_kde_data(mean_x, var_x)
1168
- y = _prep_kde_data(mean_y, var_y)
1173
+ x = self._prep_kde_data(mean_x, var_x)
1174
+ y = self._prep_kde_data(mean_y, var_y)
1169
1175
 
1170
1176
  # Gridpoints to eval KDE on
1171
- mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
1172
- var_grid = _grid_points(np.concatenate([var_x, var_y]))
1177
+ mean_grid = self._grid_points(np.concatenate([mean_x, mean_y]))
1178
+ var_grid = self._grid_points(np.concatenate([var_x, var_y]))
1173
1179
  grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
1174
1180
 
1175
- kde_x = _kde_eval(x, grid)
1176
- kde_y = _kde_eval(y, grid)
1181
+ # Fit both KDEs first
1182
+ x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x)
1183
+ y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y)
1177
1184
 
1178
- kde_diff = ((kde_x - kde_y) ** 2).mean()
1185
+ # Evaluate both KDEs on same grid chunks
1186
+ kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid)
1179
1187
 
1180
- return kde_diff
1188
+ return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean()
1181
1189
 
1182
1190
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1183
1191
  raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
@@ -304,7 +304,6 @@ class Enrichment:
304
304
  groupby: str = None,
305
305
  key: str = "pertpy_enrichment",
306
306
  ax: Axes | None = None,
307
- show: bool = True,
308
307
  return_fig: bool = False,
309
308
  **kwargs,
310
309
  ) -> DotPlot | None:
@@ -417,10 +416,9 @@ class Enrichment:
417
416
  **kwargs,
418
417
  )
419
418
 
420
- if show:
421
- plt.show()
422
419
  if return_fig:
423
420
  return fig
421
+ plt.show()
424
422
  return None
425
423
 
426
424
  def plot_gsea(
pertpy/tools/_milo.py CHANGED
@@ -727,7 +727,6 @@ class Milo:
727
727
  color_map: Colormap | str | None = None,
728
728
  palette: str | Sequence[str] | None = None,
729
729
  ax: Axes | None = None,
730
- show: bool = True,
731
730
  return_fig: bool = False,
732
731
  **kwargs,
733
732
  ) -> Figure | None:
@@ -803,10 +802,9 @@ class Milo:
803
802
  **kwargs,
804
803
  )
805
804
 
806
- if show:
807
- plt.show()
808
805
  if return_fig:
809
806
  return fig
807
+ plt.show()
810
808
  return None
811
809
 
812
810
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -820,7 +818,6 @@ class Milo:
820
818
  color_map: Colormap | str | None = None,
821
819
  palette: str | Sequence[str] | None = None,
822
820
  ax: Axes | None = None,
823
- show: bool = True,
824
821
  return_fig: bool = False,
825
822
  **kwargs,
826
823
  ) -> Figure | None:
@@ -866,10 +863,9 @@ class Milo:
866
863
  **kwargs,
867
864
  )
868
865
 
869
- if show:
870
- plt.show()
871
866
  if return_fig:
872
867
  return fig
868
+ plt.show()
873
869
  return None
874
870
 
875
871
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -882,7 +878,6 @@ class Milo:
882
878
  alpha: float = 0.1,
883
879
  subset_nhoods: list[str] = None,
884
880
  palette: str | Sequence[str] | dict[str, str] | None = None,
885
- show: bool = True,
886
881
  return_fig: bool = False,
887
882
  ) -> Figure | None:
888
883
  """Plot beeswarm plot of logFC against nhood labels
@@ -994,10 +989,9 @@ class Milo:
994
989
  plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
995
990
  plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
996
991
 
997
- if show:
998
- plt.show()
999
992
  if return_fig:
1000
993
  return plt.gcf()
994
+ plt.show()
1001
995
  return None
1002
996
 
1003
997
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1008,7 +1002,6 @@ class Milo:
1008
1002
  *,
1009
1003
  subset_nhoods: list[str] = None,
1010
1004
  log_counts: bool = False,
1011
- show: bool = True,
1012
1005
  return_fig: bool = False,
1013
1006
  ) -> Figure | None:
1014
1007
  """Plot boxplot of cell numbers vs condition of interest.
@@ -1050,8 +1043,7 @@ class Milo:
1050
1043
  plt.xticks(rotation=90)
1051
1044
  plt.xlabel(test_var)
1052
1045
 
1053
- if show:
1054
- plt.show()
1055
1046
  if return_fig:
1056
1047
  return plt.gcf()
1048
+ plt.show()
1057
1049
  return None