pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +20 -0
- pertpy/data/_dataloader.py +4 -4
- pertpy/data/_datasets.py +3 -3
- pertpy/metadata/_cell_line.py +19 -7
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +19 -6
- pertpy/tools/__init__.py +12 -15
- pertpy/tools/_augur.py +36 -46
- pertpy/tools/_cinemaot.py +24 -18
- pertpy/tools/_coda/_base_coda.py +87 -106
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +495 -113
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +15 -8
- pertpy/tools/_enrichment.py +18 -8
- pertpy/tools/_milo.py +58 -46
- pertpy/tools/_mixscape.py +111 -100
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +50 -0
- pertpy/tools/_scgen/_scgen.py +35 -25
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/METADATA +5 -4
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/RECORD +29 -29
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/licenses/LICENSE +0 -0
pertpy/__init__.py
CHANGED
pertpy/_doc.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
from textwrap import dedent
|
2
|
+
|
3
|
+
|
4
|
+
def _doc_params(**kwds): # pragma: no cover
|
5
|
+
"""\
|
6
|
+
Docstrings should start with "\" in the first line for proper formatting.
|
7
|
+
"""
|
8
|
+
|
9
|
+
def dec(obj):
|
10
|
+
obj.__orig_doc__ = obj.__doc__
|
11
|
+
obj.__doc__ = dedent(obj.__doc__.format_map(kwds))
|
12
|
+
return obj
|
13
|
+
|
14
|
+
return dec
|
15
|
+
|
16
|
+
|
17
|
+
doc_common_plot_args = """\
|
18
|
+
show: if `True`, shows the plot.
|
19
|
+
return_fig: if `True`, returns figure of the plot.\
|
20
|
+
"""
|
pertpy/data/_dataloader.py
CHANGED
@@ -23,10 +23,10 @@ def _download( # pragma: no cover
|
|
23
23
|
Args:
|
24
24
|
url: URL to download
|
25
25
|
output_file_name: Name of the downloaded file
|
26
|
-
output_path: Path to download/extract the files to
|
27
|
-
block_size: Block size for downloads in bytes
|
28
|
-
overwrite: Whether to overwrite existing files
|
29
|
-
is_zip: Whether the downloaded file needs to be unzipped
|
26
|
+
output_path: Path to download/extract the files to.
|
27
|
+
block_size: Block size for downloads in bytes.
|
28
|
+
overwrite: Whether to overwrite existing files.
|
29
|
+
is_zip: Whether the downloaded file needs to be unzipped.
|
30
30
|
"""
|
31
31
|
if output_file_name is None:
|
32
32
|
letters = ascii_lowercase
|
pertpy/data/_datasets.py
CHANGED
@@ -66,7 +66,7 @@ def sc_sim_augur() -> AnnData: # pragma: no cover
|
|
66
66
|
output_file_path = settings.datasetdir / output_file_name
|
67
67
|
if not Path(output_file_path).exists():
|
68
68
|
_download(
|
69
|
-
url="https://figshare.com/ndownloader/files/
|
69
|
+
url="https://figshare.com/ndownloader/files/49828902",
|
70
70
|
output_file_name=output_file_name,
|
71
71
|
output_path=settings.datasetdir,
|
72
72
|
is_zip=False,
|
@@ -1100,7 +1100,7 @@ def shifrut_2018() -> AnnData: # pragma: no cover
|
|
1100
1100
|
output_file_path = settings.datasetdir / output_file_name
|
1101
1101
|
if not Path(output_file_path).exists():
|
1102
1102
|
_download(
|
1103
|
-
url="https://zenodo.org/record/
|
1103
|
+
url="https://zenodo.org/record/13350497/files/ShifrutMarson2018.h5ad?download=1",
|
1104
1104
|
output_file_name=output_file_name,
|
1105
1105
|
output_path=settings.datasetdir,
|
1106
1106
|
is_zip=False,
|
@@ -1160,7 +1160,7 @@ def srivatsan_2020_sciplex3() -> AnnData: # pragma: no cover
|
|
1160
1160
|
output_file_path = settings.datasetdir / output_file_name
|
1161
1161
|
if not Path(output_file_path).exists():
|
1162
1162
|
_download(
|
1163
|
-
url="https://zenodo.org/records/
|
1163
|
+
url="https://zenodo.org/records/13350497/files/SrivatsanTrapnell2020_sciplex3.h5ad?download=1",
|
1164
1164
|
output_file_name=output_file_name,
|
1165
1165
|
output_path=settings.datasetdir,
|
1166
1166
|
is_zip=False,
|
pertpy/metadata/_cell_line.py
CHANGED
@@ -8,12 +8,15 @@ from lamin_utils import logger
|
|
8
8
|
if TYPE_CHECKING:
|
9
9
|
from collections.abc import Iterable
|
10
10
|
|
11
|
+
from matplotlib.pyplot import Figure
|
12
|
+
|
11
13
|
import matplotlib.pyplot as plt
|
12
14
|
import numpy as np
|
13
15
|
import pandas as pd
|
14
16
|
from scanpy import settings
|
15
17
|
from scipy import stats
|
16
18
|
|
19
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
20
|
from pertpy.data._dataloader import _download
|
18
21
|
|
19
22
|
from ._look_up import LookUp
|
@@ -338,8 +341,8 @@ class CellLine(MetaData):
|
|
338
341
|
# then we can compare these keys and fetch the corresponding metadata.
|
339
342
|
if query_id not in adata.obs.columns:
|
340
343
|
raise ValueError(
|
341
|
-
f"The specified `query_id` {query_id} can't be found in the `adata.obs
|
342
|
-
"Ensure that you are using one of the available query IDs present in the adata.obs for the annotation
|
344
|
+
f"The specified `query_id` {query_id} can't be found in the `adata.obs`. \n"
|
345
|
+
"Ensure that you are using one of the available query IDs present in the adata.obs for the annotation."
|
343
346
|
"If the desired query ID is not available, you can fetch the cell line metadata "
|
344
347
|
"using the `annotate()` function before calling 'annotate_bulk_rna()'. "
|
345
348
|
"This ensures that the required query ID is included in your data, e.g. stripped_cell_line_name, DepMap ID."
|
@@ -356,9 +359,8 @@ class CellLine(MetaData):
|
|
356
359
|
else:
|
357
360
|
reference_id = "DepMap_ID"
|
358
361
|
logger.warning(
|
359
|
-
"To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given
|
360
|
-
"
|
361
|
-
"Alternatively, use `annotate()` to annotate the cell line first "
|
362
|
+
"To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given."
|
363
|
+
"If `DepMap_ID` isn't available in 'adata.obs', use `annotate()` to annotate the cell line first."
|
362
364
|
)
|
363
365
|
if self.bulk_rna_broad is None:
|
364
366
|
self._download_bulk_rna(cell_line_source="broad")
|
@@ -690,6 +692,7 @@ class CellLine(MetaData):
|
|
690
692
|
|
691
693
|
return corr, pvals, new_corr, new_pvals
|
692
694
|
|
695
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
693
696
|
def plot_correlation(
|
694
697
|
self,
|
695
698
|
adata: AnnData,
|
@@ -700,7 +703,9 @@ class CellLine(MetaData):
|
|
700
703
|
metadata_key: str = "bulk_rna_broad",
|
701
704
|
category: str = "cell line",
|
702
705
|
subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
|
703
|
-
|
706
|
+
show: bool = True,
|
707
|
+
return_fig: bool = False,
|
708
|
+
) -> Figure | None:
|
704
709
|
"""Visualise the correlation of cell lines with annotated metadata.
|
705
710
|
|
706
711
|
Args:
|
@@ -713,6 +718,8 @@ class CellLine(MetaData):
|
|
713
718
|
subset_identifier: Selected identifiers for scatter plot visualization between the X matrix and `metadata_key`.
|
714
719
|
If not None, only the chosen cell line will be plotted, either specified as a value in `identifier` (string) or as an index number.
|
715
720
|
If None, all cell lines will be plotted.
|
721
|
+
{common_plot_args}
|
722
|
+
|
716
723
|
Returns:
|
717
724
|
Pearson correlation coefficients and their corresponding p-values for matched and unmatched cell lines separately.
|
718
725
|
"""
|
@@ -790,6 +797,11 @@ class CellLine(MetaData):
|
|
790
797
|
"edgecolor": "black",
|
791
798
|
},
|
792
799
|
)
|
793
|
-
|
800
|
+
|
801
|
+
if show:
|
802
|
+
plt.show()
|
803
|
+
if return_fig:
|
804
|
+
return plt.gcf()
|
805
|
+
return None
|
794
806
|
else:
|
795
807
|
raise NotImplementedError
|
pertpy/metadata/_compound.py
CHANGED
@@ -42,7 +42,7 @@ class Compound(MetaData):
|
|
42
42
|
adata = adata.copy()
|
43
43
|
|
44
44
|
if query_id not in adata.obs.columns:
|
45
|
-
raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n
|
45
|
+
raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n Please check again.")
|
46
46
|
|
47
47
|
query_dict = {}
|
48
48
|
not_matched_identifiers = []
|
@@ -84,7 +84,7 @@ class Compound(MetaData):
|
|
84
84
|
query_df = pd.DataFrame.from_dict(query_dict, orient="index", columns=["pubchem_name", "pubchem_ID", "smiles"])
|
85
85
|
# Merge and remove duplicate columns
|
86
86
|
# Column is converted to float after merging due to unmatches
|
87
|
-
# Convert back to integers
|
87
|
+
# Convert back to integers afterwards
|
88
88
|
if query_id_type == "cid":
|
89
89
|
query_df.pubchem_ID = query_df.pubchem_ID.astype("Int64")
|
90
90
|
adata.obs = (
|
@@ -119,8 +119,7 @@ class Compound(MetaData):
|
|
119
119
|
|
120
120
|
The LookUp object provides an overview of the metadata to annotate.
|
121
121
|
Each annotate_{metadata} function has a corresponding lookup function in the LookUp object,
|
122
|
-
where users can search the reference_id in the metadata and
|
123
|
-
compare with the query_id in their own data.
|
122
|
+
where users can search the reference_id in the metadata and compare with the query_id in their own data.
|
124
123
|
|
125
124
|
Returns:
|
126
125
|
Returns a LookUp object specific for compound annotation.
|
pertpy/metadata/_metadata.py
CHANGED
@@ -62,7 +62,7 @@ class MetaData:
|
|
62
62
|
if verbosity > 0:
|
63
63
|
logger.info(
|
64
64
|
f"There are {total_identifiers} identifiers in `adata.obs`."
|
65
|
-
f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation,"
|
65
|
+
f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation, "
|
66
66
|
"leading to the presence of NA values for their respective metadata.\n"
|
67
67
|
f"Please check again: *unmatched_identifiers[:verbosity]..."
|
68
68
|
)
|
@@ -3,14 +3,17 @@ from __future__ import annotations
|
|
3
3
|
import uuid
|
4
4
|
from typing import TYPE_CHECKING
|
5
5
|
|
6
|
+
import matplotlib.pyplot as plt
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
import scanpy as sc
|
9
10
|
import scipy
|
10
11
|
|
12
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
13
|
+
|
11
14
|
if TYPE_CHECKING:
|
12
15
|
from anndata import AnnData
|
13
|
-
from matplotlib.
|
16
|
+
from matplotlib.pyplot import Figure
|
14
17
|
|
15
18
|
|
16
19
|
class GuideAssignment:
|
@@ -106,14 +109,18 @@ class GuideAssignment:
|
|
106
109
|
|
107
110
|
return None
|
108
111
|
|
112
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
109
113
|
def plot_heatmap(
|
110
114
|
self,
|
111
115
|
adata: AnnData,
|
116
|
+
*,
|
112
117
|
layer: str | None = None,
|
113
118
|
order_by: np.ndarray | str | None = None,
|
114
119
|
key_to_save_order: str = None,
|
120
|
+
show: bool = True,
|
121
|
+
return_fig: bool = False,
|
115
122
|
**kwargs,
|
116
|
-
) ->
|
123
|
+
) -> Figure | None:
|
117
124
|
"""Heatmap plotting of guide RNA expression matrix.
|
118
125
|
|
119
126
|
Assuming guides have sparse expression, this function reorders cells
|
@@ -131,11 +138,12 @@ class GuideAssignment:
|
|
131
138
|
If a string is provided, adata.obs[order_by] will be used as the order.
|
132
139
|
If a numpy array is provided, the array will be used for ordering.
|
133
140
|
key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
|
141
|
+
{common_plot_args}
|
134
142
|
kwargs: Are passed to sc.pl.heatmap.
|
135
143
|
|
136
144
|
Returns:
|
137
|
-
|
138
|
-
Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
|
145
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
146
|
+
Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided.
|
139
147
|
|
140
148
|
Examples:
|
141
149
|
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
|
@@ -172,7 +180,7 @@ class GuideAssignment:
|
|
172
180
|
adata.obs[key_to_save_order] = pd.Categorical(order)
|
173
181
|
|
174
182
|
try:
|
175
|
-
|
183
|
+
fig = sc.pl.heatmap(
|
176
184
|
adata[order, :],
|
177
185
|
var_names=adata.var.index.tolist(),
|
178
186
|
groupby=temp_col_name,
|
@@ -180,9 +188,14 @@ class GuideAssignment:
|
|
180
188
|
use_raw=False,
|
181
189
|
dendrogram=False,
|
182
190
|
layer=layer,
|
191
|
+
show=False,
|
183
192
|
**kwargs,
|
184
193
|
)
|
185
194
|
finally:
|
186
195
|
del adata.obs[temp_col_name]
|
187
196
|
|
188
|
-
|
197
|
+
if show:
|
198
|
+
plt.show()
|
199
|
+
if return_fig:
|
200
|
+
return fig
|
201
|
+
return None
|
pertpy/tools/__init__.py
CHANGED
@@ -1,25 +1,22 @@
|
|
1
|
-
from functools import wraps
|
2
1
|
from importlib import import_module
|
3
2
|
|
4
3
|
|
5
4
|
def lazy_import(module_path, class_name, extras):
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
import_module(extra)
|
10
|
-
except ImportError as e:
|
11
|
-
raise ImportError(
|
12
|
-
f"Extra dependencies required: {', '.join(extras)}. "
|
13
|
-
f"Please install with: pip install {' '.join(extras)}"
|
14
|
-
) from e
|
5
|
+
try:
|
6
|
+
for extra in extras:
|
7
|
+
import_module(extra)
|
15
8
|
module = import_module(module_path)
|
16
9
|
return getattr(module, class_name)
|
10
|
+
except ImportError:
|
17
11
|
|
18
|
-
|
19
|
-
|
20
|
-
|
12
|
+
class Placeholder:
|
13
|
+
def __init__(self, *args, **kwargs):
|
14
|
+
raise ImportError(
|
15
|
+
f"Extra dependencies required: {', '.join(extras)}. "
|
16
|
+
f"Please install with: pip install {' '.join(extras)}"
|
17
|
+
)
|
21
18
|
|
22
|
-
|
19
|
+
return Placeholder
|
23
20
|
|
24
21
|
|
25
22
|
from pertpy.tools._augur import Augur
|
@@ -49,7 +46,7 @@ Sccoda = lazy_import("pertpy.tools._coda._sccoda", "Sccoda", CODA_EXTRAS)
|
|
49
46
|
Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
|
50
47
|
|
51
48
|
DE_EXTRAS = ["formulaic", "pydeseq2"]
|
52
|
-
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS
|
49
|
+
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
|
53
50
|
PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
|
54
51
|
Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
|
55
52
|
TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)
|
pertpy/tools/_augur.py
CHANGED
@@ -15,7 +15,6 @@ import statsmodels.api as sm
|
|
15
15
|
from anndata import AnnData
|
16
16
|
from joblib import Parallel, delayed
|
17
17
|
from lamin_utils import logger
|
18
|
-
from rich import print
|
19
18
|
from rich.progress import track
|
20
19
|
from scipy import sparse, stats
|
21
20
|
from sklearn.base import is_classifier, is_regressor
|
@@ -26,17 +25,19 @@ from sklearn.metrics import (
|
|
26
25
|
explained_variance_score,
|
27
26
|
f1_score,
|
28
27
|
make_scorer,
|
29
|
-
mean_squared_error,
|
30
28
|
precision_score,
|
31
29
|
r2_score,
|
32
30
|
recall_score,
|
33
31
|
roc_auc_score,
|
32
|
+
root_mean_squared_error,
|
34
33
|
)
|
35
34
|
from sklearn.model_selection import StratifiedKFold, cross_validate
|
36
35
|
from sklearn.preprocessing import LabelEncoder
|
37
36
|
from skmisc.loess import loess
|
38
37
|
from statsmodels.stats.multitest import fdrcorrection
|
39
38
|
|
39
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
40
|
+
|
40
41
|
if TYPE_CHECKING:
|
41
42
|
from matplotlib.axes import Axes
|
42
43
|
from matplotlib.figure import Figure
|
@@ -439,7 +440,7 @@ class Augur:
|
|
439
440
|
"augur_score": make_scorer(self.ccc_score),
|
440
441
|
"r2": make_scorer(r2_score),
|
441
442
|
"ccc": make_scorer(self.ccc_score),
|
442
|
-
"neg_mean_squared_error": make_scorer(
|
443
|
+
"neg_mean_squared_error": make_scorer(root_mean_squared_error),
|
443
444
|
"explained_variance": make_scorer(explained_variance_score),
|
444
445
|
}
|
445
446
|
)
|
@@ -974,24 +975,26 @@ class Augur:
|
|
974
975
|
|
975
976
|
return delta
|
976
977
|
|
978
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
977
979
|
def plot_dp_scatter(
|
978
980
|
self,
|
979
981
|
results: pd.DataFrame,
|
982
|
+
*,
|
980
983
|
top_n: int = None,
|
981
|
-
return_fig: bool | None = None,
|
982
984
|
ax: Axes = None,
|
983
|
-
show: bool
|
984
|
-
|
985
|
-
) ->
|
985
|
+
show: bool = True,
|
986
|
+
return_fig: bool = False,
|
987
|
+
) -> Figure | None:
|
986
988
|
"""Plot scatterplot of differential prioritization.
|
987
989
|
|
988
990
|
Args:
|
989
991
|
results: Results after running differential prioritization.
|
990
992
|
top_n: optionally, the number of top prioritized cell types to label in the plot
|
991
993
|
ax: optionally, axes used to draw plot
|
994
|
+
{common_plot_args}
|
992
995
|
|
993
996
|
Returns:
|
994
|
-
|
997
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
995
998
|
|
996
999
|
Examples:
|
997
1000
|
>>> import pertpy as pt
|
@@ -1038,37 +1041,34 @@ class Augur:
|
|
1038
1041
|
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1039
1042
|
ax.add_artist(legend1)
|
1040
1043
|
|
1041
|
-
if save:
|
1042
|
-
plt.savefig(save, bbox_inches="tight")
|
1043
1044
|
if show:
|
1044
1045
|
plt.show()
|
1045
1046
|
if return_fig:
|
1046
1047
|
return plt.gcf()
|
1047
|
-
if not (show or save):
|
1048
|
-
return ax
|
1049
1048
|
return None
|
1050
1049
|
|
1050
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1051
1051
|
def plot_important_features(
|
1052
1052
|
self,
|
1053
1053
|
data: dict[str, Any],
|
1054
|
+
*,
|
1054
1055
|
key: str = "augurpy_results",
|
1055
1056
|
top_n: int = 10,
|
1056
|
-
return_fig: bool | None = None,
|
1057
1057
|
ax: Axes = None,
|
1058
|
-
show: bool
|
1059
|
-
|
1060
|
-
) ->
|
1058
|
+
show: bool = True,
|
1059
|
+
return_fig: bool = False,
|
1060
|
+
) -> Figure | None:
|
1061
1061
|
"""Plot a lollipop plot of the n features with largest feature importances.
|
1062
1062
|
|
1063
1063
|
Args:
|
1064
|
-
|
1064
|
+
data: results after running `predict()` as dictionary or the AnnData object.
|
1065
1065
|
key: Key in the AnnData object of the results
|
1066
1066
|
top_n: n number feature importance values to plot. Default is 10.
|
1067
1067
|
ax: optionally, axes used to draw plot
|
1068
|
-
|
1068
|
+
{common_plot_args}
|
1069
1069
|
|
1070
1070
|
Returns:
|
1071
|
-
|
1071
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1072
1072
|
|
1073
1073
|
Examples:
|
1074
1074
|
>>> import pertpy as pt
|
@@ -1109,35 +1109,32 @@ class Augur:
|
|
1109
1109
|
plt.ylabel("Gene")
|
1110
1110
|
plt.yticks(y_axes_range, n_features["genes"])
|
1111
1111
|
|
1112
|
-
if save:
|
1113
|
-
plt.savefig(save, bbox_inches="tight")
|
1114
1112
|
if show:
|
1115
1113
|
plt.show()
|
1116
1114
|
if return_fig:
|
1117
1115
|
return plt.gcf()
|
1118
|
-
if not (show or save):
|
1119
|
-
return ax
|
1120
1116
|
return None
|
1121
1117
|
|
1118
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1122
1119
|
def plot_lollipop(
|
1123
1120
|
self,
|
1124
|
-
data: dict[str, Any],
|
1121
|
+
data: dict[str, Any] | AnnData,
|
1122
|
+
*,
|
1125
1123
|
key: str = "augurpy_results",
|
1126
|
-
return_fig: bool | None = None,
|
1127
1124
|
ax: Axes = None,
|
1128
|
-
show: bool
|
1129
|
-
|
1130
|
-
) ->
|
1125
|
+
show: bool = True,
|
1126
|
+
return_fig: bool = False,
|
1127
|
+
) -> Figure | None:
|
1131
1128
|
"""Plot a lollipop plot of the mean augur values.
|
1132
1129
|
|
1133
1130
|
Args:
|
1134
|
-
|
1135
|
-
key:
|
1136
|
-
ax: optionally, axes used to draw plot
|
1137
|
-
|
1131
|
+
data: results after running `predict()` as dictionary or the AnnData object.
|
1132
|
+
key: .uns key in the results AnnData object.
|
1133
|
+
ax: optionally, axes used to draw plot.
|
1134
|
+
{common_plot_args}
|
1138
1135
|
|
1139
1136
|
Returns:
|
1140
|
-
|
1137
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1141
1138
|
|
1142
1139
|
Examples:
|
1143
1140
|
>>> import pertpy as pt
|
@@ -1175,32 +1172,29 @@ class Augur:
|
|
1175
1172
|
plt.ylabel("Cell Type")
|
1176
1173
|
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1177
1174
|
|
1178
|
-
if save:
|
1179
|
-
plt.savefig(save, bbox_inches="tight")
|
1180
1175
|
if show:
|
1181
1176
|
plt.show()
|
1182
1177
|
if return_fig:
|
1183
1178
|
return plt.gcf()
|
1184
|
-
if not (show or save):
|
1185
|
-
return ax
|
1186
1179
|
return None
|
1187
1180
|
|
1181
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1188
1182
|
def plot_scatterplot(
|
1189
1183
|
self,
|
1190
1184
|
results1: dict[str, Any],
|
1191
1185
|
results2: dict[str, Any],
|
1186
|
+
*,
|
1192
1187
|
top_n: int = None,
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
) -> Axes | Figure | None:
|
1188
|
+
show: bool = True,
|
1189
|
+
return_fig: bool = False,
|
1190
|
+
) -> Figure | None:
|
1197
1191
|
"""Create scatterplot with two augur results.
|
1198
1192
|
|
1199
1193
|
Args:
|
1200
1194
|
results1: results after running `predict()`
|
1201
1195
|
results2: results after running `predict()`
|
1202
1196
|
top_n: optionally, the number of top prioritized cell types to label in the plot
|
1203
|
-
|
1197
|
+
{common_plot_args}
|
1204
1198
|
|
1205
1199
|
Returns:
|
1206
1200
|
Axes of the plot.
|
@@ -1249,12 +1243,8 @@ class Augur:
|
|
1249
1243
|
plt.xlabel("Augur scores 1")
|
1250
1244
|
plt.ylabel("Augur scores 2")
|
1251
1245
|
|
1252
|
-
if save:
|
1253
|
-
plt.savefig(save, bbox_inches="tight")
|
1254
1246
|
if show:
|
1255
1247
|
plt.show()
|
1256
1248
|
if return_fig:
|
1257
1249
|
return plt.gcf()
|
1258
|
-
if not (show or save):
|
1259
|
-
return ax
|
1260
1250
|
return None
|
pertpy/tools/_cinemaot.py
CHANGED
@@ -18,9 +18,12 @@ from sklearn.decomposition import FastICA
|
|
18
18
|
from sklearn.linear_model import LinearRegression
|
19
19
|
from sklearn.neighbors import NearestNeighbors
|
20
20
|
|
21
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
22
|
+
|
21
23
|
if TYPE_CHECKING:
|
22
24
|
from anndata import AnnData
|
23
25
|
from matplotlib.axes import Axes
|
26
|
+
from matplotlib.pyplot import Figure
|
24
27
|
from statsmodels.tools.typing import ArrayLike
|
25
28
|
|
26
29
|
|
@@ -88,7 +91,7 @@ class Cinemaot:
|
|
88
91
|
dim = self.get_dim(adata, use_rep=use_rep)
|
89
92
|
|
90
93
|
transformer = FastICA(n_components=dim, random_state=0, whiten="arbitrary-variance")
|
91
|
-
X_transformed = transformer.fit_transform(adata.obsm[use_rep][:, :dim])
|
94
|
+
X_transformed = np.array(transformer.fit_transform(adata.obsm[use_rep][:, :dim]), dtype=np.float64)
|
92
95
|
groupvec = (adata.obs[pert_key] == control * 1).values # control
|
93
96
|
xi = np.zeros(dim)
|
94
97
|
j = 0
|
@@ -97,9 +100,9 @@ class Cinemaot:
|
|
97
100
|
xi[j] = xi_obj.correlation
|
98
101
|
j = j + 1
|
99
102
|
|
100
|
-
cf = X_transformed[:, xi < thres]
|
101
|
-
cf1 = cf[adata.obs[pert_key] == control, :]
|
102
|
-
cf2 = cf[adata.obs[pert_key] != control, :]
|
103
|
+
cf = np.array(X_transformed[:, xi < thres], np.float64)
|
104
|
+
cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float64)
|
105
|
+
cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float64)
|
103
106
|
if sum(xi < thres) == 1:
|
104
107
|
sklearn.metrics.pairwise_distances(cf1.reshape(-1, 1), cf2.reshape(-1, 1))
|
105
108
|
elif sum(xi < thres) == 0:
|
@@ -167,7 +170,7 @@ class Cinemaot:
|
|
167
170
|
else:
|
168
171
|
_solver = sinkhorn.Sinkhorn(threshold=eps)
|
169
172
|
ot_sink = _solver(ot_prob)
|
170
|
-
ot_matrix = ot_sink.matrix.T
|
173
|
+
ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
|
171
174
|
embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
|
172
175
|
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
|
173
176
|
)
|
@@ -190,7 +193,7 @@ class Cinemaot:
|
|
190
193
|
TE.obsm["X_embedding"] = embedding
|
191
194
|
|
192
195
|
if return_matching:
|
193
|
-
TE.obsm["ot"] = ot_sink.matrix.T
|
196
|
+
TE.obsm["ot"] = np.asarray(ot_sink.matrix.T)
|
194
197
|
return TE
|
195
198
|
else:
|
196
199
|
return TE
|
@@ -639,6 +642,7 @@ class Cinemaot:
|
|
639
642
|
s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6)
|
640
643
|
return c_effect, s_effect
|
641
644
|
|
645
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
642
646
|
def plot_vis_matching(
|
643
647
|
self,
|
644
648
|
adata: AnnData,
|
@@ -647,16 +651,17 @@ class Cinemaot:
|
|
647
651
|
control: str,
|
648
652
|
de_label: str,
|
649
653
|
source_label: str,
|
654
|
+
*,
|
650
655
|
matching_rep: str = "ot",
|
651
656
|
resolution: float = 0.5,
|
652
657
|
normalize: str = "col",
|
653
658
|
title: str = "CINEMA-OT matching matrix",
|
654
659
|
min_val: float = 0.01,
|
655
|
-
show: bool = True,
|
656
|
-
save: str | None = None,
|
657
660
|
ax: Axes | None = None,
|
661
|
+
show: bool = True,
|
662
|
+
return_fig: bool = False,
|
658
663
|
**kwargs,
|
659
|
-
) -> None:
|
664
|
+
) -> Figure | None:
|
660
665
|
"""Visualize the CINEMA-OT matching matrix.
|
661
666
|
|
662
667
|
Args:
|
@@ -670,11 +675,12 @@ class Cinemaot:
|
|
670
675
|
normalize: normalize the coarse-grained matching matrix by row / column.
|
671
676
|
title: the title for the figure.
|
672
677
|
min_val: The min value to truncate the matching matrix.
|
673
|
-
|
674
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
675
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
678
|
+
{common_plot_args}
|
676
679
|
**kwargs: Other parameters to input for seaborn.heatmap.
|
677
680
|
|
681
|
+
Returns:
|
682
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
683
|
+
|
678
684
|
Examples:
|
679
685
|
>>> import pertpy as pt
|
680
686
|
>>> adata = pt.dt.cinemaot_example()
|
@@ -710,12 +716,12 @@ class Cinemaot:
|
|
710
716
|
|
711
717
|
g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
|
712
718
|
plt.title(title)
|
713
|
-
|
714
|
-
if
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
+
|
720
|
+
if show:
|
721
|
+
plt.show()
|
722
|
+
if return_fig:
|
723
|
+
return g
|
724
|
+
return None
|
719
725
|
|
720
726
|
|
721
727
|
class Xi:
|