pertpy 0.9.4__py3-none-any.whl → 0.9.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
pertpy/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = "Lukas Heumos"
4
4
  __email__ = "lukas.heumos@posteo.net"
5
- __version__ = "0.9.4"
5
+ __version__ = "0.9.5"
6
6
 
7
7
  import warnings
8
8
 
pertpy/_doc.py ADDED
@@ -0,0 +1,20 @@
1
+ from textwrap import dedent
2
+
3
+
4
+ def _doc_params(**kwds): # pragma: no cover
5
+ """\
6
+ Docstrings should start with "\" in the first line for proper formatting.
7
+ """
8
+
9
+ def dec(obj):
10
+ obj.__orig_doc__ = obj.__doc__
11
+ obj.__doc__ = dedent(obj.__doc__.format_map(kwds))
12
+ return obj
13
+
14
+ return dec
15
+
16
+
17
+ doc_common_plot_args = """\
18
+ show: if `True`, shows the plot.
19
+ return_fig: if `True`, returns figure of the plot.\
20
+ """
pertpy/data/_datasets.py CHANGED
@@ -66,7 +66,7 @@ def sc_sim_augur() -> AnnData: # pragma: no cover
66
66
  output_file_path = settings.datasetdir / output_file_name
67
67
  if not Path(output_file_path).exists():
68
68
  _download(
69
- url="https://figshare.com/ndownloader/files/31645886",
69
+ url="https://figshare.com/ndownloader/files/49828902",
70
70
  output_file_name=output_file_name,
71
71
  output_path=settings.datasetdir,
72
72
  is_zip=False,
@@ -8,12 +8,15 @@ from lamin_utils import logger
8
8
  if TYPE_CHECKING:
9
9
  from collections.abc import Iterable
10
10
 
11
+ from matplotlib.pyplot import Figure
12
+
11
13
  import matplotlib.pyplot as plt
12
14
  import numpy as np
13
15
  import pandas as pd
14
16
  from scanpy import settings
15
17
  from scipy import stats
16
18
 
19
+ from pertpy._doc import _doc_params, doc_common_plot_args
17
20
  from pertpy.data._dataloader import _download
18
21
 
19
22
  from ._look_up import LookUp
@@ -338,8 +341,8 @@ class CellLine(MetaData):
338
341
  # then we can compare these keys and fetch the corresponding metadata.
339
342
  if query_id not in adata.obs.columns:
340
343
  raise ValueError(
341
- f"The specified `query_id` {query_id} can't be found in the `adata.obs`.\n"
342
- "Ensure that you are using one of the available query IDs present in the adata.obs for the annotation.\n"
344
+ f"The specified `query_id` {query_id} can't be found in the `adata.obs`. \n"
345
+ "Ensure that you are using one of the available query IDs present in the adata.obs for the annotation."
343
346
  "If the desired query ID is not available, you can fetch the cell line metadata "
344
347
  "using the `annotate()` function before calling 'annotate_bulk_rna()'. "
345
348
  "This ensures that the required query ID is included in your data, e.g. stripped_cell_line_name, DepMap ID."
@@ -356,9 +359,8 @@ class CellLine(MetaData):
356
359
  else:
357
360
  reference_id = "DepMap_ID"
358
361
  logger.warning(
359
- "To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given.\n"
360
- "Ensure that `DepMap_ID` is available in 'adata.obs'.\n"
361
- "Alternatively, use `annotate()` to annotate the cell line first "
362
+ "To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given."
363
+ "If `DepMap_ID` isn't available in 'adata.obs', use `annotate()` to annotate the cell line first."
362
364
  )
363
365
  if self.bulk_rna_broad is None:
364
366
  self._download_bulk_rna(cell_line_source="broad")
@@ -690,6 +692,7 @@ class CellLine(MetaData):
690
692
 
691
693
  return corr, pvals, new_corr, new_pvals
692
694
 
695
+ @_doc_params(common_plot_args=doc_common_plot_args)
693
696
  def plot_correlation(
694
697
  self,
695
698
  adata: AnnData,
@@ -700,7 +703,9 @@ class CellLine(MetaData):
700
703
  metadata_key: str = "bulk_rna_broad",
701
704
  category: str = "cell line",
702
705
  subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
703
- ) -> None:
706
+ show: bool = True,
707
+ return_fig: bool = False,
708
+ ) -> Figure | None:
704
709
  """Visualise the correlation of cell lines with annotated metadata.
705
710
 
706
711
  Args:
@@ -713,6 +718,8 @@ class CellLine(MetaData):
713
718
  subset_identifier: Selected identifiers for scatter plot visualization between the X matrix and `metadata_key`.
714
719
  If not None, only the chosen cell line will be plotted, either specified as a value in `identifier` (string) or as an index number.
715
720
  If None, all cell lines will be plotted.
721
+ {common_plot_args}
722
+
716
723
  Returns:
717
724
  Pearson correlation coefficients and their corresponding p-values for matched and unmatched cell lines separately.
718
725
  """
@@ -790,6 +797,11 @@ class CellLine(MetaData):
790
797
  "edgecolor": "black",
791
798
  },
792
799
  )
793
- plt.show()
800
+
801
+ if show:
802
+ plt.show()
803
+ if return_fig:
804
+ return plt.gcf()
805
+ return None
794
806
  else:
795
807
  raise NotImplementedError
@@ -42,7 +42,7 @@ class Compound(MetaData):
42
42
  adata = adata.copy()
43
43
 
44
44
  if query_id not in adata.obs.columns:
45
- raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n" f"Please check again. ")
45
+ raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n Please check again.")
46
46
 
47
47
  query_dict = {}
48
48
  not_matched_identifiers = []
@@ -84,7 +84,7 @@ class Compound(MetaData):
84
84
  query_df = pd.DataFrame.from_dict(query_dict, orient="index", columns=["pubchem_name", "pubchem_ID", "smiles"])
85
85
  # Merge and remove duplicate columns
86
86
  # Column is converted to float after merging due to unmatches
87
- # Convert back to integers
87
+ # Convert back to integers afterwards
88
88
  if query_id_type == "cid":
89
89
  query_df.pubchem_ID = query_df.pubchem_ID.astype("Int64")
90
90
  adata.obs = (
@@ -119,8 +119,7 @@ class Compound(MetaData):
119
119
 
120
120
  The LookUp object provides an overview of the metadata to annotate.
121
121
  Each annotate_{metadata} function has a corresponding lookup function in the LookUp object,
122
- where users can search the reference_id in the metadata and
123
- compare with the query_id in their own data.
122
+ where users can search the reference_id in the metadata and compare with the query_id in their own data.
124
123
 
125
124
  Returns:
126
125
  Returns a LookUp object specific for compound annotation.
@@ -62,7 +62,7 @@ class MetaData:
62
62
  if verbosity > 0:
63
63
  logger.info(
64
64
  f"There are {total_identifiers} identifiers in `adata.obs`."
65
- f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation,"
65
+ f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation, "
66
66
  "leading to the presence of NA values for their respective metadata.\n"
67
67
  f"Please check again: *unmatched_identifiers[:verbosity]..."
68
68
  )
@@ -3,14 +3,17 @@ from __future__ import annotations
3
3
  import uuid
4
4
  from typing import TYPE_CHECKING
5
5
 
6
+ import matplotlib.pyplot as plt
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  import scanpy as sc
9
10
  import scipy
10
11
 
12
+ from pertpy._doc import _doc_params, doc_common_plot_args
13
+
11
14
  if TYPE_CHECKING:
12
15
  from anndata import AnnData
13
- from matplotlib.axes import Axes
16
+ from matplotlib.pyplot import Figure
14
17
 
15
18
 
16
19
  class GuideAssignment:
@@ -106,14 +109,18 @@ class GuideAssignment:
106
109
 
107
110
  return None
108
111
 
112
+ @_doc_params(common_plot_args=doc_common_plot_args)
109
113
  def plot_heatmap(
110
114
  self,
111
115
  adata: AnnData,
116
+ *,
112
117
  layer: str | None = None,
113
118
  order_by: np.ndarray | str | None = None,
114
119
  key_to_save_order: str = None,
120
+ show: bool = True,
121
+ return_fig: bool = False,
115
122
  **kwargs,
116
- ) -> list[Axes]:
123
+ ) -> Figure | None:
117
124
  """Heatmap plotting of guide RNA expression matrix.
118
125
 
119
126
  Assuming guides have sparse expression, this function reorders cells
@@ -131,11 +138,12 @@ class GuideAssignment:
131
138
  If a string is provided, adata.obs[order_by] will be used as the order.
132
139
  If a numpy array is provided, the array will be used for ordering.
133
140
  key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
141
+ {common_plot_args}
134
142
  kwargs: Are passed to sc.pl.heatmap.
135
143
 
136
144
  Returns:
137
- List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
138
- Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
145
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
146
+ Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided.
139
147
 
140
148
  Examples:
141
149
  Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
@@ -172,7 +180,7 @@ class GuideAssignment:
172
180
  adata.obs[key_to_save_order] = pd.Categorical(order)
173
181
 
174
182
  try:
175
- axis_group = sc.pl.heatmap(
183
+ fig = sc.pl.heatmap(
176
184
  adata[order, :],
177
185
  var_names=adata.var.index.tolist(),
178
186
  groupby=temp_col_name,
@@ -180,9 +188,14 @@ class GuideAssignment:
180
188
  use_raw=False,
181
189
  dendrogram=False,
182
190
  layer=layer,
191
+ show=False,
183
192
  **kwargs,
184
193
  )
185
194
  finally:
186
195
  del adata.obs[temp_col_name]
187
196
 
188
- return axis_group
197
+ if show:
198
+ plt.show()
199
+ if return_fig:
200
+ return fig
201
+ return None
pertpy/tools/__init__.py CHANGED
@@ -46,7 +46,7 @@ Sccoda = lazy_import("pertpy.tools._coda._sccoda", "Sccoda", CODA_EXTRAS)
46
46
  Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
47
47
 
48
48
  DE_EXTRAS = ["formulaic", "pydeseq2"]
49
- EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
49
+ EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
50
50
  PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
51
51
  Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
52
52
  TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)
pertpy/tools/_augur.py CHANGED
@@ -15,7 +15,6 @@ import statsmodels.api as sm
15
15
  from anndata import AnnData
16
16
  from joblib import Parallel, delayed
17
17
  from lamin_utils import logger
18
- from rich import print
19
18
  from rich.progress import track
20
19
  from scipy import sparse, stats
21
20
  from sklearn.base import is_classifier, is_regressor
@@ -26,17 +25,19 @@ from sklearn.metrics import (
26
25
  explained_variance_score,
27
26
  f1_score,
28
27
  make_scorer,
29
- mean_squared_error,
30
28
  precision_score,
31
29
  r2_score,
32
30
  recall_score,
33
31
  roc_auc_score,
32
+ root_mean_squared_error,
34
33
  )
35
34
  from sklearn.model_selection import StratifiedKFold, cross_validate
36
35
  from sklearn.preprocessing import LabelEncoder
37
36
  from skmisc.loess import loess
38
37
  from statsmodels.stats.multitest import fdrcorrection
39
38
 
39
+ from pertpy._doc import _doc_params, doc_common_plot_args
40
+
40
41
  if TYPE_CHECKING:
41
42
  from matplotlib.axes import Axes
42
43
  from matplotlib.figure import Figure
@@ -439,7 +440,7 @@ class Augur:
439
440
  "augur_score": make_scorer(self.ccc_score),
440
441
  "r2": make_scorer(r2_score),
441
442
  "ccc": make_scorer(self.ccc_score),
442
- "neg_mean_squared_error": make_scorer(mean_squared_error),
443
+ "neg_mean_squared_error": make_scorer(root_mean_squared_error),
443
444
  "explained_variance": make_scorer(explained_variance_score),
444
445
  }
445
446
  )
@@ -974,24 +975,26 @@ class Augur:
974
975
 
975
976
  return delta
976
977
 
978
+ @_doc_params(common_plot_args=doc_common_plot_args)
977
979
  def plot_dp_scatter(
978
980
  self,
979
981
  results: pd.DataFrame,
982
+ *,
980
983
  top_n: int = None,
981
- return_fig: bool | None = None,
982
984
  ax: Axes = None,
983
- show: bool | None = None,
984
- save: str | bool | None = None,
985
- ) -> Axes | Figure | None:
985
+ show: bool = True,
986
+ return_fig: bool = False,
987
+ ) -> Figure | None:
986
988
  """Plot scatterplot of differential prioritization.
987
989
 
988
990
  Args:
989
991
  results: Results after running differential prioritization.
990
992
  top_n: optionally, the number of top prioritized cell types to label in the plot
991
993
  ax: optionally, axes used to draw plot
994
+ {common_plot_args}
992
995
 
993
996
  Returns:
994
- Axes of the plot.
997
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
995
998
 
996
999
  Examples:
997
1000
  >>> import pertpy as pt
@@ -1038,37 +1041,34 @@ class Augur:
1038
1041
  legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
1039
1042
  ax.add_artist(legend1)
1040
1043
 
1041
- if save:
1042
- plt.savefig(save, bbox_inches="tight")
1043
1044
  if show:
1044
1045
  plt.show()
1045
1046
  if return_fig:
1046
1047
  return plt.gcf()
1047
- if not (show or save):
1048
- return ax
1049
1048
  return None
1050
1049
 
1050
+ @_doc_params(common_plot_args=doc_common_plot_args)
1051
1051
  def plot_important_features(
1052
1052
  self,
1053
1053
  data: dict[str, Any],
1054
+ *,
1054
1055
  key: str = "augurpy_results",
1055
1056
  top_n: int = 10,
1056
- return_fig: bool | None = None,
1057
1057
  ax: Axes = None,
1058
- show: bool | None = None,
1059
- save: str | bool | None = None,
1060
- ) -> Axes | None:
1058
+ show: bool = True,
1059
+ return_fig: bool = False,
1060
+ ) -> Figure | None:
1061
1061
  """Plot a lollipop plot of the n features with largest feature importances.
1062
1062
 
1063
1063
  Args:
1064
- results: results after running `predict()` as dictionary or the AnnData object.
1064
+ data: results after running `predict()` as dictionary or the AnnData object.
1065
1065
  key: Key in the AnnData object of the results
1066
1066
  top_n: n number feature importance values to plot. Default is 10.
1067
1067
  ax: optionally, axes used to draw plot
1068
- return_figure: if `True` returns figure of the plot, default is `False`
1068
+ {common_plot_args}
1069
1069
 
1070
1070
  Returns:
1071
- Axes of the plot.
1071
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1072
1072
 
1073
1073
  Examples:
1074
1074
  >>> import pertpy as pt
@@ -1109,35 +1109,32 @@ class Augur:
1109
1109
  plt.ylabel("Gene")
1110
1110
  plt.yticks(y_axes_range, n_features["genes"])
1111
1111
 
1112
- if save:
1113
- plt.savefig(save, bbox_inches="tight")
1114
1112
  if show:
1115
1113
  plt.show()
1116
1114
  if return_fig:
1117
1115
  return plt.gcf()
1118
- if not (show or save):
1119
- return ax
1120
1116
  return None
1121
1117
 
1118
+ @_doc_params(common_plot_args=doc_common_plot_args)
1122
1119
  def plot_lollipop(
1123
1120
  self,
1124
- data: dict[str, Any],
1121
+ data: dict[str, Any] | AnnData,
1122
+ *,
1125
1123
  key: str = "augurpy_results",
1126
- return_fig: bool | None = None,
1127
1124
  ax: Axes = None,
1128
- show: bool | None = None,
1129
- save: str | bool | None = None,
1130
- ) -> Axes | Figure | None:
1125
+ show: bool = True,
1126
+ return_fig: bool = False,
1127
+ ) -> Figure | None:
1131
1128
  """Plot a lollipop plot of the mean augur values.
1132
1129
 
1133
1130
  Args:
1134
- results: results after running `predict()` as dictionary or the AnnData object.
1135
- key: Key in the AnnData object of the results
1136
- ax: optionally, axes used to draw plot
1137
- return_figure: if `True` returns figure of the plot
1131
+ data: results after running `predict()` as dictionary or the AnnData object.
1132
+ key: .uns key in the results AnnData object.
1133
+ ax: optionally, axes used to draw plot.
1134
+ {common_plot_args}
1138
1135
 
1139
1136
  Returns:
1140
- Axes of the plot.
1137
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1141
1138
 
1142
1139
  Examples:
1143
1140
  >>> import pertpy as pt
@@ -1175,32 +1172,29 @@ class Augur:
1175
1172
  plt.ylabel("Cell Type")
1176
1173
  plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
1177
1174
 
1178
- if save:
1179
- plt.savefig(save, bbox_inches="tight")
1180
1175
  if show:
1181
1176
  plt.show()
1182
1177
  if return_fig:
1183
1178
  return plt.gcf()
1184
- if not (show or save):
1185
- return ax
1186
1179
  return None
1187
1180
 
1181
+ @_doc_params(common_plot_args=doc_common_plot_args)
1188
1182
  def plot_scatterplot(
1189
1183
  self,
1190
1184
  results1: dict[str, Any],
1191
1185
  results2: dict[str, Any],
1186
+ *,
1192
1187
  top_n: int = None,
1193
- return_fig: bool | None = None,
1194
- show: bool | None = None,
1195
- save: str | bool | None = None,
1196
- ) -> Axes | Figure | None:
1188
+ show: bool = True,
1189
+ return_fig: bool = False,
1190
+ ) -> Figure | None:
1197
1191
  """Create scatterplot with two augur results.
1198
1192
 
1199
1193
  Args:
1200
1194
  results1: results after running `predict()`
1201
1195
  results2: results after running `predict()`
1202
1196
  top_n: optionally, the number of top prioritized cell types to label in the plot
1203
- return_figure: if `True` returns figure of the plot
1197
+ {common_plot_args}
1204
1198
 
1205
1199
  Returns:
1206
1200
  Axes of the plot.
@@ -1249,12 +1243,8 @@ class Augur:
1249
1243
  plt.xlabel("Augur scores 1")
1250
1244
  plt.ylabel("Augur scores 2")
1251
1245
 
1252
- if save:
1253
- plt.savefig(save, bbox_inches="tight")
1254
1246
  if show:
1255
1247
  plt.show()
1256
1248
  if return_fig:
1257
1249
  return plt.gcf()
1258
- if not (show or save):
1259
- return ax
1260
1250
  return None
pertpy/tools/_cinemaot.py CHANGED
@@ -18,9 +18,12 @@ from sklearn.decomposition import FastICA
18
18
  from sklearn.linear_model import LinearRegression
19
19
  from sklearn.neighbors import NearestNeighbors
20
20
 
21
+ from pertpy._doc import _doc_params, doc_common_plot_args
22
+
21
23
  if TYPE_CHECKING:
22
24
  from anndata import AnnData
23
25
  from matplotlib.axes import Axes
26
+ from matplotlib.pyplot import Figure
24
27
  from statsmodels.tools.typing import ArrayLike
25
28
 
26
29
 
@@ -88,7 +91,7 @@ class Cinemaot:
88
91
  dim = self.get_dim(adata, use_rep=use_rep)
89
92
 
90
93
  transformer = FastICA(n_components=dim, random_state=0, whiten="arbitrary-variance")
91
- X_transformed = transformer.fit_transform(adata.obsm[use_rep][:, :dim])
94
+ X_transformed = np.array(transformer.fit_transform(adata.obsm[use_rep][:, :dim]), dtype=np.float64)
92
95
  groupvec = (adata.obs[pert_key] == control * 1).values # control
93
96
  xi = np.zeros(dim)
94
97
  j = 0
@@ -97,9 +100,9 @@ class Cinemaot:
97
100
  xi[j] = xi_obj.correlation
98
101
  j = j + 1
99
102
 
100
- cf = X_transformed[:, xi < thres]
101
- cf1 = cf[adata.obs[pert_key] == control, :]
102
- cf2 = cf[adata.obs[pert_key] != control, :]
103
+ cf = np.array(X_transformed[:, xi < thres], np.float64)
104
+ cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float64)
105
+ cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float64)
103
106
  if sum(xi < thres) == 1:
104
107
  sklearn.metrics.pairwise_distances(cf1.reshape(-1, 1), cf2.reshape(-1, 1))
105
108
  elif sum(xi < thres) == 0:
@@ -167,7 +170,7 @@ class Cinemaot:
167
170
  else:
168
171
  _solver = sinkhorn.Sinkhorn(threshold=eps)
169
172
  ot_sink = _solver(ot_prob)
170
- ot_matrix = ot_sink.matrix.T
173
+ ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
171
174
  embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
172
175
  ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
173
176
  )
@@ -639,6 +642,7 @@ class Cinemaot:
639
642
  s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6)
640
643
  return c_effect, s_effect
641
644
 
645
+ @_doc_params(common_plot_args=doc_common_plot_args)
642
646
  def plot_vis_matching(
643
647
  self,
644
648
  adata: AnnData,
@@ -647,16 +651,17 @@ class Cinemaot:
647
651
  control: str,
648
652
  de_label: str,
649
653
  source_label: str,
654
+ *,
650
655
  matching_rep: str = "ot",
651
656
  resolution: float = 0.5,
652
657
  normalize: str = "col",
653
658
  title: str = "CINEMA-OT matching matrix",
654
659
  min_val: float = 0.01,
655
- show: bool = True,
656
- save: str | None = None,
657
660
  ax: Axes | None = None,
661
+ show: bool = True,
662
+ return_fig: bool = False,
658
663
  **kwargs,
659
- ) -> None:
664
+ ) -> Figure | None:
660
665
  """Visualize the CINEMA-OT matching matrix.
661
666
 
662
667
  Args:
@@ -670,11 +675,12 @@ class Cinemaot:
670
675
  normalize: normalize the coarse-grained matching matrix by row / column.
671
676
  title: the title for the figure.
672
677
  min_val: The min value to truncate the matching matrix.
673
- show: Show the plot, do not return axis.
674
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
675
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
678
+ {common_plot_args}
676
679
  **kwargs: Other parameters to input for seaborn.heatmap.
677
680
 
681
+ Returns:
682
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
683
+
678
684
  Examples:
679
685
  >>> import pertpy as pt
680
686
  >>> adata = pt.dt.cinemaot_example()
@@ -710,12 +716,12 @@ class Cinemaot:
710
716
 
711
717
  g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
712
718
  plt.title(title)
713
- _utils.savefig_or_show("matching_heatmap", show=show, save=save)
714
- if not show:
715
- if ax is not None:
716
- return ax
717
- else:
718
- return g
719
+
720
+ if show:
721
+ plt.show()
722
+ if return_fig:
723
+ return g
724
+ return None
719
725
 
720
726
 
721
727
  class Xi: