pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/metadata/_moa.py
ADDED
@@ -0,0 +1,129 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
from rich import print
|
9
|
+
from scanpy import settings
|
10
|
+
|
11
|
+
from pertpy.data._dataloader import _download
|
12
|
+
|
13
|
+
from ._look_up import LookUp
|
14
|
+
from ._metadata import MetaData
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from anndata import AnnData
|
18
|
+
|
19
|
+
|
20
|
+
class Moa(MetaData):
|
21
|
+
"""Utilities to fetch metadata for mechanism of action studies."""
|
22
|
+
|
23
|
+
def __init__(self):
|
24
|
+
self.clue = None
|
25
|
+
|
26
|
+
def _download_clue(self) -> None:
|
27
|
+
clue_path = Path(settings.cachedir) / "repurposing_drugs_20200324.txt"
|
28
|
+
if not Path(clue_path).exists():
|
29
|
+
print("[bold yellow]No metadata file was found for clue. Starting download now.")
|
30
|
+
_download(
|
31
|
+
url="https://s3.amazonaws.com/data.clue.io/repurposing/downloads/repurposing_drugs_20200324.txt",
|
32
|
+
output_file_name="repurposing_drugs_20200324.txt",
|
33
|
+
output_path=settings.cachedir,
|
34
|
+
block_size=4096,
|
35
|
+
is_zip=False,
|
36
|
+
)
|
37
|
+
self.clue = pd.read_csv(clue_path, sep=" ", skiprows=9)
|
38
|
+
self.clue = self.clue[["pert_iname", "moa", "target"]]
|
39
|
+
|
40
|
+
def annotate(
|
41
|
+
self,
|
42
|
+
adata: AnnData,
|
43
|
+
query_id: str = "perturbation",
|
44
|
+
target: str | None = None,
|
45
|
+
verbosity: int | str = 5,
|
46
|
+
copy: bool = False,
|
47
|
+
) -> AnnData:
|
48
|
+
"""Annotate cells affected by perturbations by mechanism of action.
|
49
|
+
|
50
|
+
For each cell, we fetch the mechanism of action and molecular targets of the compounds sourced from clue.io.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
adata: The data object to annotate.
|
54
|
+
query_id: The column of `.obs` with the name of a perturbagen. Defaults to 'perturbation'.
|
55
|
+
target: The column of `.obs` with target information. If set to None, all MoAs are retrieved without comparing molecular targets.
|
56
|
+
Defaults to None.
|
57
|
+
verbosity: The number of unmatched identifiers to print, can be either non-negative values or 'all'.
|
58
|
+
Defaults to 5.
|
59
|
+
copy: Determines whether a copy of the `adata` is returned. Defaults to False.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
Returns an AnnData object with MoA annotation.
|
63
|
+
"""
|
64
|
+
if copy:
|
65
|
+
adata = adata.copy()
|
66
|
+
|
67
|
+
if query_id not in adata.obs.columns:
|
68
|
+
raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n" "Please check again.")
|
69
|
+
|
70
|
+
if self.clue is None:
|
71
|
+
self._download_clue()
|
72
|
+
|
73
|
+
identifier_num_all = len(adata.obs[query_id].unique())
|
74
|
+
not_matched_identifiers = list(set(adata.obs[query_id].str.lower()) - set(self.clue["pert_iname"].str.lower()))
|
75
|
+
self._warn_unmatch(
|
76
|
+
total_identifiers=identifier_num_all,
|
77
|
+
unmatched_identifiers=not_matched_identifiers,
|
78
|
+
query_id=query_id,
|
79
|
+
reference_id="pert_iname",
|
80
|
+
metadata_type="moa",
|
81
|
+
verbosity=verbosity,
|
82
|
+
)
|
83
|
+
|
84
|
+
adata.obs = (
|
85
|
+
adata.obs.merge(
|
86
|
+
self.clue,
|
87
|
+
left_on=adata.obs[query_id].str.lower(),
|
88
|
+
right_on=self.clue["pert_iname"].str.lower(),
|
89
|
+
how="left",
|
90
|
+
suffixes=("", "_fromMeta"),
|
91
|
+
)
|
92
|
+
.set_index(adata.obs.index)
|
93
|
+
.drop("key_0", axis=1)
|
94
|
+
)
|
95
|
+
|
96
|
+
# If target column is given, check whether it is one of the targets listed in the metadata
|
97
|
+
# If inconsistent, treat this perturbagen as unmatched and overwrite the annotated metadata with NaN
|
98
|
+
if target is not None:
|
99
|
+
target_meta = "target" if target != "target" else "target_fromMeta"
|
100
|
+
adata.obs[target_meta] = adata.obs[target_meta].mask(
|
101
|
+
~adata.obs.apply(lambda row: str(row[target]) in str(row[target_meta]), axis=1)
|
102
|
+
)
|
103
|
+
pertname_meta = "pert_iname" if query_id != "pert_iname" else "pert_iname_fromMeta"
|
104
|
+
adata.obs.loc[adata.obs[target_meta].isna(), [pertname_meta, "moa"]] = np.nan
|
105
|
+
|
106
|
+
# If query_id and reference_id have different names, there will be a column for each of them after merging
|
107
|
+
# which is redundant as they refer to the same information.
|
108
|
+
if query_id != "pert_iname":
|
109
|
+
del adata.obs["pert_iname"]
|
110
|
+
|
111
|
+
return adata
|
112
|
+
|
113
|
+
def lookup(self) -> LookUp:
|
114
|
+
"""Generate LookUp object for Moa metadata.
|
115
|
+
|
116
|
+
The LookUp object provides an overview of the metadata to annotate.
|
117
|
+
annotate_moa function has a corresponding lookup function in the LookUp object,
|
118
|
+
where users can search the query_ids and targets in the metadata.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
Returns a LookUp object specific for MoA annotation.
|
122
|
+
"""
|
123
|
+
if self.clue is None:
|
124
|
+
self._download_clue()
|
125
|
+
|
126
|
+
return LookUp(
|
127
|
+
type="moa",
|
128
|
+
transfer_metadata=[self.clue],
|
129
|
+
)
|
pertpy/plot/__init__.py
CHANGED
@@ -1,13 +1,5 @@
|
|
1
1
|
from pertpy.plot._augur import AugurpyPlot as ag
|
2
|
-
from pertpy.plot.
|
3
|
-
|
4
|
-
try:
|
5
|
-
from pertpy.plot._coda import CodaPlot as coda
|
6
|
-
except ImportError:
|
7
|
-
pass
|
8
|
-
|
9
|
-
from pertpy.plot._cinemaot import CinemaotPlot as cot
|
2
|
+
from pertpy.plot._coda import CodaPlot as coda
|
10
3
|
from pertpy.plot._guide_rna import GuideRnaPlot as guide
|
11
4
|
from pertpy.plot._milopy import MilopyPlot as milo
|
12
5
|
from pertpy.plot._mixscape import MixscapePlot as ms
|
13
|
-
from pertpy.plot._scgen import JaxscgenPlot as scg
|
pertpy/plot/_augur.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import warnings
|
3
4
|
from typing import TYPE_CHECKING, Any
|
4
5
|
|
5
6
|
from anndata import AnnData
|
@@ -15,7 +16,7 @@ class AugurpyPlot:
|
|
15
16
|
"""Plotting functions for Augurpy."""
|
16
17
|
|
17
18
|
@staticmethod
|
18
|
-
def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None
|
19
|
+
def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None) -> Figure | Axes:
|
19
20
|
"""Plot result of differential prioritization.
|
20
21
|
|
21
22
|
Args:
|
@@ -42,38 +43,24 @@ class AugurpyPlot:
|
|
42
43
|
|
43
44
|
>>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
|
44
45
|
permuted_results1=results_15_permute, permuted_results2=results_48_permute)
|
45
|
-
>>>
|
46
|
+
>>> ag_rfc.plot_dp_scatter(pvals)
|
46
47
|
"""
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
# adding optional labels
|
55
|
-
top_n_index = results.sort_values(by="pval").index[:top_n]
|
56
|
-
for idx in top_n_index:
|
57
|
-
ax.annotate(
|
58
|
-
results.loc[idx, "cell_type"],
|
59
|
-
(results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
|
60
|
-
)
|
48
|
+
warnings.warn(
|
49
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
50
|
+
" Please use the corresponding 'pt.tl' object",
|
51
|
+
FutureWarning,
|
52
|
+
stacklevel=2,
|
53
|
+
)
|
61
54
|
|
62
|
-
|
63
|
-
limits = max(ax.get_xlim(), ax.get_ylim())
|
64
|
-
(diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
|
55
|
+
from pertpy.tools import Augur
|
65
56
|
|
66
|
-
|
67
|
-
plt.xlabel("Augur scores 1")
|
68
|
-
plt.ylabel("Augur scores 2")
|
69
|
-
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
70
|
-
ax.add_artist(legend1)
|
57
|
+
ag = Augur("random_forest_classifier")
|
71
58
|
|
72
|
-
return
|
59
|
+
return ag.plot_dp_scatter(results=results, top_n=top_n, ax=ax)
|
73
60
|
|
74
61
|
@staticmethod
|
75
62
|
def important_features(
|
76
|
-
data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None
|
63
|
+
data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None
|
77
64
|
) -> Figure | Axes:
|
78
65
|
"""Plot a lollipop plot of the n features with largest feature importances.
|
79
66
|
|
@@ -92,44 +79,26 @@ class AugurpyPlot:
|
|
92
79
|
>>> adata = pt.dt.sc_sim_augur()
|
93
80
|
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
94
81
|
>>> loaded_data = ag_rfc.load(adata)
|
95
|
-
>>> v_adata, v_results = ag_rfc.predict(
|
96
|
-
|
82
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
83
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
84
|
+
... )
|
85
|
+
>>> ag_rfc.plot_important_features(v_results)
|
97
86
|
"""
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
n_features = (
|
104
|
-
results["feature_importances"]
|
105
|
-
.groupby("genes", as_index=False)
|
106
|
-
.feature_importances.mean()
|
107
|
-
.sort_values(by="feature_importances")[-top_n:]
|
87
|
+
warnings.warn(
|
88
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
89
|
+
" Please use the corresponding 'pt.tl' object",
|
90
|
+
FutureWarning,
|
91
|
+
stacklevel=2,
|
108
92
|
)
|
109
93
|
|
110
|
-
|
111
|
-
fig, ax = plt.subplots()
|
112
|
-
y_axes_range = range(1, top_n + 1)
|
113
|
-
ax.hlines(
|
114
|
-
y_axes_range,
|
115
|
-
xmin=0,
|
116
|
-
xmax=n_features["feature_importances"],
|
117
|
-
)
|
118
|
-
|
119
|
-
# drawing the markers (circle)
|
120
|
-
ax.plot(n_features["feature_importances"], y_axes_range, "o")
|
94
|
+
from pertpy.tools import Augur
|
121
95
|
|
122
|
-
|
123
|
-
plt.xlabel("Mean Feature Importance")
|
124
|
-
plt.ylabel("Gene")
|
125
|
-
plt.yticks(y_axes_range, n_features["genes"])
|
96
|
+
ag = Augur("random_forest_classifier")
|
126
97
|
|
127
|
-
return
|
98
|
+
return ag.plot_important_features(data=data, key=key, top_n=top_n, ax=ax)
|
128
99
|
|
129
100
|
@staticmethod
|
130
|
-
def lollipop(
|
131
|
-
data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None, return_figure: bool = False
|
132
|
-
) -> Figure | Axes:
|
101
|
+
def lollipop(data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None) -> Figure | Axes | None:
|
133
102
|
"""Plot a lollipop plot of the mean augur values.
|
134
103
|
|
135
104
|
Args:
|
@@ -146,40 +115,26 @@ class AugurpyPlot:
|
|
146
115
|
>>> adata = pt.dt.sc_sim_augur()
|
147
116
|
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
148
117
|
>>> loaded_data = ag_rfc.load(adata)
|
149
|
-
>>> v_adata, v_results = ag_rfc.predict(
|
150
|
-
|
118
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
119
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
120
|
+
... )
|
121
|
+
>>> ag_rfc.plot_lollipop(v_results)
|
151
122
|
"""
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
fig, ax = plt.subplots()
|
158
|
-
y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
|
159
|
-
ax.hlines(
|
160
|
-
y_axes_range,
|
161
|
-
xmin=0,
|
162
|
-
xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
123
|
+
warnings.warn(
|
124
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
125
|
+
" Please use the corresponding 'pt.tl' object",
|
126
|
+
FutureWarning,
|
127
|
+
stacklevel=2,
|
163
128
|
)
|
164
129
|
|
165
|
-
|
166
|
-
ax.plot(
|
167
|
-
results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
168
|
-
y_axes_range,
|
169
|
-
"o",
|
170
|
-
)
|
130
|
+
from pertpy.tools import Augur
|
171
131
|
|
172
|
-
|
173
|
-
plt.xlabel("Mean Augur Score")
|
174
|
-
plt.ylabel("Cell Type")
|
175
|
-
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
132
|
+
ag = Augur("random_forest_classifier")
|
176
133
|
|
177
|
-
return
|
134
|
+
return ag.plot_lollipop(data=data, key=key, ax=ax)
|
178
135
|
|
179
136
|
@staticmethod
|
180
|
-
def scatterplot(
|
181
|
-
results1: dict[str, Any], results2: dict[str, Any], top_n=None, return_figure: bool = False
|
182
|
-
) -> Figure | Axes:
|
137
|
+
def scatterplot(results1: dict[str, Any], results2: dict[str, Any], top_n=None) -> Figure | Axes:
|
183
138
|
"""Create scatterplot with two augur results.
|
184
139
|
|
185
140
|
Args:
|
@@ -197,38 +152,20 @@ class AugurpyPlot:
|
|
197
152
|
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
198
153
|
>>> loaded_data = ag_rfc.load(adata)
|
199
154
|
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
|
200
|
-
>>> v_adata, v_results = ag_rfc.predict(
|
201
|
-
|
155
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
156
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
157
|
+
... )
|
158
|
+
>>> ag_rfc.plot_scatterplot(v_results, h_results)
|
202
159
|
"""
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
results2["summary_metrics"].loc["mean_augur_score", cell_types],
|
160
|
+
warnings.warn(
|
161
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
162
|
+
" Please use the corresponding 'pt.tl' object",
|
163
|
+
FutureWarning,
|
164
|
+
stacklevel=2,
|
209
165
|
)
|
210
166
|
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
)
|
217
|
-
for txt in top_n_cell_types:
|
218
|
-
ax.annotate(
|
219
|
-
txt,
|
220
|
-
(
|
221
|
-
results1["summary_metrics"].loc["mean_augur_score", txt],
|
222
|
-
results2["summary_metrics"].loc["mean_augur_score", txt],
|
223
|
-
),
|
224
|
-
)
|
225
|
-
|
226
|
-
# adding diagonal
|
227
|
-
limits = max(ax.get_xlim(), ax.get_ylim())
|
228
|
-
(diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
|
229
|
-
|
230
|
-
# formatting and details
|
231
|
-
plt.xlabel("Augur scores 1")
|
232
|
-
plt.ylabel("Augur scores 2")
|
233
|
-
|
234
|
-
return fig if return_figure else ax
|
167
|
+
from pertpy.tools import Augur
|
168
|
+
|
169
|
+
ag = Augur("random_forest_classifier")
|
170
|
+
|
171
|
+
return ag.plot_scatterplot(results1=results1, results2=results2, top_n=top_n)
|