pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pertpy/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = "Lukas Heumos"
4
4
  __email__ = "lukas.heumos@posteo.net"
5
- __version__ = "0.9.5"
5
+ __version__ = "0.10.0"
6
6
 
7
7
  import warnings
8
8
 
pertpy/_doc.py CHANGED
@@ -15,6 +15,5 @@ def _doc_params(**kwds): # pragma: no cover
15
15
 
16
16
 
17
17
  doc_common_plot_args = """\
18
- show: if `True`, shows the plot.
19
- return_fig: if `True`, returns figure of the plot.\
18
+ return_fig: if `True`, returns figure of the plot, that can be used for saving.\
20
19
  """
@@ -703,7 +703,6 @@ class CellLine(MetaData):
703
703
  metadata_key: str = "bulk_rna_broad",
704
704
  category: str = "cell line",
705
705
  subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
706
- show: bool = True,
707
706
  return_fig: bool = False,
708
707
  ) -> Figure | None:
709
708
  """Visualise the correlation of cell lines with annotated metadata.
@@ -747,7 +746,7 @@ class CellLine(MetaData):
747
746
  if all(isinstance(id, str) for id in subset_identifier_list):
748
747
  if set(subset_identifier_list).issubset(adata.obs[identifier].unique()):
749
748
  subset_identifier_list = np.where(
750
- np.in1d(adata.obs[identifier].values, subset_identifier_list)
749
+ np.isin(adata.obs[identifier].values, subset_identifier_list)
751
750
  )[0]
752
751
  else:
753
752
  raise ValueError("`Subset_identifier` must be found in adata.obs.`identifier`.")
@@ -798,10 +797,9 @@ class CellLine(MetaData):
798
797
  },
799
798
  )
800
799
 
801
- if show:
802
- plt.show()
803
800
  if return_fig:
804
801
  return plt.gcf()
802
+ plt.show()
805
803
  return None
806
804
  else:
807
- raise NotImplementedError
805
+ raise NotImplementedError("Only 'cell line' category is supported for correlation comparison.")
@@ -1,15 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import uuid
4
- from typing import TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Literal
5
+ from warnings import warn
5
6
 
6
7
  import matplotlib.pyplot as plt
7
8
  import numpy as np
8
9
  import pandas as pd
9
10
  import scanpy as sc
10
11
  import scipy
12
+ from rich.progress import track
13
+ from scipy.sparse import issparse
11
14
 
12
15
  from pertpy._doc import _doc_params, doc_common_plot_args
16
+ from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
13
17
 
14
18
  if TYPE_CHECKING:
15
19
  from anndata import AnnData
@@ -17,7 +21,7 @@ if TYPE_CHECKING:
17
21
 
18
22
 
19
23
  class GuideAssignment:
20
- """Offers simple guide assigment based on count thresholds."""
24
+ """Assign cells to guide RNAs."""
21
25
 
22
26
  def assign_by_threshold(
23
27
  self,
@@ -33,12 +37,12 @@ class GuideAssignment:
33
37
  This function expects unnormalized data as input.
34
38
 
35
39
  Args:
36
- adata: Annotated data matrix containing gRNA values
40
+ adata: AnnData object containing gRNA values.
37
41
  assignment_threshold: The count threshold that is required for an assignment to be viable.
38
42
  layer: Key to the layer containing raw count values of the gRNAs.
39
43
  adata.X is used if layer is None. Expects count data.
40
44
  output_layer: Assigned guide will be saved on adata.layers[output_key].
41
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
45
+ only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
42
46
 
43
47
  Examples:
44
48
  Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
@@ -67,7 +71,7 @@ class GuideAssignment:
67
71
  assignment_threshold: float,
68
72
  layer: str | None = None,
69
73
  output_key: str = "assigned_guide",
70
- no_grna_assigned_key: str = "NT",
74
+ no_grna_assigned_key: str = "Negative",
71
75
  only_return_results: bool = False,
72
76
  ) -> np.ndarray | None:
73
77
  """Simple threshold based max gRNA assignment function.
@@ -76,13 +80,13 @@ class GuideAssignment:
76
80
  This function expects unnormalized data as input.
77
81
 
78
82
  Args:
79
- adata: Annotated data matrix containing gRNA values
83
+ adata: AnnData object containing gRNA values.
80
84
  assignment_threshold: The count threshold that is required for an assignment to be viable.
81
85
  layer: Key to the layer containing raw count values of the gRNAs.
82
86
  adata.X is used if layer is None. Expects count data.
83
87
  output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
84
88
  no_grna_assigned_key: The key to return if no gRNA is expressed enough.
85
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
89
+ only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
86
90
 
87
91
  Examples:
88
92
  Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
@@ -109,6 +113,92 @@ class GuideAssignment:
109
113
 
110
114
  return None
111
115
 
116
+ def assign_mixture_model(
117
+ self,
118
+ adata: AnnData,
119
+ model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
120
+ assigned_guides_key: str = "assigned_guide",
121
+ no_grna_assigned_key: str = "negative",
122
+ max_assignments_per_cell: int = 5,
123
+ multiple_grna_assigned_key: str = "multiple",
124
+ multiple_grna_assignment_string: str = "+",
125
+ only_return_results: bool = False,
126
+ uns_key: str = "guide_assignment_params",
127
+ show_progress: bool = False,
128
+ **mixture_model_kwargs,
129
+ ) -> np.ndarray | None:
130
+ """Assigns gRNAs to cells using a mixture model.
131
+
132
+ Args:
133
+ adata: AnnData object containing gRNA values.
134
+ model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
135
+ output_key: Assigned guide will be saved on adata.obs[output_key].
136
+ no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
137
+ max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
138
+ multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
139
+ multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
140
+ only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
141
+ show_progress: Whether to shows progress bar.
142
+ mixture_model_kwargs: Are passed to the mixture model.
143
+
144
+ Examples:
145
+ >>> import pertpy as pt
146
+ >>> mdata = pt.dt.papalexi_2021()
147
+ >>> gdo = mdata.mod["gdo"]
148
+ >>> ga = pt.pp.GuideAssignment()
149
+ >>> ga.assign_mixture_model(gdo)
150
+ """
151
+ if model == "poisson_gauss_mixture":
152
+ mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
153
+ else:
154
+ raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
155
+
156
+ if uns_key not in adata.uns:
157
+ adata.uns[uns_key] = {}
158
+ elif type(adata.uns[uns_key]) is not dict:
159
+ raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")
160
+
161
+ res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
162
+ fct = track if show_progress else lambda iterable: iterable
163
+ for gene in fct(adata.var_names):
164
+ is_nonzero = (
165
+ np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
166
+ )
167
+ if sum(is_nonzero) < 2:
168
+ warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
169
+ continue
170
+ # We are only fitting the model to the non-zero values, the rest is
171
+ # automatically assigned to the negative class
172
+ data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
173
+ data = np.ravel(data)
174
+
175
+ if np.any(data < 0):
176
+ raise ValueError(
177
+ "Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
178
+ )
179
+
180
+ # Log2 transform the data so positive population is approximately normal
181
+ data = np.log2(data)
182
+ assignments = mixture_model.run_model(data)
183
+ res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
184
+ adata.uns[uns_key][gene] = mixture_model.params
185
+
186
+ # Assign guides to cells
187
+ # Some cells might have multiple guides assigned
188
+ series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
189
+ num_guides_assigned = res.sum(1)
190
+ series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
191
+ lambda row: row.index[row == 1].tolist(), axis=1
192
+ ).str.join(multiple_grna_assignment_string)
193
+ series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
194
+
195
+ if only_return_results:
196
+ return series.values
197
+
198
+ adata.obs[assigned_guides_key] = series.values
199
+
200
+ return None
201
+
112
202
  @_doc_params(common_plot_args=doc_common_plot_args)
113
203
  def plot_heatmap(
114
204
  self,
@@ -117,7 +207,6 @@ class GuideAssignment:
117
207
  layer: str | None = None,
118
208
  order_by: np.ndarray | str | None = None,
119
209
  key_to_save_order: str = None,
120
- show: bool = True,
121
210
  return_fig: bool = False,
122
211
  **kwargs,
123
212
  ) -> Figure | None:
@@ -194,8 +283,7 @@ class GuideAssignment:
194
283
  finally:
195
284
  del adata.obs[temp_col_name]
196
285
 
197
- if show:
198
- plt.show()
199
286
  if return_fig:
200
287
  return fig
288
+ plt.show()
201
289
  return None
@@ -0,0 +1,179 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Mapping
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import numpyro
10
+ import numpyro.distributions as dist
11
+ from jax import random
12
+ from numpyro.infer import MCMC, NUTS
13
+
14
+ ParamsDict = Mapping[str, jnp.ndarray]
15
+
16
+
17
+ class MixtureModel(ABC):
18
+ """Abstract base class for 2-component mixture models.
19
+
20
+ Args:
21
+ num_warmup: Number of warmup steps for MCMC sampling.
22
+ num_samples: Number of samples to draw after warmup.
23
+ fraction_positive_expected: Prior belief about fraction of positive components.
24
+ poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
25
+ gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
26
+ gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ num_warmup: int = 50,
32
+ num_samples: int = 100,
33
+ fraction_positive_expected: float = 0.15,
34
+ poisson_rate_prior: float = 0.2,
35
+ gaussian_mean_prior: tuple[float, float] = (3, 2),
36
+ gaussian_std_prior: float = 1,
37
+ ) -> None:
38
+ self.num_warmup = num_warmup
39
+ self.num_samples = num_samples
40
+ self.fraction_positive_expected = fraction_positive_expected
41
+ self.poisson_rate_prior = poisson_rate_prior
42
+ self.gaussian_mean_prior = gaussian_mean_prior
43
+ self.gaussian_std_prior = gaussian_std_prior
44
+
45
+ @abstractmethod
46
+ def initialize_params(self) -> ParamsDict:
47
+ """Initialize model parameters via sampling from priors.
48
+
49
+ Returns:
50
+ Dictionary of sampled parameter values.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
56
+ """Calculate log likelihood of data under current parameters.
57
+
58
+ Args:
59
+ data: Input data array.
60
+ params: Current parameter values.
61
+
62
+ Returns:
63
+ Log likelihood values for each datapoint.
64
+ """
65
+ pass
66
+
67
+ def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
68
+ """Fit the mixture model using MCMC.
69
+
70
+ Args:
71
+ data: Input data to fit.
72
+ seed: Random seed for reproducibility.
73
+
74
+ Returns:
75
+ Fitted MCMC object containing samples.
76
+ """
77
+ nuts_kernel = NUTS(self.mixture_model)
78
+ mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
79
+ mcmc.run(random.PRNGKey(seed), data=data)
80
+ return mcmc
81
+
82
+ def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
83
+ """Run model fitting and assign components.
84
+
85
+ Args:
86
+ data: Input data array.
87
+ seed: Random seed.
88
+
89
+ Returns:
90
+ Array of "Positive"/"Negative" assignments for each datapoint.
91
+ """
92
+ self.mcmc = self.fit_model(data, seed)
93
+ self.samples = self.mcmc.get_samples()
94
+ self.assignments = self.assignment(self.samples, data)
95
+ return self.assignments
96
+
97
+ def mixture_model(self, data: jnp.ndarray) -> None:
98
+ """Define mixture model structure for NumPyro.
99
+
100
+ Args:
101
+ data: Input data array.
102
+ """
103
+ params = self.initialize_params()
104
+
105
+ with numpyro.plate("data", data.shape[0]):
106
+ log_likelihoods = self.log_likelihood(data, params)
107
+ log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
108
+ numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)
109
+
110
+ def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
111
+ """Assign data points to mixture components.
112
+
113
+ Args:
114
+ samples: MCMC samples of parameters.
115
+ data: Input data array.
116
+
117
+ Returns:
118
+ Array of component assignments.
119
+ """
120
+ params = {key: samples[key].mean(axis=0) for key in samples.keys()}
121
+ self.params = params
122
+
123
+ log_likelihoods = self.log_likelihood(data, params)
124
+ guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
125
+
126
+ assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
127
+ return np.array(assignments)
128
+
129
+
130
+ class PoissonGaussMixture(MixtureModel):
131
+ """Mixture model combining Poisson and Gaussian distributions."""
132
+
133
+ def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
134
+ """Calculate component-wise log likelihoods.
135
+
136
+ Args:
137
+ data: Input data array.
138
+ params: Current parameter values.
139
+
140
+ Returns:
141
+ Log likelihood values for each component.
142
+ """
143
+ poisson_rate = params["poisson_rate"]
144
+ gaussian_mean = params["gaussian_mean"]
145
+ gaussian_std = params["gaussian_std"]
146
+ mix_probs = params["mix_probs"]
147
+
148
+ # We penalize the model for positioning the Poisson component to the right of the Gaussian component
149
+ # by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
150
+ # Heuristic regularization term to prevent flipping of the components
151
+ numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
152
+
153
+ log_likelihoods = jnp.stack(
154
+ [
155
+ # Poisson component
156
+ jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
157
+ # Gaussian component
158
+ jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
159
+ ],
160
+ axis=-1,
161
+ )
162
+
163
+ return log_likelihoods
164
+
165
+ def initialize_params(self) -> ParamsDict:
166
+ """Initialize model parameters via prior sampling.
167
+
168
+ Returns:
169
+ Dictionary of sampled parameter values.
170
+ """
171
+ params = {}
172
+ params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
173
+ params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
174
+ params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
175
+ params["mix_probs"] = numpyro.sample(
176
+ "mix_probs",
177
+ dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
178
+ )
179
+ return params
pertpy/tools/_augur.py CHANGED
@@ -685,7 +685,7 @@ class Augur:
685
685
  span: float = 0.75,
686
686
  filter_negative_residuals: bool = False,
687
687
  n_threads: int = 4,
688
- augur_mode: Literal["permute"] | Literal["default"] | Literal["velocity"] = "default",
688
+ augur_mode: Literal["default", "permute", "velocity"] = "default",
689
689
  select_variance_features: bool = True,
690
690
  key_added: str = "augurpy_results",
691
691
  random_state: int | None = None,
@@ -908,41 +908,39 @@ class Augur:
908
908
  .mean()
909
909
  )
910
910
 
911
- sampled_permuted_cv_augur1 = []
912
- sampled_permuted_cv_augur2 = []
911
+ rng = np.random.default_rng()
912
+ sampled_data = []
913
913
 
914
914
  # draw mean aucs for permute1 and permute2
915
915
  for celltype in permuted_cv_augur1["cell_type"].unique():
916
916
  df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
917
917
  df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
918
- for permutation_idx in range(n_permutations):
919
- # subsample
920
- sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
921
- sampled_permuted_cv_augur1.append(
922
- pd.DataFrame(
923
- {
924
- "cell_type": [celltype],
925
- "permutation_idx": [permutation_idx],
926
- "mean": [sample1["augur_score"].mean(axis=0)],
927
- "std": [sample1["augur_score"].std(axis=0)],
928
- }
929
- )
930
- )
931
918
 
932
- sample2 = df2.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
933
- sampled_permuted_cv_augur2.append(
934
- pd.DataFrame(
935
- {
936
- "cell_type": [celltype],
937
- "permutation_idx": [permutation_idx],
938
- "mean": [sample2["augur_score"].mean(axis=0)],
939
- "std": [sample2["augur_score"].std(axis=0)],
940
- }
941
- )
919
+ indices1 = rng.choice(len(df1), size=(n_permutations, n_subsamples), replace=True)
920
+ indices2 = rng.choice(len(df2), size=(n_permutations, n_subsamples), replace=True)
921
+
922
+ scores1 = df1["augur_score"].values[indices1]
923
+ scores2 = df2["augur_score"].values[indices2]
924
+
925
+ means1 = scores1.mean(axis=1)
926
+ means2 = scores2.mean(axis=1)
927
+ stds1 = scores1.std(axis=1)
928
+ stds2 = scores2.std(axis=1)
929
+
930
+ sampled_data.append(
931
+ pd.DataFrame(
932
+ {
933
+ "cell_type": np.repeat(celltype, n_permutations),
934
+ "permutation_idx": np.arange(n_permutations),
935
+ "mean1": means1,
936
+ "mean2": means2,
937
+ "std1": stds1,
938
+ "std2": stds2,
939
+ }
942
940
  )
941
+ )
943
942
 
944
- permuted_samples1 = pd.concat(sampled_permuted_cv_augur1)
945
- permuted_samples2 = pd.concat(sampled_permuted_cv_augur2)
943
+ sampled_df = pd.concat(sampled_data)
946
944
 
947
945
  # delta between augur scores
948
946
  delta = augur_score1.merge(augur_score2, on=["cell_type"], suffixes=("1", "2")).assign(
@@ -950,9 +948,7 @@ class Augur:
950
948
  )
951
949
 
952
950
  # delta between permutation scores
953
- delta_rnd = permuted_samples1.merge(
954
- permuted_samples2, on=["cell_type", "permutation_idx"], suffixes=("1", "2")
955
- ).assign(delta_rnd=lambda x: x.mean2 - x.mean1)
951
+ delta_rnd = sampled_df.assign(delta_rnd=lambda x: x.mean2 - x.mean1)
956
952
 
957
953
  # number of values where permutations are larger than test statistic
958
954
  delta["b"] = (
@@ -967,7 +963,7 @@ class Augur:
967
963
  delta["z"] = (
968
964
  delta["delta_augur"] - delta_rnd.groupby("cell_type", as_index=False).mean()["delta_rnd"]
969
965
  ) / delta_rnd.groupby("cell_type", as_index=False).std()["delta_rnd"]
970
- # calculate pvalues
966
+
971
967
  delta["pval"] = np.minimum(
972
968
  2 * (delta["b"] + 1) / (delta["m"] + 1), 2 * (delta["m"] - delta["b"] + 1) / (delta["m"] + 1)
973
969
  )
@@ -982,7 +978,6 @@ class Augur:
982
978
  *,
983
979
  top_n: int = None,
984
980
  ax: Axes = None,
985
- show: bool = True,
986
981
  return_fig: bool = False,
987
982
  ) -> Figure | None:
988
983
  """Plot scatterplot of differential prioritization.
@@ -1041,10 +1036,9 @@ class Augur:
1041
1036
  legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
1042
1037
  ax.add_artist(legend1)
1043
1038
 
1044
- if show:
1045
- plt.show()
1046
1039
  if return_fig:
1047
1040
  return plt.gcf()
1041
+ plt.show()
1048
1042
  return None
1049
1043
 
1050
1044
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1055,7 +1049,6 @@ class Augur:
1055
1049
  key: str = "augurpy_results",
1056
1050
  top_n: int = 10,
1057
1051
  ax: Axes = None,
1058
- show: bool = True,
1059
1052
  return_fig: bool = False,
1060
1053
  ) -> Figure | None:
1061
1054
  """Plot a lollipop plot of the n features with largest feature importances.
@@ -1109,10 +1102,9 @@ class Augur:
1109
1102
  plt.ylabel("Gene")
1110
1103
  plt.yticks(y_axes_range, n_features["genes"])
1111
1104
 
1112
- if show:
1113
- plt.show()
1114
1105
  if return_fig:
1115
1106
  return plt.gcf()
1107
+ plt.show()
1116
1108
  return None
1117
1109
 
1118
1110
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1122,7 +1114,6 @@ class Augur:
1122
1114
  *,
1123
1115
  key: str = "augurpy_results",
1124
1116
  ax: Axes = None,
1125
- show: bool = True,
1126
1117
  return_fig: bool = False,
1127
1118
  ) -> Figure | None:
1128
1119
  """Plot a lollipop plot of the mean augur values.
@@ -1172,10 +1163,9 @@ class Augur:
1172
1163
  plt.ylabel("Cell Type")
1173
1164
  plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
1174
1165
 
1175
- if show:
1176
- plt.show()
1177
1166
  if return_fig:
1178
1167
  return plt.gcf()
1168
+ plt.show()
1179
1169
  return None
1180
1170
 
1181
1171
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1185,7 +1175,6 @@ class Augur:
1185
1175
  results2: dict[str, Any],
1186
1176
  *,
1187
1177
  top_n: int = None,
1188
- show: bool = True,
1189
1178
  return_fig: bool = False,
1190
1179
  ) -> Figure | None:
1191
1180
  """Create scatterplot with two augur results.
@@ -1243,8 +1232,7 @@ class Augur:
1243
1232
  plt.xlabel("Augur scores 1")
1244
1233
  plt.ylabel("Augur scores 2")
1245
1234
 
1246
- if show:
1247
- plt.show()
1248
1235
  if return_fig:
1249
1236
  return plt.gcf()
1237
+ plt.show()
1250
1238
  return None
pertpy/tools/_cinemaot.py CHANGED
@@ -658,7 +658,6 @@ class Cinemaot:
658
658
  title: str = "CINEMA-OT matching matrix",
659
659
  min_val: float = 0.01,
660
660
  ax: Axes | None = None,
661
- show: bool = True,
662
661
  return_fig: bool = False,
663
662
  **kwargs,
664
663
  ) -> Figure | None:
@@ -717,10 +716,9 @@ class Cinemaot:
717
716
  g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
718
717
  plt.title(title)
719
718
 
720
- if show:
721
- plt.show()
722
719
  if return_fig:
723
720
  return g
721
+ plt.show()
724
722
  return None
725
723
 
726
724