pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +20 -0
- pertpy/data/_dataloader.py +4 -4
- pertpy/data/_datasets.py +3 -3
- pertpy/metadata/_cell_line.py +19 -7
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +19 -6
- pertpy/tools/__init__.py +12 -15
- pertpy/tools/_augur.py +36 -46
- pertpy/tools/_cinemaot.py +24 -18
- pertpy/tools/_coda/_base_coda.py +87 -106
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +495 -113
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +15 -8
- pertpy/tools/_enrichment.py +18 -8
- pertpy/tools/_milo.py +58 -46
- pertpy/tools/_mixscape.py +111 -100
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +50 -0
- pertpy/tools/_scgen/_scgen.py +35 -25
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/METADATA +5 -4
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/RECORD +29 -29
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/licenses/LICENSE +0 -0
pertpy/__init__.py
CHANGED
pertpy/_doc.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
from textwrap import dedent
|
2
|
+
|
3
|
+
|
4
|
+
def _doc_params(**kwds): # pragma: no cover
|
5
|
+
"""\
|
6
|
+
Docstrings should start with "\" in the first line for proper formatting.
|
7
|
+
"""
|
8
|
+
|
9
|
+
def dec(obj):
|
10
|
+
obj.__orig_doc__ = obj.__doc__
|
11
|
+
obj.__doc__ = dedent(obj.__doc__.format_map(kwds))
|
12
|
+
return obj
|
13
|
+
|
14
|
+
return dec
|
15
|
+
|
16
|
+
|
17
|
+
doc_common_plot_args = """\
|
18
|
+
show: if `True`, shows the plot.
|
19
|
+
return_fig: if `True`, returns figure of the plot.\
|
20
|
+
"""
|
pertpy/data/_dataloader.py
CHANGED
@@ -23,10 +23,10 @@ def _download( # pragma: no cover
|
|
23
23
|
Args:
|
24
24
|
url: URL to download
|
25
25
|
output_file_name: Name of the downloaded file
|
26
|
-
output_path: Path to download/extract the files to
|
27
|
-
block_size: Block size for downloads in bytes
|
28
|
-
overwrite: Whether to overwrite existing files
|
29
|
-
is_zip: Whether the downloaded file needs to be unzipped
|
26
|
+
output_path: Path to download/extract the files to.
|
27
|
+
block_size: Block size for downloads in bytes.
|
28
|
+
overwrite: Whether to overwrite existing files.
|
29
|
+
is_zip: Whether the downloaded file needs to be unzipped.
|
30
30
|
"""
|
31
31
|
if output_file_name is None:
|
32
32
|
letters = ascii_lowercase
|
pertpy/data/_datasets.py
CHANGED
@@ -66,7 +66,7 @@ def sc_sim_augur() -> AnnData: # pragma: no cover
|
|
66
66
|
output_file_path = settings.datasetdir / output_file_name
|
67
67
|
if not Path(output_file_path).exists():
|
68
68
|
_download(
|
69
|
-
url="https://figshare.com/ndownloader/files/
|
69
|
+
url="https://figshare.com/ndownloader/files/49828902",
|
70
70
|
output_file_name=output_file_name,
|
71
71
|
output_path=settings.datasetdir,
|
72
72
|
is_zip=False,
|
@@ -1100,7 +1100,7 @@ def shifrut_2018() -> AnnData: # pragma: no cover
|
|
1100
1100
|
output_file_path = settings.datasetdir / output_file_name
|
1101
1101
|
if not Path(output_file_path).exists():
|
1102
1102
|
_download(
|
1103
|
-
url="https://zenodo.org/record/
|
1103
|
+
url="https://zenodo.org/record/13350497/files/ShifrutMarson2018.h5ad?download=1",
|
1104
1104
|
output_file_name=output_file_name,
|
1105
1105
|
output_path=settings.datasetdir,
|
1106
1106
|
is_zip=False,
|
@@ -1160,7 +1160,7 @@ def srivatsan_2020_sciplex3() -> AnnData: # pragma: no cover
|
|
1160
1160
|
output_file_path = settings.datasetdir / output_file_name
|
1161
1161
|
if not Path(output_file_path).exists():
|
1162
1162
|
_download(
|
1163
|
-
url="https://zenodo.org/records/
|
1163
|
+
url="https://zenodo.org/records/13350497/files/SrivatsanTrapnell2020_sciplex3.h5ad?download=1",
|
1164
1164
|
output_file_name=output_file_name,
|
1165
1165
|
output_path=settings.datasetdir,
|
1166
1166
|
is_zip=False,
|
pertpy/metadata/_cell_line.py
CHANGED
@@ -8,12 +8,15 @@ from lamin_utils import logger
|
|
8
8
|
if TYPE_CHECKING:
|
9
9
|
from collections.abc import Iterable
|
10
10
|
|
11
|
+
from matplotlib.pyplot import Figure
|
12
|
+
|
11
13
|
import matplotlib.pyplot as plt
|
12
14
|
import numpy as np
|
13
15
|
import pandas as pd
|
14
16
|
from scanpy import settings
|
15
17
|
from scipy import stats
|
16
18
|
|
19
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
20
|
from pertpy.data._dataloader import _download
|
18
21
|
|
19
22
|
from ._look_up import LookUp
|
@@ -338,8 +341,8 @@ class CellLine(MetaData):
|
|
338
341
|
# then we can compare these keys and fetch the corresponding metadata.
|
339
342
|
if query_id not in adata.obs.columns:
|
340
343
|
raise ValueError(
|
341
|
-
f"The specified `query_id` {query_id} can't be found in the `adata.obs
|
342
|
-
"Ensure that you are using one of the available query IDs present in the adata.obs for the annotation
|
344
|
+
f"The specified `query_id` {query_id} can't be found in the `adata.obs`. \n"
|
345
|
+
"Ensure that you are using one of the available query IDs present in the adata.obs for the annotation."
|
343
346
|
"If the desired query ID is not available, you can fetch the cell line metadata "
|
344
347
|
"using the `annotate()` function before calling 'annotate_bulk_rna()'. "
|
345
348
|
"This ensures that the required query ID is included in your data, e.g. stripped_cell_line_name, DepMap ID."
|
@@ -356,9 +359,8 @@ class CellLine(MetaData):
|
|
356
359
|
else:
|
357
360
|
reference_id = "DepMap_ID"
|
358
361
|
logger.warning(
|
359
|
-
"To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given
|
360
|
-
"
|
361
|
-
"Alternatively, use `annotate()` to annotate the cell line first "
|
362
|
+
"To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given."
|
363
|
+
"If `DepMap_ID` isn't available in 'adata.obs', use `annotate()` to annotate the cell line first."
|
362
364
|
)
|
363
365
|
if self.bulk_rna_broad is None:
|
364
366
|
self._download_bulk_rna(cell_line_source="broad")
|
@@ -690,6 +692,7 @@ class CellLine(MetaData):
|
|
690
692
|
|
691
693
|
return corr, pvals, new_corr, new_pvals
|
692
694
|
|
695
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
693
696
|
def plot_correlation(
|
694
697
|
self,
|
695
698
|
adata: AnnData,
|
@@ -700,7 +703,9 @@ class CellLine(MetaData):
|
|
700
703
|
metadata_key: str = "bulk_rna_broad",
|
701
704
|
category: str = "cell line",
|
702
705
|
subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
|
703
|
-
|
706
|
+
show: bool = True,
|
707
|
+
return_fig: bool = False,
|
708
|
+
) -> Figure | None:
|
704
709
|
"""Visualise the correlation of cell lines with annotated metadata.
|
705
710
|
|
706
711
|
Args:
|
@@ -713,6 +718,8 @@ class CellLine(MetaData):
|
|
713
718
|
subset_identifier: Selected identifiers for scatter plot visualization between the X matrix and `metadata_key`.
|
714
719
|
If not None, only the chosen cell line will be plotted, either specified as a value in `identifier` (string) or as an index number.
|
715
720
|
If None, all cell lines will be plotted.
|
721
|
+
{common_plot_args}
|
722
|
+
|
716
723
|
Returns:
|
717
724
|
Pearson correlation coefficients and their corresponding p-values for matched and unmatched cell lines separately.
|
718
725
|
"""
|
@@ -790,6 +797,11 @@ class CellLine(MetaData):
|
|
790
797
|
"edgecolor": "black",
|
791
798
|
},
|
792
799
|
)
|
793
|
-
|
800
|
+
|
801
|
+
if show:
|
802
|
+
plt.show()
|
803
|
+
if return_fig:
|
804
|
+
return plt.gcf()
|
805
|
+
return None
|
794
806
|
else:
|
795
807
|
raise NotImplementedError
|
pertpy/metadata/_compound.py
CHANGED
@@ -42,7 +42,7 @@ class Compound(MetaData):
|
|
42
42
|
adata = adata.copy()
|
43
43
|
|
44
44
|
if query_id not in adata.obs.columns:
|
45
|
-
raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n
|
45
|
+
raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n Please check again.")
|
46
46
|
|
47
47
|
query_dict = {}
|
48
48
|
not_matched_identifiers = []
|
@@ -84,7 +84,7 @@ class Compound(MetaData):
|
|
84
84
|
query_df = pd.DataFrame.from_dict(query_dict, orient="index", columns=["pubchem_name", "pubchem_ID", "smiles"])
|
85
85
|
# Merge and remove duplicate columns
|
86
86
|
# Column is converted to float after merging due to unmatches
|
87
|
-
# Convert back to integers
|
87
|
+
# Convert back to integers afterwards
|
88
88
|
if query_id_type == "cid":
|
89
89
|
query_df.pubchem_ID = query_df.pubchem_ID.astype("Int64")
|
90
90
|
adata.obs = (
|
@@ -119,8 +119,7 @@ class Compound(MetaData):
|
|
119
119
|
|
120
120
|
The LookUp object provides an overview of the metadata to annotate.
|
121
121
|
Each annotate_{metadata} function has a corresponding lookup function in the LookUp object,
|
122
|
-
where users can search the reference_id in the metadata and
|
123
|
-
compare with the query_id in their own data.
|
122
|
+
where users can search the reference_id in the metadata and compare with the query_id in their own data.
|
124
123
|
|
125
124
|
Returns:
|
126
125
|
Returns a LookUp object specific for compound annotation.
|
pertpy/metadata/_metadata.py
CHANGED
@@ -62,7 +62,7 @@ class MetaData:
|
|
62
62
|
if verbosity > 0:
|
63
63
|
logger.info(
|
64
64
|
f"There are {total_identifiers} identifiers in `adata.obs`."
|
65
|
-
f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation,"
|
65
|
+
f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation, "
|
66
66
|
"leading to the presence of NA values for their respective metadata.\n"
|
67
67
|
f"Please check again: *unmatched_identifiers[:verbosity]..."
|
68
68
|
)
|
@@ -3,14 +3,17 @@ from __future__ import annotations
|
|
3
3
|
import uuid
|
4
4
|
from typing import TYPE_CHECKING
|
5
5
|
|
6
|
+
import matplotlib.pyplot as plt
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
import scanpy as sc
|
9
10
|
import scipy
|
10
11
|
|
12
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
13
|
+
|
11
14
|
if TYPE_CHECKING:
|
12
15
|
from anndata import AnnData
|
13
|
-
from matplotlib.
|
16
|
+
from matplotlib.pyplot import Figure
|
14
17
|
|
15
18
|
|
16
19
|
class GuideAssignment:
|
@@ -106,14 +109,18 @@ class GuideAssignment:
|
|
106
109
|
|
107
110
|
return None
|
108
111
|
|
112
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
109
113
|
def plot_heatmap(
|
110
114
|
self,
|
111
115
|
adata: AnnData,
|
116
|
+
*,
|
112
117
|
layer: str | None = None,
|
113
118
|
order_by: np.ndarray | str | None = None,
|
114
119
|
key_to_save_order: str = None,
|
120
|
+
show: bool = True,
|
121
|
+
return_fig: bool = False,
|
115
122
|
**kwargs,
|
116
|
-
) ->
|
123
|
+
) -> Figure | None:
|
117
124
|
"""Heatmap plotting of guide RNA expression matrix.
|
118
125
|
|
119
126
|
Assuming guides have sparse expression, this function reorders cells
|
@@ -131,11 +138,12 @@ class GuideAssignment:
|
|
131
138
|
If a string is provided, adata.obs[order_by] will be used as the order.
|
132
139
|
If a numpy array is provided, the array will be used for ordering.
|
133
140
|
key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
|
141
|
+
{common_plot_args}
|
134
142
|
kwargs: Are passed to sc.pl.heatmap.
|
135
143
|
|
136
144
|
Returns:
|
137
|
-
|
138
|
-
Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
|
145
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
146
|
+
Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided.
|
139
147
|
|
140
148
|
Examples:
|
141
149
|
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
|
@@ -172,7 +180,7 @@ class GuideAssignment:
|
|
172
180
|
adata.obs[key_to_save_order] = pd.Categorical(order)
|
173
181
|
|
174
182
|
try:
|
175
|
-
|
183
|
+
fig = sc.pl.heatmap(
|
176
184
|
adata[order, :],
|
177
185
|
var_names=adata.var.index.tolist(),
|
178
186
|
groupby=temp_col_name,
|
@@ -180,9 +188,14 @@ class GuideAssignment:
|
|
180
188
|
use_raw=False,
|
181
189
|
dendrogram=False,
|
182
190
|
layer=layer,
|
191
|
+
show=False,
|
183
192
|
**kwargs,
|
184
193
|
)
|
185
194
|
finally:
|
186
195
|
del adata.obs[temp_col_name]
|
187
196
|
|
188
|
-
|
197
|
+
if show:
|
198
|
+
plt.show()
|
199
|
+
if return_fig:
|
200
|
+
return fig
|
201
|
+
return None
|
pertpy/tools/__init__.py
CHANGED
@@ -1,25 +1,22 @@
|
|
1
|
-
from functools import wraps
|
2
1
|
from importlib import import_module
|
3
2
|
|
4
3
|
|
5
4
|
def lazy_import(module_path, class_name, extras):
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
import_module(extra)
|
10
|
-
except ImportError as e:
|
11
|
-
raise ImportError(
|
12
|
-
f"Extra dependencies required: {', '.join(extras)}. "
|
13
|
-
f"Please install with: pip install {' '.join(extras)}"
|
14
|
-
) from e
|
5
|
+
try:
|
6
|
+
for extra in extras:
|
7
|
+
import_module(extra)
|
15
8
|
module = import_module(module_path)
|
16
9
|
return getattr(module, class_name)
|
10
|
+
except ImportError:
|
17
11
|
|
18
|
-
|
19
|
-
|
20
|
-
|
12
|
+
class Placeholder:
|
13
|
+
def __init__(self, *args, **kwargs):
|
14
|
+
raise ImportError(
|
15
|
+
f"Extra dependencies required: {', '.join(extras)}. "
|
16
|
+
f"Please install with: pip install {' '.join(extras)}"
|
17
|
+
)
|
21
18
|
|
22
|
-
|
19
|
+
return Placeholder
|
23
20
|
|
24
21
|
|
25
22
|
from pertpy.tools._augur import Augur
|
@@ -49,7 +46,7 @@ Sccoda = lazy_import("pertpy.tools._coda._sccoda", "Sccoda", CODA_EXTRAS)
|
|
49
46
|
Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
|
50
47
|
|
51
48
|
DE_EXTRAS = ["formulaic", "pydeseq2"]
|
52
|
-
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS
|
49
|
+
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
|
53
50
|
PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
|
54
51
|
Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
|
55
52
|
TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)
|
pertpy/tools/_augur.py
CHANGED
@@ -15,7 +15,6 @@ import statsmodels.api as sm
|
|
15
15
|
from anndata import AnnData
|
16
16
|
from joblib import Parallel, delayed
|
17
17
|
from lamin_utils import logger
|
18
|
-
from rich import print
|
19
18
|
from rich.progress import track
|
20
19
|
from scipy import sparse, stats
|
21
20
|
from sklearn.base import is_classifier, is_regressor
|
@@ -26,17 +25,19 @@ from sklearn.metrics import (
|
|
26
25
|
explained_variance_score,
|
27
26
|
f1_score,
|
28
27
|
make_scorer,
|
29
|
-
mean_squared_error,
|
30
28
|
precision_score,
|
31
29
|
r2_score,
|
32
30
|
recall_score,
|
33
31
|
roc_auc_score,
|
32
|
+
root_mean_squared_error,
|
34
33
|
)
|
35
34
|
from sklearn.model_selection import StratifiedKFold, cross_validate
|
36
35
|
from sklearn.preprocessing import LabelEncoder
|
37
36
|
from skmisc.loess import loess
|
38
37
|
from statsmodels.stats.multitest import fdrcorrection
|
39
38
|
|
39
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
40
|
+
|
40
41
|
if TYPE_CHECKING:
|
41
42
|
from matplotlib.axes import Axes
|
42
43
|
from matplotlib.figure import Figure
|
@@ -439,7 +440,7 @@ class Augur:
|
|
439
440
|
"augur_score": make_scorer(self.ccc_score),
|
440
441
|
"r2": make_scorer(r2_score),
|
441
442
|
"ccc": make_scorer(self.ccc_score),
|
442
|
-
"neg_mean_squared_error": make_scorer(
|
443
|
+
"neg_mean_squared_error": make_scorer(root_mean_squared_error),
|
443
444
|
"explained_variance": make_scorer(explained_variance_score),
|
444
445
|
}
|
445
446
|
)
|
@@ -974,24 +975,26 @@ class Augur:
|
|
974
975
|
|
975
976
|
return delta
|
976
977
|
|
978
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
977
979
|
def plot_dp_scatter(
|
978
980
|
self,
|
979
981
|
results: pd.DataFrame,
|
982
|
+
*,
|
980
983
|
top_n: int = None,
|
981
|
-
return_fig: bool | None = None,
|
982
984
|
ax: Axes = None,
|
983
|
-
show: bool
|
984
|
-
|
985
|
-
) ->
|
985
|
+
show: bool = True,
|
986
|
+
return_fig: bool = False,
|
987
|
+
) -> Figure | None:
|
986
988
|
"""Plot scatterplot of differential prioritization.
|
987
989
|
|
988
990
|
Args:
|
989
991
|
results: Results after running differential prioritization.
|
990
992
|
top_n: optionally, the number of top prioritized cell types to label in the plot
|
991
993
|
ax: optionally, axes used to draw plot
|
994
|
+
{common_plot_args}
|
992
995
|
|
993
996
|
Returns:
|
994
|
-
|
997
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
995
998
|
|
996
999
|
Examples:
|
997
1000
|
>>> import pertpy as pt
|
@@ -1038,37 +1041,34 @@ class Augur:
|
|
1038
1041
|
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1039
1042
|
ax.add_artist(legend1)
|
1040
1043
|
|
1041
|
-
if save:
|
1042
|
-
plt.savefig(save, bbox_inches="tight")
|
1043
1044
|
if show:
|
1044
1045
|
plt.show()
|
1045
1046
|
if return_fig:
|
1046
1047
|
return plt.gcf()
|
1047
|
-
if not (show or save):
|
1048
|
-
return ax
|
1049
1048
|
return None
|
1050
1049
|
|
1050
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1051
1051
|
def plot_important_features(
|
1052
1052
|
self,
|
1053
1053
|
data: dict[str, Any],
|
1054
|
+
*,
|
1054
1055
|
key: str = "augurpy_results",
|
1055
1056
|
top_n: int = 10,
|
1056
|
-
return_fig: bool | None = None,
|
1057
1057
|
ax: Axes = None,
|
1058
|
-
show: bool
|
1059
|
-
|
1060
|
-
) ->
|
1058
|
+
show: bool = True,
|
1059
|
+
return_fig: bool = False,
|
1060
|
+
) -> Figure | None:
|
1061
1061
|
"""Plot a lollipop plot of the n features with largest feature importances.
|
1062
1062
|
|
1063
1063
|
Args:
|
1064
|
-
|
1064
|
+
data: results after running `predict()` as dictionary or the AnnData object.
|
1065
1065
|
key: Key in the AnnData object of the results
|
1066
1066
|
top_n: n number feature importance values to plot. Default is 10.
|
1067
1067
|
ax: optionally, axes used to draw plot
|
1068
|
-
|
1068
|
+
{common_plot_args}
|
1069
1069
|
|
1070
1070
|
Returns:
|
1071
|
-
|
1071
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1072
1072
|
|
1073
1073
|
Examples:
|
1074
1074
|
>>> import pertpy as pt
|
@@ -1109,35 +1109,32 @@ class Augur:
|
|
1109
1109
|
plt.ylabel("Gene")
|
1110
1110
|
plt.yticks(y_axes_range, n_features["genes"])
|
1111
1111
|
|
1112
|
-
if save:
|
1113
|
-
plt.savefig(save, bbox_inches="tight")
|
1114
1112
|
if show:
|
1115
1113
|
plt.show()
|
1116
1114
|
if return_fig:
|
1117
1115
|
return plt.gcf()
|
1118
|
-
if not (show or save):
|
1119
|
-
return ax
|
1120
1116
|
return None
|
1121
1117
|
|
1118
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1122
1119
|
def plot_lollipop(
|
1123
1120
|
self,
|
1124
|
-
data: dict[str, Any],
|
1121
|
+
data: dict[str, Any] | AnnData,
|
1122
|
+
*,
|
1125
1123
|
key: str = "augurpy_results",
|
1126
|
-
return_fig: bool | None = None,
|
1127
1124
|
ax: Axes = None,
|
1128
|
-
show: bool
|
1129
|
-
|
1130
|
-
) ->
|
1125
|
+
show: bool = True,
|
1126
|
+
return_fig: bool = False,
|
1127
|
+
) -> Figure | None:
|
1131
1128
|
"""Plot a lollipop plot of the mean augur values.
|
1132
1129
|
|
1133
1130
|
Args:
|
1134
|
-
|
1135
|
-
key:
|
1136
|
-
ax: optionally, axes used to draw plot
|
1137
|
-
|
1131
|
+
data: results after running `predict()` as dictionary or the AnnData object.
|
1132
|
+
key: .uns key in the results AnnData object.
|
1133
|
+
ax: optionally, axes used to draw plot.
|
1134
|
+
{common_plot_args}
|
1138
1135
|
|
1139
1136
|
Returns:
|
1140
|
-
|
1137
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1141
1138
|
|
1142
1139
|
Examples:
|
1143
1140
|
>>> import pertpy as pt
|
@@ -1175,32 +1172,29 @@ class Augur:
|
|
1175
1172
|
plt.ylabel("Cell Type")
|
1176
1173
|
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1177
1174
|
|
1178
|
-
if save:
|
1179
|
-
plt.savefig(save, bbox_inches="tight")
|
1180
1175
|
if show:
|
1181
1176
|
plt.show()
|
1182
1177
|
if return_fig:
|
1183
1178
|
return plt.gcf()
|
1184
|
-
if not (show or save):
|
1185
|
-
return ax
|
1186
1179
|
return None
|
1187
1180
|
|
1181
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1188
1182
|
def plot_scatterplot(
|
1189
1183
|
self,
|
1190
1184
|
results1: dict[str, Any],
|
1191
1185
|
results2: dict[str, Any],
|
1186
|
+
*,
|
1192
1187
|
top_n: int = None,
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
) -> Axes | Figure | None:
|
1188
|
+
show: bool = True,
|
1189
|
+
return_fig: bool = False,
|
1190
|
+
) -> Figure | None:
|
1197
1191
|
"""Create scatterplot with two augur results.
|
1198
1192
|
|
1199
1193
|
Args:
|
1200
1194
|
results1: results after running `predict()`
|
1201
1195
|
results2: results after running `predict()`
|
1202
1196
|
top_n: optionally, the number of top prioritized cell types to label in the plot
|
1203
|
-
|
1197
|
+
{common_plot_args}
|
1204
1198
|
|
1205
1199
|
Returns:
|
1206
1200
|
Axes of the plot.
|
@@ -1249,12 +1243,8 @@ class Augur:
|
|
1249
1243
|
plt.xlabel("Augur scores 1")
|
1250
1244
|
plt.ylabel("Augur scores 2")
|
1251
1245
|
|
1252
|
-
if save:
|
1253
|
-
plt.savefig(save, bbox_inches="tight")
|
1254
1246
|
if show:
|
1255
1247
|
plt.show()
|
1256
1248
|
if return_fig:
|
1257
1249
|
return plt.gcf()
|
1258
|
-
if not (show or save):
|
1259
|
-
return ax
|
1260
1250
|
return None
|
pertpy/tools/_cinemaot.py
CHANGED
@@ -18,9 +18,12 @@ from sklearn.decomposition import FastICA
|
|
18
18
|
from sklearn.linear_model import LinearRegression
|
19
19
|
from sklearn.neighbors import NearestNeighbors
|
20
20
|
|
21
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
22
|
+
|
21
23
|
if TYPE_CHECKING:
|
22
24
|
from anndata import AnnData
|
23
25
|
from matplotlib.axes import Axes
|
26
|
+
from matplotlib.pyplot import Figure
|
24
27
|
from statsmodels.tools.typing import ArrayLike
|
25
28
|
|
26
29
|
|
@@ -88,7 +91,7 @@ class Cinemaot:
|
|
88
91
|
dim = self.get_dim(adata, use_rep=use_rep)
|
89
92
|
|
90
93
|
transformer = FastICA(n_components=dim, random_state=0, whiten="arbitrary-variance")
|
91
|
-
X_transformed = transformer.fit_transform(adata.obsm[use_rep][:, :dim])
|
94
|
+
X_transformed = np.array(transformer.fit_transform(adata.obsm[use_rep][:, :dim]), dtype=np.float64)
|
92
95
|
groupvec = (adata.obs[pert_key] == control * 1).values # control
|
93
96
|
xi = np.zeros(dim)
|
94
97
|
j = 0
|
@@ -97,9 +100,9 @@ class Cinemaot:
|
|
97
100
|
xi[j] = xi_obj.correlation
|
98
101
|
j = j + 1
|
99
102
|
|
100
|
-
cf = X_transformed[:, xi < thres]
|
101
|
-
cf1 = cf[adata.obs[pert_key] == control, :]
|
102
|
-
cf2 = cf[adata.obs[pert_key] != control, :]
|
103
|
+
cf = np.array(X_transformed[:, xi < thres], np.float64)
|
104
|
+
cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float64)
|
105
|
+
cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float64)
|
103
106
|
if sum(xi < thres) == 1:
|
104
107
|
sklearn.metrics.pairwise_distances(cf1.reshape(-1, 1), cf2.reshape(-1, 1))
|
105
108
|
elif sum(xi < thres) == 0:
|
@@ -167,7 +170,7 @@ class Cinemaot:
|
|
167
170
|
else:
|
168
171
|
_solver = sinkhorn.Sinkhorn(threshold=eps)
|
169
172
|
ot_sink = _solver(ot_prob)
|
170
|
-
ot_matrix = ot_sink.matrix.T
|
173
|
+
ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
|
171
174
|
embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
|
172
175
|
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
|
173
176
|
)
|
@@ -190,7 +193,7 @@ class Cinemaot:
|
|
190
193
|
TE.obsm["X_embedding"] = embedding
|
191
194
|
|
192
195
|
if return_matching:
|
193
|
-
TE.obsm["ot"] = ot_sink.matrix.T
|
196
|
+
TE.obsm["ot"] = np.asarray(ot_sink.matrix.T)
|
194
197
|
return TE
|
195
198
|
else:
|
196
199
|
return TE
|
@@ -639,6 +642,7 @@ class Cinemaot:
|
|
639
642
|
s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6)
|
640
643
|
return c_effect, s_effect
|
641
644
|
|
645
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
642
646
|
def plot_vis_matching(
|
643
647
|
self,
|
644
648
|
adata: AnnData,
|
@@ -647,16 +651,17 @@ class Cinemaot:
|
|
647
651
|
control: str,
|
648
652
|
de_label: str,
|
649
653
|
source_label: str,
|
654
|
+
*,
|
650
655
|
matching_rep: str = "ot",
|
651
656
|
resolution: float = 0.5,
|
652
657
|
normalize: str = "col",
|
653
658
|
title: str = "CINEMA-OT matching matrix",
|
654
659
|
min_val: float = 0.01,
|
655
|
-
show: bool = True,
|
656
|
-
save: str | None = None,
|
657
660
|
ax: Axes | None = None,
|
661
|
+
show: bool = True,
|
662
|
+
return_fig: bool = False,
|
658
663
|
**kwargs,
|
659
|
-
) -> None:
|
664
|
+
) -> Figure | None:
|
660
665
|
"""Visualize the CINEMA-OT matching matrix.
|
661
666
|
|
662
667
|
Args:
|
@@ -670,11 +675,12 @@ class Cinemaot:
|
|
670
675
|
normalize: normalize the coarse-grained matching matrix by row / column.
|
671
676
|
title: the title for the figure.
|
672
677
|
min_val: The min value to truncate the matching matrix.
|
673
|
-
|
674
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
675
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
678
|
+
{common_plot_args}
|
676
679
|
**kwargs: Other parameters to input for seaborn.heatmap.
|
677
680
|
|
681
|
+
Returns:
|
682
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
683
|
+
|
678
684
|
Examples:
|
679
685
|
>>> import pertpy as pt
|
680
686
|
>>> adata = pt.dt.cinemaot_example()
|
@@ -710,12 +716,12 @@ class Cinemaot:
|
|
710
716
|
|
711
717
|
g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
|
712
718
|
plt.title(title)
|
713
|
-
|
714
|
-
if
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
+
|
720
|
+
if show:
|
721
|
+
plt.show()
|
722
|
+
if return_fig:
|
723
|
+
return g
|
724
|
+
return None
|
719
725
|
|
720
726
|
|
721
727
|
class Xi:
|