pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
pertpy/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = "Lukas Heumos"
4
4
  __email__ = "lukas.heumos@posteo.net"
5
- __version__ = "0.9.5"
5
+ __version__ = "0.10.0"
6
6
 
7
7
  import warnings
8
8
 
pertpy/_doc.py CHANGED
@@ -15,6 +15,5 @@ def _doc_params(**kwds): # pragma: no cover
15
15
 
16
16
 
17
17
  doc_common_plot_args = """\
18
- show: if `True`, shows the plot.
19
- return_fig: if `True`, returns figure of the plot.\
18
+ return_fig: if `True`, returns figure of the plot, that can be used for saving.\
20
19
  """
@@ -703,7 +703,6 @@ class CellLine(MetaData):
703
703
  metadata_key: str = "bulk_rna_broad",
704
704
  category: str = "cell line",
705
705
  subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
706
- show: bool = True,
707
706
  return_fig: bool = False,
708
707
  ) -> Figure | None:
709
708
  """Visualise the correlation of cell lines with annotated metadata.
@@ -747,7 +746,7 @@ class CellLine(MetaData):
747
746
  if all(isinstance(id, str) for id in subset_identifier_list):
748
747
  if set(subset_identifier_list).issubset(adata.obs[identifier].unique()):
749
748
  subset_identifier_list = np.where(
750
- np.in1d(adata.obs[identifier].values, subset_identifier_list)
749
+ np.isin(adata.obs[identifier].values, subset_identifier_list)
751
750
  )[0]
752
751
  else:
753
752
  raise ValueError("`Subset_identifier` must be found in adata.obs.`identifier`.")
@@ -798,10 +797,9 @@ class CellLine(MetaData):
798
797
  },
799
798
  )
800
799
 
801
- if show:
802
- plt.show()
803
800
  if return_fig:
804
801
  return plt.gcf()
802
+ plt.show()
805
803
  return None
806
804
  else:
807
- raise NotImplementedError
805
+ raise NotImplementedError("Only 'cell line' category is supported for correlation comparison.")
@@ -1,15 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import uuid
4
- from typing import TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Literal
5
+ from warnings import warn
5
6
 
6
7
  import matplotlib.pyplot as plt
7
8
  import numpy as np
8
9
  import pandas as pd
9
10
  import scanpy as sc
10
11
  import scipy
12
+ from rich.progress import track
13
+ from scipy.sparse import issparse
11
14
 
12
15
  from pertpy._doc import _doc_params, doc_common_plot_args
16
+ from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
13
17
 
14
18
  if TYPE_CHECKING:
15
19
  from anndata import AnnData
@@ -17,7 +21,7 @@ if TYPE_CHECKING:
17
21
 
18
22
 
19
23
  class GuideAssignment:
20
- """Offers simple guide assigment based on count thresholds."""
24
+ """Assign cells to guide RNAs."""
21
25
 
22
26
  def assign_by_threshold(
23
27
  self,
@@ -33,12 +37,12 @@ class GuideAssignment:
33
37
  This function expects unnormalized data as input.
34
38
 
35
39
  Args:
36
- adata: Annotated data matrix containing gRNA values
40
+ adata: AnnData object containing gRNA values.
37
41
  assignment_threshold: The count threshold that is required for an assignment to be viable.
38
42
  layer: Key to the layer containing raw count values of the gRNAs.
39
43
  adata.X is used if layer is None. Expects count data.
40
44
  output_layer: Assigned guide will be saved on adata.layers[output_key].
41
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
45
+ only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
42
46
 
43
47
  Examples:
44
48
  Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
@@ -67,7 +71,7 @@ class GuideAssignment:
67
71
  assignment_threshold: float,
68
72
  layer: str | None = None,
69
73
  output_key: str = "assigned_guide",
70
- no_grna_assigned_key: str = "NT",
74
+ no_grna_assigned_key: str = "Negative",
71
75
  only_return_results: bool = False,
72
76
  ) -> np.ndarray | None:
73
77
  """Simple threshold based max gRNA assignment function.
@@ -76,13 +80,13 @@ class GuideAssignment:
76
80
  This function expects unnormalized data as input.
77
81
 
78
82
  Args:
79
- adata: Annotated data matrix containing gRNA values
83
+ adata: AnnData object containing gRNA values.
80
84
  assignment_threshold: The count threshold that is required for an assignment to be viable.
81
85
  layer: Key to the layer containing raw count values of the gRNAs.
82
86
  adata.X is used if layer is None. Expects count data.
83
87
  output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
84
88
  no_grna_assigned_key: The key to return if no gRNA is expressed enough.
85
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
89
+ only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
86
90
 
87
91
  Examples:
88
92
  Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
@@ -109,6 +113,92 @@ class GuideAssignment:
109
113
 
110
114
  return None
111
115
 
116
+ def assign_mixture_model(
117
+ self,
118
+ adata: AnnData,
119
+ model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
120
+ assigned_guides_key: str = "assigned_guide",
121
+ no_grna_assigned_key: str = "negative",
122
+ max_assignments_per_cell: int = 5,
123
+ multiple_grna_assigned_key: str = "multiple",
124
+ multiple_grna_assignment_string: str = "+",
125
+ only_return_results: bool = False,
126
+ uns_key: str = "guide_assignment_params",
127
+ show_progress: bool = False,
128
+ **mixture_model_kwargs,
129
+ ) -> np.ndarray | None:
130
+ """Assigns gRNAs to cells using a mixture model.
131
+
132
+ Args:
133
+ adata: AnnData object containing gRNA values.
134
+ model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
135
+ output_key: Assigned guide will be saved on adata.obs[output_key].
136
+ no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
137
+ max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
138
+ multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
139
+ multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
140
+ only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
141
+ show_progress: Whether to shows progress bar.
142
+ mixture_model_kwargs: Are passed to the mixture model.
143
+
144
+ Examples:
145
+ >>> import pertpy as pt
146
+ >>> mdata = pt.dt.papalexi_2021()
147
+ >>> gdo = mdata.mod["gdo"]
148
+ >>> ga = pt.pp.GuideAssignment()
149
+ >>> ga.assign_mixture_model(gdo)
150
+ """
151
+ if model == "poisson_gauss_mixture":
152
+ mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
153
+ else:
154
+ raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
155
+
156
+ if uns_key not in adata.uns:
157
+ adata.uns[uns_key] = {}
158
+ elif type(adata.uns[uns_key]) is not dict:
159
+ raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")
160
+
161
+ res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
162
+ fct = track if show_progress else lambda iterable: iterable
163
+ for gene in fct(adata.var_names):
164
+ is_nonzero = (
165
+ np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
166
+ )
167
+ if sum(is_nonzero) < 2:
168
+ warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
169
+ continue
170
+ # We are only fitting the model to the non-zero values, the rest is
171
+ # automatically assigned to the negative class
172
+ data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
173
+ data = np.ravel(data)
174
+
175
+ if np.any(data < 0):
176
+ raise ValueError(
177
+ "Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
178
+ )
179
+
180
+ # Log2 transform the data so positive population is approximately normal
181
+ data = np.log2(data)
182
+ assignments = mixture_model.run_model(data)
183
+ res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
184
+ adata.uns[uns_key][gene] = mixture_model.params
185
+
186
+ # Assign guides to cells
187
+ # Some cells might have multiple guides assigned
188
+ series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
189
+ num_guides_assigned = res.sum(1)
190
+ series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
191
+ lambda row: row.index[row == 1].tolist(), axis=1
192
+ ).str.join(multiple_grna_assignment_string)
193
+ series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
194
+
195
+ if only_return_results:
196
+ return series.values
197
+
198
+ adata.obs[assigned_guides_key] = series.values
199
+
200
+ return None
201
+
112
202
  @_doc_params(common_plot_args=doc_common_plot_args)
113
203
  def plot_heatmap(
114
204
  self,
@@ -117,7 +207,6 @@ class GuideAssignment:
117
207
  layer: str | None = None,
118
208
  order_by: np.ndarray | str | None = None,
119
209
  key_to_save_order: str = None,
120
- show: bool = True,
121
210
  return_fig: bool = False,
122
211
  **kwargs,
123
212
  ) -> Figure | None:
@@ -194,8 +283,7 @@ class GuideAssignment:
194
283
  finally:
195
284
  del adata.obs[temp_col_name]
196
285
 
197
- if show:
198
- plt.show()
199
286
  if return_fig:
200
287
  return fig
288
+ plt.show()
201
289
  return None
@@ -0,0 +1,179 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Mapping
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import numpyro
10
+ import numpyro.distributions as dist
11
+ from jax import random
12
+ from numpyro.infer import MCMC, NUTS
13
+
14
+ ParamsDict = Mapping[str, jnp.ndarray]
15
+
16
+
17
+ class MixtureModel(ABC):
18
+ """Abstract base class for 2-component mixture models.
19
+
20
+ Args:
21
+ num_warmup: Number of warmup steps for MCMC sampling.
22
+ num_samples: Number of samples to draw after warmup.
23
+ fraction_positive_expected: Prior belief about fraction of positive components.
24
+ poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
25
+ gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
26
+ gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ num_warmup: int = 50,
32
+ num_samples: int = 100,
33
+ fraction_positive_expected: float = 0.15,
34
+ poisson_rate_prior: float = 0.2,
35
+ gaussian_mean_prior: tuple[float, float] = (3, 2),
36
+ gaussian_std_prior: float = 1,
37
+ ) -> None:
38
+ self.num_warmup = num_warmup
39
+ self.num_samples = num_samples
40
+ self.fraction_positive_expected = fraction_positive_expected
41
+ self.poisson_rate_prior = poisson_rate_prior
42
+ self.gaussian_mean_prior = gaussian_mean_prior
43
+ self.gaussian_std_prior = gaussian_std_prior
44
+
45
+ @abstractmethod
46
+ def initialize_params(self) -> ParamsDict:
47
+ """Initialize model parameters via sampling from priors.
48
+
49
+ Returns:
50
+ Dictionary of sampled parameter values.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
56
+ """Calculate log likelihood of data under current parameters.
57
+
58
+ Args:
59
+ data: Input data array.
60
+ params: Current parameter values.
61
+
62
+ Returns:
63
+ Log likelihood values for each datapoint.
64
+ """
65
+ pass
66
+
67
+ def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
68
+ """Fit the mixture model using MCMC.
69
+
70
+ Args:
71
+ data: Input data to fit.
72
+ seed: Random seed for reproducibility.
73
+
74
+ Returns:
75
+ Fitted MCMC object containing samples.
76
+ """
77
+ nuts_kernel = NUTS(self.mixture_model)
78
+ mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
79
+ mcmc.run(random.PRNGKey(seed), data=data)
80
+ return mcmc
81
+
82
+ def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
83
+ """Run model fitting and assign components.
84
+
85
+ Args:
86
+ data: Input data array.
87
+ seed: Random seed.
88
+
89
+ Returns:
90
+ Array of "Positive"/"Negative" assignments for each datapoint.
91
+ """
92
+ self.mcmc = self.fit_model(data, seed)
93
+ self.samples = self.mcmc.get_samples()
94
+ self.assignments = self.assignment(self.samples, data)
95
+ return self.assignments
96
+
97
+ def mixture_model(self, data: jnp.ndarray) -> None:
98
+ """Define mixture model structure for NumPyro.
99
+
100
+ Args:
101
+ data: Input data array.
102
+ """
103
+ params = self.initialize_params()
104
+
105
+ with numpyro.plate("data", data.shape[0]):
106
+ log_likelihoods = self.log_likelihood(data, params)
107
+ log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
108
+ numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)
109
+
110
+ def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
111
+ """Assign data points to mixture components.
112
+
113
+ Args:
114
+ samples: MCMC samples of parameters.
115
+ data: Input data array.
116
+
117
+ Returns:
118
+ Array of component assignments.
119
+ """
120
+ params = {key: samples[key].mean(axis=0) for key in samples.keys()}
121
+ self.params = params
122
+
123
+ log_likelihoods = self.log_likelihood(data, params)
124
+ guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
125
+
126
+ assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
127
+ return np.array(assignments)
128
+
129
+
130
+ class PoissonGaussMixture(MixtureModel):
131
+ """Mixture model combining Poisson and Gaussian distributions."""
132
+
133
+ def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
134
+ """Calculate component-wise log likelihoods.
135
+
136
+ Args:
137
+ data: Input data array.
138
+ params: Current parameter values.
139
+
140
+ Returns:
141
+ Log likelihood values for each component.
142
+ """
143
+ poisson_rate = params["poisson_rate"]
144
+ gaussian_mean = params["gaussian_mean"]
145
+ gaussian_std = params["gaussian_std"]
146
+ mix_probs = params["mix_probs"]
147
+
148
+ # We penalize the model for positioning the Poisson component to the right of the Gaussian component
149
+ # by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
150
+ # Heuristic regularization term to prevent flipping of the components
151
+ numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
152
+
153
+ log_likelihoods = jnp.stack(
154
+ [
155
+ # Poisson component
156
+ jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
157
+ # Gaussian component
158
+ jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
159
+ ],
160
+ axis=-1,
161
+ )
162
+
163
+ return log_likelihoods
164
+
165
+ def initialize_params(self) -> ParamsDict:
166
+ """Initialize model parameters via prior sampling.
167
+
168
+ Returns:
169
+ Dictionary of sampled parameter values.
170
+ """
171
+ params = {}
172
+ params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
173
+ params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
174
+ params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
175
+ params["mix_probs"] = numpyro.sample(
176
+ "mix_probs",
177
+ dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
178
+ )
179
+ return params
pertpy/tools/_augur.py CHANGED
@@ -685,7 +685,7 @@ class Augur:
685
685
  span: float = 0.75,
686
686
  filter_negative_residuals: bool = False,
687
687
  n_threads: int = 4,
688
- augur_mode: Literal["permute"] | Literal["default"] | Literal["velocity"] = "default",
688
+ augur_mode: Literal["default", "permute", "velocity"] = "default",
689
689
  select_variance_features: bool = True,
690
690
  key_added: str = "augurpy_results",
691
691
  random_state: int | None = None,
@@ -908,41 +908,39 @@ class Augur:
908
908
  .mean()
909
909
  )
910
910
 
911
- sampled_permuted_cv_augur1 = []
912
- sampled_permuted_cv_augur2 = []
911
+ rng = np.random.default_rng()
912
+ sampled_data = []
913
913
 
914
914
  # draw mean aucs for permute1 and permute2
915
915
  for celltype in permuted_cv_augur1["cell_type"].unique():
916
916
  df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
917
917
  df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
918
- for permutation_idx in range(n_permutations):
919
- # subsample
920
- sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
921
- sampled_permuted_cv_augur1.append(
922
- pd.DataFrame(
923
- {
924
- "cell_type": [celltype],
925
- "permutation_idx": [permutation_idx],
926
- "mean": [sample1["augur_score"].mean(axis=0)],
927
- "std": [sample1["augur_score"].std(axis=0)],
928
- }
929
- )
930
- )
931
918
 
932
- sample2 = df2.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
933
- sampled_permuted_cv_augur2.append(
934
- pd.DataFrame(
935
- {
936
- "cell_type": [celltype],
937
- "permutation_idx": [permutation_idx],
938
- "mean": [sample2["augur_score"].mean(axis=0)],
939
- "std": [sample2["augur_score"].std(axis=0)],
940
- }
941
- )
919
+ indices1 = rng.choice(len(df1), size=(n_permutations, n_subsamples), replace=True)
920
+ indices2 = rng.choice(len(df2), size=(n_permutations, n_subsamples), replace=True)
921
+
922
+ scores1 = df1["augur_score"].values[indices1]
923
+ scores2 = df2["augur_score"].values[indices2]
924
+
925
+ means1 = scores1.mean(axis=1)
926
+ means2 = scores2.mean(axis=1)
927
+ stds1 = scores1.std(axis=1)
928
+ stds2 = scores2.std(axis=1)
929
+
930
+ sampled_data.append(
931
+ pd.DataFrame(
932
+ {
933
+ "cell_type": np.repeat(celltype, n_permutations),
934
+ "permutation_idx": np.arange(n_permutations),
935
+ "mean1": means1,
936
+ "mean2": means2,
937
+ "std1": stds1,
938
+ "std2": stds2,
939
+ }
942
940
  )
941
+ )
943
942
 
944
- permuted_samples1 = pd.concat(sampled_permuted_cv_augur1)
945
- permuted_samples2 = pd.concat(sampled_permuted_cv_augur2)
943
+ sampled_df = pd.concat(sampled_data)
946
944
 
947
945
  # delta between augur scores
948
946
  delta = augur_score1.merge(augur_score2, on=["cell_type"], suffixes=("1", "2")).assign(
@@ -950,9 +948,7 @@ class Augur:
950
948
  )
951
949
 
952
950
  # delta between permutation scores
953
- delta_rnd = permuted_samples1.merge(
954
- permuted_samples2, on=["cell_type", "permutation_idx"], suffixes=("1", "2")
955
- ).assign(delta_rnd=lambda x: x.mean2 - x.mean1)
951
+ delta_rnd = sampled_df.assign(delta_rnd=lambda x: x.mean2 - x.mean1)
956
952
 
957
953
  # number of values where permutations are larger than test statistic
958
954
  delta["b"] = (
@@ -967,7 +963,7 @@ class Augur:
967
963
  delta["z"] = (
968
964
  delta["delta_augur"] - delta_rnd.groupby("cell_type", as_index=False).mean()["delta_rnd"]
969
965
  ) / delta_rnd.groupby("cell_type", as_index=False).std()["delta_rnd"]
970
- # calculate pvalues
966
+
971
967
  delta["pval"] = np.minimum(
972
968
  2 * (delta["b"] + 1) / (delta["m"] + 1), 2 * (delta["m"] - delta["b"] + 1) / (delta["m"] + 1)
973
969
  )
@@ -982,7 +978,6 @@ class Augur:
982
978
  *,
983
979
  top_n: int = None,
984
980
  ax: Axes = None,
985
- show: bool = True,
986
981
  return_fig: bool = False,
987
982
  ) -> Figure | None:
988
983
  """Plot scatterplot of differential prioritization.
@@ -1041,10 +1036,9 @@ class Augur:
1041
1036
  legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
1042
1037
  ax.add_artist(legend1)
1043
1038
 
1044
- if show:
1045
- plt.show()
1046
1039
  if return_fig:
1047
1040
  return plt.gcf()
1041
+ plt.show()
1048
1042
  return None
1049
1043
 
1050
1044
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1055,7 +1049,6 @@ class Augur:
1055
1049
  key: str = "augurpy_results",
1056
1050
  top_n: int = 10,
1057
1051
  ax: Axes = None,
1058
- show: bool = True,
1059
1052
  return_fig: bool = False,
1060
1053
  ) -> Figure | None:
1061
1054
  """Plot a lollipop plot of the n features with largest feature importances.
@@ -1109,10 +1102,9 @@ class Augur:
1109
1102
  plt.ylabel("Gene")
1110
1103
  plt.yticks(y_axes_range, n_features["genes"])
1111
1104
 
1112
- if show:
1113
- plt.show()
1114
1105
  if return_fig:
1115
1106
  return plt.gcf()
1107
+ plt.show()
1116
1108
  return None
1117
1109
 
1118
1110
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1122,7 +1114,6 @@ class Augur:
1122
1114
  *,
1123
1115
  key: str = "augurpy_results",
1124
1116
  ax: Axes = None,
1125
- show: bool = True,
1126
1117
  return_fig: bool = False,
1127
1118
  ) -> Figure | None:
1128
1119
  """Plot a lollipop plot of the mean augur values.
@@ -1172,10 +1163,9 @@ class Augur:
1172
1163
  plt.ylabel("Cell Type")
1173
1164
  plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
1174
1165
 
1175
- if show:
1176
- plt.show()
1177
1166
  if return_fig:
1178
1167
  return plt.gcf()
1168
+ plt.show()
1179
1169
  return None
1180
1170
 
1181
1171
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1185,7 +1175,6 @@ class Augur:
1185
1175
  results2: dict[str, Any],
1186
1176
  *,
1187
1177
  top_n: int = None,
1188
- show: bool = True,
1189
1178
  return_fig: bool = False,
1190
1179
  ) -> Figure | None:
1191
1180
  """Create scatterplot with two augur results.
@@ -1243,8 +1232,7 @@ class Augur:
1243
1232
  plt.xlabel("Augur scores 1")
1244
1233
  plt.ylabel("Augur scores 2")
1245
1234
 
1246
- if show:
1247
- plt.show()
1248
1235
  if return_fig:
1249
1236
  return plt.gcf()
1237
+ plt.show()
1250
1238
  return None
pertpy/tools/_cinemaot.py CHANGED
@@ -658,7 +658,6 @@ class Cinemaot:
658
658
  title: str = "CINEMA-OT matching matrix",
659
659
  min_val: float = 0.01,
660
660
  ax: Axes | None = None,
661
- show: bool = True,
662
661
  return_fig: bool = False,
663
662
  **kwargs,
664
663
  ) -> Figure | None:
@@ -717,10 +716,9 @@ class Cinemaot:
717
716
  g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
718
717
  plt.title(title)
719
718
 
720
- if show:
721
- plt.show()
722
719
  if return_fig:
723
720
  return g
721
+ plt.show()
724
722
  return None
725
723
 
726
724