pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import uuid
|
3
4
|
from typing import TYPE_CHECKING
|
4
5
|
|
5
6
|
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
import scanpy as sc
|
6
9
|
import scipy
|
7
10
|
|
8
11
|
if TYPE_CHECKING:
|
9
12
|
from anndata import AnnData
|
13
|
+
from matplotlib.axes import Axes
|
10
14
|
|
11
15
|
|
12
16
|
class GuideAssignment:
|
@@ -39,7 +43,7 @@ class GuideAssignment:
|
|
39
43
|
|
40
44
|
>>> import pertpy as pt
|
41
45
|
>>> mdata = pt.data.papalexi_2021()
|
42
|
-
>>> gdo = mdata.mod[
|
46
|
+
>>> gdo = mdata.mod["gdo"]
|
43
47
|
>>> ga = pt.pp.GuideAssignment()
|
44
48
|
>>> ga.assign_by_threshold(gdo, assignment_threshold=5)
|
45
49
|
"""
|
@@ -71,7 +75,6 @@ class GuideAssignment:
|
|
71
75
|
|
72
76
|
Args:
|
73
77
|
adata: Annotated data matrix containing gRNA values
|
74
|
-
assignment_threshold: If a gRNA is available for at least `assignment_threshold`, it will be recognized as assigned.
|
75
78
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
76
79
|
layer: Key to the layer containing raw count values of the gRNAs.
|
77
80
|
adata.X is used if layer is None. Expects count data.
|
@@ -83,8 +86,8 @@ class GuideAssignment:
|
|
83
86
|
Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
|
84
87
|
|
85
88
|
>>> import pertpy as pt
|
86
|
-
>>> mdata = pt.
|
87
|
-
>>> gdo = mdata.mod[
|
89
|
+
>>> mdata = pt.dt.papalexi_2021()
|
90
|
+
>>> gdo = mdata.mod["gdo"]
|
88
91
|
>>> ga = pt.pp.GuideAssignment()
|
89
92
|
>>> ga.assign_to_max_guide(gdo, assignment_threshold=5)
|
90
93
|
"""
|
@@ -103,3 +106,84 @@ class GuideAssignment:
|
|
103
106
|
adata.obs[output_key] = assigned_grna
|
104
107
|
|
105
108
|
return None
|
109
|
+
|
110
|
+
def plot_heatmap(
|
111
|
+
self,
|
112
|
+
adata: AnnData,
|
113
|
+
layer: str | None = None,
|
114
|
+
order_by: np.ndarray | str | None = None,
|
115
|
+
key_to_save_order: str = None,
|
116
|
+
**kwargs,
|
117
|
+
) -> list[Axes]:
|
118
|
+
"""Heatmap plotting of guide RNA expression matrix.
|
119
|
+
|
120
|
+
Assuming guides have sparse expression, this function reorders cells
|
121
|
+
and plots guide RNA expression so that a nice sparse representation is achieved.
|
122
|
+
The cell ordering can be stored and reused in future plots to obtain consistent
|
123
|
+
plots before and after analysis of the guide RNA expression.
|
124
|
+
Note: This function expects a log-normalized or binary data.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
adata: Annotated data matrix containing gRNA values
|
128
|
+
layer: Key to the layer containing log normalized count values of the gRNAs.
|
129
|
+
adata.X is used if layer is None.
|
130
|
+
order_by: The order of cells in y axis. Defaults to None.
|
131
|
+
If None, cells will be reordered to have a nice sparse representation.
|
132
|
+
If a string is provided, adata.obs[order_by] will be used as the order.
|
133
|
+
If a numpy array is provided, the array will be used for ordering.
|
134
|
+
key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
|
135
|
+
kwargs: Are passed to sc.pl.heatmap.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
|
139
|
+
Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
|
140
|
+
|
141
|
+
Examples:
|
142
|
+
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
|
143
|
+
visualized using a heatmap.
|
144
|
+
|
145
|
+
>>> import pertpy as pt
|
146
|
+
>>> mdata = pt.dt.papalexi_2021()
|
147
|
+
>>> gdo = mdata.mod["gdo"]
|
148
|
+
>>> ga = pt.pp.GuideAssignment()
|
149
|
+
>>> ga.assign_by_threshold(gdo, assignment_threshold=5)
|
150
|
+
>>> ga.plot_heatmap(gdo)
|
151
|
+
"""
|
152
|
+
data = adata.X if layer is None else adata.layers[layer]
|
153
|
+
|
154
|
+
if order_by is None:
|
155
|
+
if scipy.sparse.issparse(data):
|
156
|
+
max_values = data.max(axis=1).A.squeeze()
|
157
|
+
data_argmax = data.argmax(axis=1).A.squeeze()
|
158
|
+
max_guide_index = np.where(max_values != data.min(axis=1).A.squeeze(), data_argmax, -1)
|
159
|
+
else:
|
160
|
+
max_guide_index = np.where(
|
161
|
+
data.max(axis=1).squeeze() != data.min(axis=1).squeeze(), data.argmax(axis=1).squeeze(), -1
|
162
|
+
)
|
163
|
+
order = np.argsort(max_guide_index)
|
164
|
+
elif isinstance(order_by, str):
|
165
|
+
order = np.argsort(adata.obs[order_by])
|
166
|
+
else:
|
167
|
+
order = order_by
|
168
|
+
|
169
|
+
temp_col_name = f"_tmp_pertpy_grna_plot_{uuid.uuid4()}"
|
170
|
+
adata.obs[temp_col_name] = pd.Categorical(["" for _ in range(adata.shape[0])])
|
171
|
+
|
172
|
+
if key_to_save_order is not None:
|
173
|
+
adata.obs[key_to_save_order] = pd.Categorical(order)
|
174
|
+
|
175
|
+
try:
|
176
|
+
axis_group = sc.pl.heatmap(
|
177
|
+
adata[order, :],
|
178
|
+
var_names=adata.var.index.tolist(),
|
179
|
+
groupby=temp_col_name,
|
180
|
+
cmap="viridis",
|
181
|
+
use_raw=False,
|
182
|
+
dendrogram=False,
|
183
|
+
layer=layer,
|
184
|
+
**kwargs,
|
185
|
+
)
|
186
|
+
finally:
|
187
|
+
del adata.obs[temp_col_name]
|
188
|
+
|
189
|
+
return axis_group
|
pertpy/tools/__init__.py
CHANGED
@@ -1,24 +1,19 @@
|
|
1
|
-
from rich import print
|
2
|
-
|
3
1
|
from pertpy.tools._augur import Augur
|
4
2
|
from pertpy.tools._cinemaot import Cinemaot
|
3
|
+
from pertpy.tools._coda._sccoda import Sccoda
|
4
|
+
from pertpy.tools._coda._tasccoda import Tasccoda
|
5
5
|
from pertpy.tools._dialogue import Dialogue
|
6
6
|
from pertpy.tools._differential_gene_expression import DifferentialGeneExpression
|
7
7
|
from pertpy.tools._distances._distance_tests import DistanceTest
|
8
8
|
from pertpy.tools._distances._distances import Distance
|
9
|
-
from pertpy.tools.
|
9
|
+
from pertpy.tools._enrichment import Enrichment
|
10
10
|
from pertpy.tools._milo import Milo
|
11
11
|
from pertpy.tools._mixscape import Mixscape
|
12
12
|
from pertpy.tools._perturbation_space._clustering import ClusteringSpace
|
13
|
-
from pertpy.tools._perturbation_space.
|
13
|
+
from pertpy.tools._perturbation_space._discriminator_classifiers import (
|
14
|
+
DiscriminatorClassifierSpace,
|
15
|
+
LRClassifierSpace,
|
16
|
+
MLPClassifierSpace,
|
17
|
+
)
|
14
18
|
from pertpy.tools._perturbation_space._simple import CentroidSpace, DBSCANSpace, KMeansSpace, PseudobulkSpace
|
15
19
|
from pertpy.tools._scgen import SCGEN
|
16
|
-
|
17
|
-
try:
|
18
|
-
from pertpy.tools._coda._sccoda import Sccoda
|
19
|
-
from pertpy.tools._coda._tasccoda import Tasccoda
|
20
|
-
except ImportError as e:
|
21
|
-
if "ete3" in str(e):
|
22
|
-
print("[bold yellow]To use sccoda or tasccoda please install ete3 with [green]pip install ete3")
|
23
|
-
else:
|
24
|
-
raise e
|
pertpy/tools/_augur.py
CHANGED
@@ -4,8 +4,10 @@ import random
|
|
4
4
|
from collections import defaultdict
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from math import floor, nan
|
7
|
-
from typing import Any, Literal
|
7
|
+
from typing import TYPE_CHECKING, Any, Literal
|
8
8
|
|
9
|
+
import anndata as ad
|
10
|
+
import matplotlib.pyplot as plt
|
9
11
|
import numpy as np
|
10
12
|
import pandas as pd
|
11
13
|
import scanpy as sc
|
@@ -34,6 +36,10 @@ from sklearn.preprocessing import LabelEncoder
|
|
34
36
|
from skmisc.loess import loess
|
35
37
|
from statsmodels.stats.multitest import fdrcorrection
|
36
38
|
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from matplotlib.axes import Axes
|
41
|
+
from matplotlib.figure import Figure
|
42
|
+
|
37
43
|
|
38
44
|
@dataclass
|
39
45
|
class Params:
|
@@ -135,8 +141,8 @@ class Augur:
|
|
135
141
|
# filter samples according to label
|
136
142
|
if condition_label is not None and treatment_label is not None:
|
137
143
|
print(f"Filtering samples with {condition_label} and {treatment_label} labels.")
|
138
|
-
adata =
|
139
|
-
adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
|
144
|
+
adata = ad.concat(
|
145
|
+
[adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
|
140
146
|
)
|
141
147
|
label_encoder = LabelEncoder()
|
142
148
|
adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
|
@@ -214,7 +220,9 @@ class Augur:
|
|
214
220
|
>>> loaded_data = ag_rfc.load(adata)
|
215
221
|
>>> ag_rfc.select_highly_variable(loaded_data)
|
216
222
|
>>> features = loaded_data.var_names
|
217
|
-
>>> subsample = ag_rfc.sample(
|
223
|
+
>>> subsample = ag_rfc.sample(
|
224
|
+
... loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names
|
225
|
+
... )
|
218
226
|
"""
|
219
227
|
# export subsampling.
|
220
228
|
random.seed(random_state)
|
@@ -230,7 +238,7 @@ class Augur:
|
|
230
238
|
random_state=random_state,
|
231
239
|
)
|
232
240
|
)
|
233
|
-
subsample =
|
241
|
+
subsample = ad.concat([*label_subsamples], index_unique=None)
|
234
242
|
else:
|
235
243
|
subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
|
236
244
|
|
@@ -409,8 +417,8 @@ class Augur:
|
|
409
417
|
"""
|
410
418
|
if multiclass:
|
411
419
|
return {
|
412
|
-
"augur_score": make_scorer(roc_auc_score, multi_class="ovo",
|
413
|
-
"auc": make_scorer(roc_auc_score, multi_class="ovo",
|
420
|
+
"augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
|
421
|
+
"auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
|
414
422
|
"accuracy": make_scorer(accuracy_score),
|
415
423
|
"precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
|
416
424
|
"f1": make_scorer(f1_score, average="macro"),
|
@@ -418,8 +426,8 @@ class Augur:
|
|
418
426
|
}
|
419
427
|
return (
|
420
428
|
{
|
421
|
-
"augur_score": make_scorer(roc_auc_score,
|
422
|
-
"auc": make_scorer(roc_auc_score,
|
429
|
+
"augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
|
430
|
+
"auc": make_scorer(roc_auc_score, response_method="predict_proba"),
|
423
431
|
"accuracy": make_scorer(accuracy_score),
|
424
432
|
"precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
|
425
433
|
"f1": make_scorer(f1_score, average="binary"),
|
@@ -488,7 +496,7 @@ class Augur:
|
|
488
496
|
# feature importances
|
489
497
|
feature_importances = defaultdict(list)
|
490
498
|
if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
|
491
|
-
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
|
499
|
+
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
|
492
500
|
feature_importances["genes"].extend(x.columns.tolist())
|
493
501
|
feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
|
494
502
|
feature_importances["subsample_idx"].extend(len(x.columns) * [subsample_idx])
|
@@ -497,7 +505,7 @@ class Augur:
|
|
497
505
|
# standardized coefficients with Agresti method
|
498
506
|
# cf. https://think-lab.github.io/d/205/#3
|
499
507
|
if isinstance(self.estimator, LogisticRegression):
|
500
|
-
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
|
508
|
+
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
|
501
509
|
feature_importances["genes"].extend(x.columns.tolist())
|
502
510
|
feature_importances["feature_importances"].extend(
|
503
511
|
(self.estimator.coef_ * self.estimator.coef_.std()).flatten().tolist()
|
@@ -723,6 +731,7 @@ class Augur:
|
|
723
731
|
>>> loaded_data = ag_rfc.load(adata)
|
724
732
|
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
|
725
733
|
"""
|
734
|
+
adata = adata.copy()
|
726
735
|
if augur_mode == "permute" and n_subsamples < 100:
|
727
736
|
n_subsamples = 500
|
728
737
|
if is_regressor(self.estimator) and len(adata.obs["y_"].unique()) <= 3:
|
@@ -765,6 +774,7 @@ class Augur:
|
|
765
774
|
elif (
|
766
775
|
cell_type_subsample.obs.groupby(
|
767
776
|
["cell_type", "label"],
|
777
|
+
observed=True,
|
768
778
|
).y_.count()
|
769
779
|
< subsample_size
|
770
780
|
).any():
|
@@ -804,7 +814,7 @@ class Augur:
|
|
804
814
|
* (len(results["feature_importances"]["genes"]) - len(results["feature_importances"]["cell_type"]))
|
805
815
|
)
|
806
816
|
|
807
|
-
for idx, cv in zip(range(n_subsamples), results[cell_type]):
|
817
|
+
for idx, cv in zip(range(n_subsamples), results[cell_type], strict=False):
|
808
818
|
results["full_results"]["idx"].extend([idx] * folds)
|
809
819
|
results["full_results"]["augur_score"].extend(cv["test_augur_score"])
|
810
820
|
results["full_results"]["folds"].extend(range(folds))
|
@@ -869,28 +879,31 @@ class Augur:
|
|
869
879
|
& set(permuted_results1["summary_metrics"].columns)
|
870
880
|
& set(permuted_results2["summary_metrics"].columns)
|
871
881
|
)
|
882
|
+
|
883
|
+
cell_types_list = list(cell_types)
|
884
|
+
|
872
885
|
# mean augur scores
|
873
886
|
augur_score1 = (
|
874
887
|
augur_results1["summary_metrics"]
|
875
|
-
.loc["mean_augur_score",
|
888
|
+
.loc["mean_augur_score", cell_types_list]
|
876
889
|
.reset_index()
|
877
890
|
.rename(columns={"index": "cell_type"})
|
878
891
|
)
|
879
892
|
augur_score2 = (
|
880
893
|
augur_results2["summary_metrics"]
|
881
|
-
.loc["mean_augur_score",
|
894
|
+
.loc["mean_augur_score", cell_types_list]
|
882
895
|
.reset_index()
|
883
896
|
.rename(columns={"index": "cell_type"})
|
884
897
|
)
|
885
898
|
|
886
899
|
# mean permuted scores over cross validation runs
|
887
900
|
permuted_cv_augur1 = (
|
888
|
-
permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(
|
901
|
+
permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types_list)]
|
889
902
|
.groupby(["cell_type", "idx"], as_index=False)
|
890
903
|
.mean()
|
891
904
|
)
|
892
905
|
permuted_cv_augur2 = (
|
893
|
-
permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(
|
906
|
+
permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types_list)]
|
894
907
|
.groupby(["cell_type", "idx"], as_index=False)
|
895
908
|
.mean()
|
896
909
|
)
|
@@ -901,7 +914,7 @@ class Augur:
|
|
901
914
|
# draw mean aucs for permute1 and permute2
|
902
915
|
for celltype in permuted_cv_augur1["cell_type"].unique():
|
903
916
|
df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
|
904
|
-
df2 = permuted_cv_augur2[
|
917
|
+
df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
|
905
918
|
for permutation_idx in range(n_permutations):
|
906
919
|
# subsample
|
907
920
|
sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
|
@@ -961,3 +974,288 @@ class Augur:
|
|
961
974
|
delta["padj"] = fdrcorrection(delta["pval"])[1]
|
962
975
|
|
963
976
|
return delta
|
977
|
+
|
978
|
+
def plot_dp_scatter(
|
979
|
+
self,
|
980
|
+
results: pd.DataFrame,
|
981
|
+
top_n: int = None,
|
982
|
+
return_fig: bool | None = None,
|
983
|
+
ax: Axes = None,
|
984
|
+
show: bool | None = None,
|
985
|
+
save: str | bool | None = None,
|
986
|
+
) -> Axes | Figure | None:
|
987
|
+
"""Plot scatterplot of differential prioritization.
|
988
|
+
|
989
|
+
Args:
|
990
|
+
results: Results after running differential prioritization.
|
991
|
+
top_n: optionally, the number of top prioritized cell types to label in the plot
|
992
|
+
ax: optionally, axes used to draw plot
|
993
|
+
|
994
|
+
Returns:
|
995
|
+
Axes of the plot.
|
996
|
+
|
997
|
+
Examples:
|
998
|
+
>>> import pertpy as pt
|
999
|
+
>>> adata = pt.dt.bhattacherjee()
|
1000
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1001
|
+
|
1002
|
+
>>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine")
|
1003
|
+
>>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4)
|
1004
|
+
>>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
|
1005
|
+
|
1006
|
+
>>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine")
|
1007
|
+
>>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4)
|
1008
|
+
>>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
|
1009
|
+
|
1010
|
+
>>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
|
1011
|
+
permuted_results1=results_15_permute, permuted_results2=results_48_permute)
|
1012
|
+
>>> ag_rfc.plot_dp_scatter(pvals)
|
1013
|
+
|
1014
|
+
Preview:
|
1015
|
+
.. image:: /_static/docstring_previews/augur_dp_scatter.png
|
1016
|
+
"""
|
1017
|
+
x = results["mean_augur_score1"]
|
1018
|
+
y = results["mean_augur_score2"]
|
1019
|
+
|
1020
|
+
if ax is None:
|
1021
|
+
fig, ax = plt.subplots()
|
1022
|
+
scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
|
1023
|
+
|
1024
|
+
# adding optional labels
|
1025
|
+
top_n_index = results.sort_values(by="pval").index[:top_n]
|
1026
|
+
for idx in top_n_index:
|
1027
|
+
ax.annotate(
|
1028
|
+
results.loc[idx, "cell_type"],
|
1029
|
+
(results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
|
1030
|
+
)
|
1031
|
+
|
1032
|
+
# add diagonal
|
1033
|
+
limits = max(ax.get_xlim(), ax.get_ylim())
|
1034
|
+
(_,) = ax.plot(limits, limits, ls="--", c=".3")
|
1035
|
+
|
1036
|
+
# formatting and details
|
1037
|
+
plt.xlabel("Augur scores 1")
|
1038
|
+
plt.ylabel("Augur scores 2")
|
1039
|
+
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1040
|
+
ax.add_artist(legend1)
|
1041
|
+
|
1042
|
+
if save:
|
1043
|
+
plt.savefig(save, bbox_inches="tight")
|
1044
|
+
if show:
|
1045
|
+
plt.show()
|
1046
|
+
if return_fig:
|
1047
|
+
return plt.gcf()
|
1048
|
+
if not (show or save):
|
1049
|
+
return ax
|
1050
|
+
return None
|
1051
|
+
|
1052
|
+
def plot_important_features(
|
1053
|
+
self,
|
1054
|
+
data: dict[str, Any],
|
1055
|
+
key: str = "augurpy_results",
|
1056
|
+
top_n: int = 10,
|
1057
|
+
return_fig: bool | None = None,
|
1058
|
+
ax: Axes = None,
|
1059
|
+
show: bool | None = None,
|
1060
|
+
save: str | bool | None = None,
|
1061
|
+
) -> Axes | None:
|
1062
|
+
"""Plot a lollipop plot of the n features with largest feature importances.
|
1063
|
+
|
1064
|
+
Args:
|
1065
|
+
results: results after running `predict()` as dictionary or the AnnData object.
|
1066
|
+
key: Key in the AnnData object of the results
|
1067
|
+
top_n: n number feature importance values to plot. Default is 10.
|
1068
|
+
ax: optionally, axes used to draw plot
|
1069
|
+
return_figure: if `True` returns figure of the plot, default is `False`
|
1070
|
+
|
1071
|
+
Returns:
|
1072
|
+
Axes of the plot.
|
1073
|
+
|
1074
|
+
Examples:
|
1075
|
+
>>> import pertpy as pt
|
1076
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1077
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1078
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1079
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1080
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1081
|
+
... )
|
1082
|
+
>>> ag_rfc.plot_important_features(v_results)
|
1083
|
+
|
1084
|
+
Preview:
|
1085
|
+
.. image:: /_static/docstring_previews/augur_important_features.png
|
1086
|
+
"""
|
1087
|
+
if isinstance(data, AnnData):
|
1088
|
+
results = data.uns[key]
|
1089
|
+
else:
|
1090
|
+
results = data
|
1091
|
+
n_features = (
|
1092
|
+
results["feature_importances"]
|
1093
|
+
.groupby("genes", as_index=False)
|
1094
|
+
.feature_importances.mean()
|
1095
|
+
.sort_values(by="feature_importances")[-top_n:]
|
1096
|
+
)
|
1097
|
+
|
1098
|
+
if ax is None:
|
1099
|
+
fig, ax = plt.subplots()
|
1100
|
+
y_axes_range = range(1, top_n + 1)
|
1101
|
+
ax.hlines(
|
1102
|
+
y_axes_range,
|
1103
|
+
xmin=0,
|
1104
|
+
xmax=n_features["feature_importances"],
|
1105
|
+
)
|
1106
|
+
|
1107
|
+
ax.plot(n_features["feature_importances"], y_axes_range, "o")
|
1108
|
+
|
1109
|
+
plt.xlabel("Mean Feature Importance")
|
1110
|
+
plt.ylabel("Gene")
|
1111
|
+
plt.yticks(y_axes_range, n_features["genes"])
|
1112
|
+
|
1113
|
+
if save:
|
1114
|
+
plt.savefig(save, bbox_inches="tight")
|
1115
|
+
if show:
|
1116
|
+
plt.show()
|
1117
|
+
if return_fig:
|
1118
|
+
return plt.gcf()
|
1119
|
+
if not (show or save):
|
1120
|
+
return ax
|
1121
|
+
return None
|
1122
|
+
|
1123
|
+
def plot_lollipop(
|
1124
|
+
self,
|
1125
|
+
data: dict[str, Any],
|
1126
|
+
key: str = "augurpy_results",
|
1127
|
+
return_fig: bool | None = None,
|
1128
|
+
ax: Axes = None,
|
1129
|
+
show: bool | None = None,
|
1130
|
+
save: str | bool | None = None,
|
1131
|
+
) -> Axes | Figure | None:
|
1132
|
+
"""Plot a lollipop plot of the mean augur values.
|
1133
|
+
|
1134
|
+
Args:
|
1135
|
+
results: results after running `predict()` as dictionary or the AnnData object.
|
1136
|
+
key: Key in the AnnData object of the results
|
1137
|
+
ax: optionally, axes used to draw plot
|
1138
|
+
return_figure: if `True` returns figure of the plot
|
1139
|
+
|
1140
|
+
Returns:
|
1141
|
+
Axes of the plot.
|
1142
|
+
|
1143
|
+
Examples:
|
1144
|
+
>>> import pertpy as pt
|
1145
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1146
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1147
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1148
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1149
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1150
|
+
... )
|
1151
|
+
>>> ag_rfc.plot_lollipop(v_results)
|
1152
|
+
|
1153
|
+
Preview:
|
1154
|
+
.. image:: /_static/docstring_previews/augur_lollipop.png
|
1155
|
+
"""
|
1156
|
+
if isinstance(data, AnnData):
|
1157
|
+
results = data.uns[key]
|
1158
|
+
else:
|
1159
|
+
results = data
|
1160
|
+
if ax is None:
|
1161
|
+
fig, ax = plt.subplots()
|
1162
|
+
y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
|
1163
|
+
ax.hlines(
|
1164
|
+
y_axes_range,
|
1165
|
+
xmin=0,
|
1166
|
+
xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
1167
|
+
)
|
1168
|
+
|
1169
|
+
ax.plot(
|
1170
|
+
results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
1171
|
+
y_axes_range,
|
1172
|
+
"o",
|
1173
|
+
)
|
1174
|
+
|
1175
|
+
plt.xlabel("Mean Augur Score")
|
1176
|
+
plt.ylabel("Cell Type")
|
1177
|
+
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1178
|
+
|
1179
|
+
if save:
|
1180
|
+
plt.savefig(save, bbox_inches="tight")
|
1181
|
+
if show:
|
1182
|
+
plt.show()
|
1183
|
+
if return_fig:
|
1184
|
+
return plt.gcf()
|
1185
|
+
if not (show or save):
|
1186
|
+
return ax
|
1187
|
+
return None
|
1188
|
+
|
1189
|
+
def plot_scatterplot(
|
1190
|
+
self,
|
1191
|
+
results1: dict[str, Any],
|
1192
|
+
results2: dict[str, Any],
|
1193
|
+
top_n: int = None,
|
1194
|
+
return_fig: bool | None = None,
|
1195
|
+
show: bool | None = None,
|
1196
|
+
save: str | bool | None = None,
|
1197
|
+
) -> Axes | Figure | None:
|
1198
|
+
"""Create scatterplot with two augur results.
|
1199
|
+
|
1200
|
+
Args:
|
1201
|
+
results1: results after running `predict()`
|
1202
|
+
results2: results after running `predict()`
|
1203
|
+
top_n: optionally, the number of top prioritized cell types to label in the plot
|
1204
|
+
return_figure: if `True` returns figure of the plot
|
1205
|
+
|
1206
|
+
Returns:
|
1207
|
+
Axes of the plot.
|
1208
|
+
|
1209
|
+
Examples:
|
1210
|
+
>>> import pertpy as pt
|
1211
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1212
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1213
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1214
|
+
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
|
1215
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1216
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1217
|
+
... )
|
1218
|
+
>>> ag_rfc.plot_scatterplot(v_results, h_results)
|
1219
|
+
|
1220
|
+
Preview:
|
1221
|
+
.. image:: /_static/docstring_previews/augur_scatterplot.png
|
1222
|
+
"""
|
1223
|
+
cell_types = results1["summary_metrics"].columns
|
1224
|
+
|
1225
|
+
fig, ax = plt.subplots()
|
1226
|
+
ax.scatter(
|
1227
|
+
results1["summary_metrics"].loc["mean_augur_score", cell_types],
|
1228
|
+
results2["summary_metrics"].loc["mean_augur_score", cell_types],
|
1229
|
+
)
|
1230
|
+
|
1231
|
+
# adding optional labels
|
1232
|
+
top_n_cell_types = (
|
1233
|
+
(results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
|
1234
|
+
.sort_values(ascending=False)
|
1235
|
+
.index[:top_n]
|
1236
|
+
)
|
1237
|
+
for txt in top_n_cell_types:
|
1238
|
+
ax.annotate(
|
1239
|
+
txt,
|
1240
|
+
(
|
1241
|
+
results1["summary_metrics"].loc["mean_augur_score", txt],
|
1242
|
+
results2["summary_metrics"].loc["mean_augur_score", txt],
|
1243
|
+
),
|
1244
|
+
)
|
1245
|
+
|
1246
|
+
# adding diagonal
|
1247
|
+
limits = max(ax.get_xlim(), ax.get_ylim())
|
1248
|
+
(diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
|
1249
|
+
|
1250
|
+
plt.xlabel("Augur scores 1")
|
1251
|
+
plt.ylabel("Augur scores 2")
|
1252
|
+
|
1253
|
+
if save:
|
1254
|
+
plt.savefig(save, bbox_inches="tight")
|
1255
|
+
if show:
|
1256
|
+
plt.show()
|
1257
|
+
if return_fig:
|
1258
|
+
return plt.gcf()
|
1259
|
+
if not (show or save):
|
1260
|
+
return ax
|
1261
|
+
return None
|