pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import uuid
|
3
4
|
from typing import TYPE_CHECKING
|
4
5
|
|
5
6
|
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
import scanpy as sc
|
6
9
|
import scipy
|
7
10
|
|
8
11
|
if TYPE_CHECKING:
|
9
12
|
from anndata import AnnData
|
13
|
+
from matplotlib.axes import Axes
|
10
14
|
|
11
15
|
|
12
16
|
class GuideAssignment:
|
@@ -39,7 +43,7 @@ class GuideAssignment:
|
|
39
43
|
|
40
44
|
>>> import pertpy as pt
|
41
45
|
>>> mdata = pt.data.papalexi_2021()
|
42
|
-
>>> gdo = mdata.mod[
|
46
|
+
>>> gdo = mdata.mod["gdo"]
|
43
47
|
>>> ga = pt.pp.GuideAssignment()
|
44
48
|
>>> ga.assign_by_threshold(gdo, assignment_threshold=5)
|
45
49
|
"""
|
@@ -71,7 +75,6 @@ class GuideAssignment:
|
|
71
75
|
|
72
76
|
Args:
|
73
77
|
adata: Annotated data matrix containing gRNA values
|
74
|
-
assignment_threshold: If a gRNA is available for at least `assignment_threshold`, it will be recognized as assigned.
|
75
78
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
76
79
|
layer: Key to the layer containing raw count values of the gRNAs.
|
77
80
|
adata.X is used if layer is None. Expects count data.
|
@@ -83,8 +86,8 @@ class GuideAssignment:
|
|
83
86
|
Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
|
84
87
|
|
85
88
|
>>> import pertpy as pt
|
86
|
-
>>> mdata = pt.
|
87
|
-
>>> gdo = mdata.mod[
|
89
|
+
>>> mdata = pt.dt.papalexi_2021()
|
90
|
+
>>> gdo = mdata.mod["gdo"]
|
88
91
|
>>> ga = pt.pp.GuideAssignment()
|
89
92
|
>>> ga.assign_to_max_guide(gdo, assignment_threshold=5)
|
90
93
|
"""
|
@@ -103,3 +106,84 @@ class GuideAssignment:
|
|
103
106
|
adata.obs[output_key] = assigned_grna
|
104
107
|
|
105
108
|
return None
|
109
|
+
|
110
|
+
def plot_heatmap(
|
111
|
+
self,
|
112
|
+
adata: AnnData,
|
113
|
+
layer: str | None = None,
|
114
|
+
order_by: np.ndarray | str | None = None,
|
115
|
+
key_to_save_order: str = None,
|
116
|
+
**kwargs,
|
117
|
+
) -> list[Axes]:
|
118
|
+
"""Heatmap plotting of guide RNA expression matrix.
|
119
|
+
|
120
|
+
Assuming guides have sparse expression, this function reorders cells
|
121
|
+
and plots guide RNA expression so that a nice sparse representation is achieved.
|
122
|
+
The cell ordering can be stored and reused in future plots to obtain consistent
|
123
|
+
plots before and after analysis of the guide RNA expression.
|
124
|
+
Note: This function expects a log-normalized or binary data.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
adata: Annotated data matrix containing gRNA values
|
128
|
+
layer: Key to the layer containing log normalized count values of the gRNAs.
|
129
|
+
adata.X is used if layer is None.
|
130
|
+
order_by: The order of cells in y axis. Defaults to None.
|
131
|
+
If None, cells will be reordered to have a nice sparse representation.
|
132
|
+
If a string is provided, adata.obs[order_by] will be used as the order.
|
133
|
+
If a numpy array is provided, the array will be used for ordering.
|
134
|
+
key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
|
135
|
+
kwargs: Are passed to sc.pl.heatmap.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
|
139
|
+
Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
|
140
|
+
|
141
|
+
Examples:
|
142
|
+
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
|
143
|
+
visualized using a heatmap.
|
144
|
+
|
145
|
+
>>> import pertpy as pt
|
146
|
+
>>> mdata = pt.dt.papalexi_2021()
|
147
|
+
>>> gdo = mdata.mod["gdo"]
|
148
|
+
>>> ga = pt.pp.GuideAssignment()
|
149
|
+
>>> ga.assign_by_threshold(gdo, assignment_threshold=5)
|
150
|
+
>>> ga.plot_heatmap(gdo)
|
151
|
+
"""
|
152
|
+
data = adata.X if layer is None else adata.layers[layer]
|
153
|
+
|
154
|
+
if order_by is None:
|
155
|
+
if scipy.sparse.issparse(data):
|
156
|
+
max_values = data.max(axis=1).A.squeeze()
|
157
|
+
data_argmax = data.argmax(axis=1).A.squeeze()
|
158
|
+
max_guide_index = np.where(max_values != data.min(axis=1).A.squeeze(), data_argmax, -1)
|
159
|
+
else:
|
160
|
+
max_guide_index = np.where(
|
161
|
+
data.max(axis=1).squeeze() != data.min(axis=1).squeeze(), data.argmax(axis=1).squeeze(), -1
|
162
|
+
)
|
163
|
+
order = np.argsort(max_guide_index)
|
164
|
+
elif isinstance(order_by, str):
|
165
|
+
order = np.argsort(adata.obs[order_by])
|
166
|
+
else:
|
167
|
+
order = order_by
|
168
|
+
|
169
|
+
temp_col_name = f"_tmp_pertpy_grna_plot_{uuid.uuid4()}"
|
170
|
+
adata.obs[temp_col_name] = pd.Categorical(["" for _ in range(adata.shape[0])])
|
171
|
+
|
172
|
+
if key_to_save_order is not None:
|
173
|
+
adata.obs[key_to_save_order] = pd.Categorical(order)
|
174
|
+
|
175
|
+
try:
|
176
|
+
axis_group = sc.pl.heatmap(
|
177
|
+
adata[order, :],
|
178
|
+
var_names=adata.var.index.tolist(),
|
179
|
+
groupby=temp_col_name,
|
180
|
+
cmap="viridis",
|
181
|
+
use_raw=False,
|
182
|
+
dendrogram=False,
|
183
|
+
layer=layer,
|
184
|
+
**kwargs,
|
185
|
+
)
|
186
|
+
finally:
|
187
|
+
del adata.obs[temp_col_name]
|
188
|
+
|
189
|
+
return axis_group
|
pertpy/tools/__init__.py
CHANGED
@@ -1,24 +1,19 @@
|
|
1
|
-
from rich import print
|
2
|
-
|
3
1
|
from pertpy.tools._augur import Augur
|
4
2
|
from pertpy.tools._cinemaot import Cinemaot
|
3
|
+
from pertpy.tools._coda._sccoda import Sccoda
|
4
|
+
from pertpy.tools._coda._tasccoda import Tasccoda
|
5
5
|
from pertpy.tools._dialogue import Dialogue
|
6
6
|
from pertpy.tools._differential_gene_expression import DifferentialGeneExpression
|
7
7
|
from pertpy.tools._distances._distance_tests import DistanceTest
|
8
8
|
from pertpy.tools._distances._distances import Distance
|
9
|
-
from pertpy.tools.
|
9
|
+
from pertpy.tools._enrichment import Enrichment
|
10
10
|
from pertpy.tools._milo import Milo
|
11
11
|
from pertpy.tools._mixscape import Mixscape
|
12
12
|
from pertpy.tools._perturbation_space._clustering import ClusteringSpace
|
13
|
-
from pertpy.tools._perturbation_space.
|
13
|
+
from pertpy.tools._perturbation_space._discriminator_classifiers import (
|
14
|
+
DiscriminatorClassifierSpace,
|
15
|
+
LRClassifierSpace,
|
16
|
+
MLPClassifierSpace,
|
17
|
+
)
|
14
18
|
from pertpy.tools._perturbation_space._simple import CentroidSpace, DBSCANSpace, KMeansSpace, PseudobulkSpace
|
15
19
|
from pertpy.tools._scgen import SCGEN
|
16
|
-
|
17
|
-
try:
|
18
|
-
from pertpy.tools._coda._sccoda import Sccoda
|
19
|
-
from pertpy.tools._coda._tasccoda import Tasccoda
|
20
|
-
except ImportError as e:
|
21
|
-
if "ete3" in str(e):
|
22
|
-
print("[bold yellow]To use sccoda or tasccoda please install ete3 with [green]pip install ete3")
|
23
|
-
else:
|
24
|
-
raise e
|
pertpy/tools/_augur.py
CHANGED
@@ -4,8 +4,10 @@ import random
|
|
4
4
|
from collections import defaultdict
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from math import floor, nan
|
7
|
-
from typing import Any, Literal
|
7
|
+
from typing import TYPE_CHECKING, Any, Literal
|
8
8
|
|
9
|
+
import anndata as ad
|
10
|
+
import matplotlib.pyplot as plt
|
9
11
|
import numpy as np
|
10
12
|
import pandas as pd
|
11
13
|
import scanpy as sc
|
@@ -34,6 +36,10 @@ from sklearn.preprocessing import LabelEncoder
|
|
34
36
|
from skmisc.loess import loess
|
35
37
|
from statsmodels.stats.multitest import fdrcorrection
|
36
38
|
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from matplotlib.axes import Axes
|
41
|
+
from matplotlib.figure import Figure
|
42
|
+
|
37
43
|
|
38
44
|
@dataclass
|
39
45
|
class Params:
|
@@ -135,8 +141,8 @@ class Augur:
|
|
135
141
|
# filter samples according to label
|
136
142
|
if condition_label is not None and treatment_label is not None:
|
137
143
|
print(f"Filtering samples with {condition_label} and {treatment_label} labels.")
|
138
|
-
adata =
|
139
|
-
adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
|
144
|
+
adata = ad.concat(
|
145
|
+
[adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
|
140
146
|
)
|
141
147
|
label_encoder = LabelEncoder()
|
142
148
|
adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
|
@@ -214,7 +220,9 @@ class Augur:
|
|
214
220
|
>>> loaded_data = ag_rfc.load(adata)
|
215
221
|
>>> ag_rfc.select_highly_variable(loaded_data)
|
216
222
|
>>> features = loaded_data.var_names
|
217
|
-
>>> subsample = ag_rfc.sample(
|
223
|
+
>>> subsample = ag_rfc.sample(
|
224
|
+
... loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names
|
225
|
+
... )
|
218
226
|
"""
|
219
227
|
# export subsampling.
|
220
228
|
random.seed(random_state)
|
@@ -230,7 +238,7 @@ class Augur:
|
|
230
238
|
random_state=random_state,
|
231
239
|
)
|
232
240
|
)
|
233
|
-
subsample =
|
241
|
+
subsample = ad.concat([*label_subsamples], index_unique=None)
|
234
242
|
else:
|
235
243
|
subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
|
236
244
|
|
@@ -409,8 +417,8 @@ class Augur:
|
|
409
417
|
"""
|
410
418
|
if multiclass:
|
411
419
|
return {
|
412
|
-
"augur_score": make_scorer(roc_auc_score, multi_class="ovo",
|
413
|
-
"auc": make_scorer(roc_auc_score, multi_class="ovo",
|
420
|
+
"augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
|
421
|
+
"auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
|
414
422
|
"accuracy": make_scorer(accuracy_score),
|
415
423
|
"precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
|
416
424
|
"f1": make_scorer(f1_score, average="macro"),
|
@@ -418,8 +426,8 @@ class Augur:
|
|
418
426
|
}
|
419
427
|
return (
|
420
428
|
{
|
421
|
-
"augur_score": make_scorer(roc_auc_score,
|
422
|
-
"auc": make_scorer(roc_auc_score,
|
429
|
+
"augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
|
430
|
+
"auc": make_scorer(roc_auc_score, response_method="predict_proba"),
|
423
431
|
"accuracy": make_scorer(accuracy_score),
|
424
432
|
"precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
|
425
433
|
"f1": make_scorer(f1_score, average="binary"),
|
@@ -488,7 +496,7 @@ class Augur:
|
|
488
496
|
# feature importances
|
489
497
|
feature_importances = defaultdict(list)
|
490
498
|
if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
|
491
|
-
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
|
499
|
+
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
|
492
500
|
feature_importances["genes"].extend(x.columns.tolist())
|
493
501
|
feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
|
494
502
|
feature_importances["subsample_idx"].extend(len(x.columns) * [subsample_idx])
|
@@ -497,7 +505,7 @@ class Augur:
|
|
497
505
|
# standardized coefficients with Agresti method
|
498
506
|
# cf. https://think-lab.github.io/d/205/#3
|
499
507
|
if isinstance(self.estimator, LogisticRegression):
|
500
|
-
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
|
508
|
+
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
|
501
509
|
feature_importances["genes"].extend(x.columns.tolist())
|
502
510
|
feature_importances["feature_importances"].extend(
|
503
511
|
(self.estimator.coef_ * self.estimator.coef_.std()).flatten().tolist()
|
@@ -723,6 +731,7 @@ class Augur:
|
|
723
731
|
>>> loaded_data = ag_rfc.load(adata)
|
724
732
|
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
|
725
733
|
"""
|
734
|
+
adata = adata.copy()
|
726
735
|
if augur_mode == "permute" and n_subsamples < 100:
|
727
736
|
n_subsamples = 500
|
728
737
|
if is_regressor(self.estimator) and len(adata.obs["y_"].unique()) <= 3:
|
@@ -765,6 +774,7 @@ class Augur:
|
|
765
774
|
elif (
|
766
775
|
cell_type_subsample.obs.groupby(
|
767
776
|
["cell_type", "label"],
|
777
|
+
observed=True,
|
768
778
|
).y_.count()
|
769
779
|
< subsample_size
|
770
780
|
).any():
|
@@ -804,7 +814,7 @@ class Augur:
|
|
804
814
|
* (len(results["feature_importances"]["genes"]) - len(results["feature_importances"]["cell_type"]))
|
805
815
|
)
|
806
816
|
|
807
|
-
for idx, cv in zip(range(n_subsamples), results[cell_type]):
|
817
|
+
for idx, cv in zip(range(n_subsamples), results[cell_type], strict=False):
|
808
818
|
results["full_results"]["idx"].extend([idx] * folds)
|
809
819
|
results["full_results"]["augur_score"].extend(cv["test_augur_score"])
|
810
820
|
results["full_results"]["folds"].extend(range(folds))
|
@@ -869,28 +879,31 @@ class Augur:
|
|
869
879
|
& set(permuted_results1["summary_metrics"].columns)
|
870
880
|
& set(permuted_results2["summary_metrics"].columns)
|
871
881
|
)
|
882
|
+
|
883
|
+
cell_types_list = list(cell_types)
|
884
|
+
|
872
885
|
# mean augur scores
|
873
886
|
augur_score1 = (
|
874
887
|
augur_results1["summary_metrics"]
|
875
|
-
.loc["mean_augur_score",
|
888
|
+
.loc["mean_augur_score", cell_types_list]
|
876
889
|
.reset_index()
|
877
890
|
.rename(columns={"index": "cell_type"})
|
878
891
|
)
|
879
892
|
augur_score2 = (
|
880
893
|
augur_results2["summary_metrics"]
|
881
|
-
.loc["mean_augur_score",
|
894
|
+
.loc["mean_augur_score", cell_types_list]
|
882
895
|
.reset_index()
|
883
896
|
.rename(columns={"index": "cell_type"})
|
884
897
|
)
|
885
898
|
|
886
899
|
# mean permuted scores over cross validation runs
|
887
900
|
permuted_cv_augur1 = (
|
888
|
-
permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(
|
901
|
+
permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types_list)]
|
889
902
|
.groupby(["cell_type", "idx"], as_index=False)
|
890
903
|
.mean()
|
891
904
|
)
|
892
905
|
permuted_cv_augur2 = (
|
893
|
-
permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(
|
906
|
+
permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types_list)]
|
894
907
|
.groupby(["cell_type", "idx"], as_index=False)
|
895
908
|
.mean()
|
896
909
|
)
|
@@ -901,7 +914,7 @@ class Augur:
|
|
901
914
|
# draw mean aucs for permute1 and permute2
|
902
915
|
for celltype in permuted_cv_augur1["cell_type"].unique():
|
903
916
|
df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
|
904
|
-
df2 = permuted_cv_augur2[
|
917
|
+
df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
|
905
918
|
for permutation_idx in range(n_permutations):
|
906
919
|
# subsample
|
907
920
|
sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
|
@@ -961,3 +974,288 @@ class Augur:
|
|
961
974
|
delta["padj"] = fdrcorrection(delta["pval"])[1]
|
962
975
|
|
963
976
|
return delta
|
977
|
+
|
978
|
+
def plot_dp_scatter(
|
979
|
+
self,
|
980
|
+
results: pd.DataFrame,
|
981
|
+
top_n: int = None,
|
982
|
+
return_fig: bool | None = None,
|
983
|
+
ax: Axes = None,
|
984
|
+
show: bool | None = None,
|
985
|
+
save: str | bool | None = None,
|
986
|
+
) -> Axes | Figure | None:
|
987
|
+
"""Plot scatterplot of differential prioritization.
|
988
|
+
|
989
|
+
Args:
|
990
|
+
results: Results after running differential prioritization.
|
991
|
+
top_n: optionally, the number of top prioritized cell types to label in the plot
|
992
|
+
ax: optionally, axes used to draw plot
|
993
|
+
|
994
|
+
Returns:
|
995
|
+
Axes of the plot.
|
996
|
+
|
997
|
+
Examples:
|
998
|
+
>>> import pertpy as pt
|
999
|
+
>>> adata = pt.dt.bhattacherjee()
|
1000
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1001
|
+
|
1002
|
+
>>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine")
|
1003
|
+
>>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4)
|
1004
|
+
>>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
|
1005
|
+
|
1006
|
+
>>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine")
|
1007
|
+
>>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4)
|
1008
|
+
>>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
|
1009
|
+
|
1010
|
+
>>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
|
1011
|
+
permuted_results1=results_15_permute, permuted_results2=results_48_permute)
|
1012
|
+
>>> ag_rfc.plot_dp_scatter(pvals)
|
1013
|
+
|
1014
|
+
Preview:
|
1015
|
+
.. image:: /_static/docstring_previews/augur_dp_scatter.png
|
1016
|
+
"""
|
1017
|
+
x = results["mean_augur_score1"]
|
1018
|
+
y = results["mean_augur_score2"]
|
1019
|
+
|
1020
|
+
if ax is None:
|
1021
|
+
fig, ax = plt.subplots()
|
1022
|
+
scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
|
1023
|
+
|
1024
|
+
# adding optional labels
|
1025
|
+
top_n_index = results.sort_values(by="pval").index[:top_n]
|
1026
|
+
for idx in top_n_index:
|
1027
|
+
ax.annotate(
|
1028
|
+
results.loc[idx, "cell_type"],
|
1029
|
+
(results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
|
1030
|
+
)
|
1031
|
+
|
1032
|
+
# add diagonal
|
1033
|
+
limits = max(ax.get_xlim(), ax.get_ylim())
|
1034
|
+
(_,) = ax.plot(limits, limits, ls="--", c=".3")
|
1035
|
+
|
1036
|
+
# formatting and details
|
1037
|
+
plt.xlabel("Augur scores 1")
|
1038
|
+
plt.ylabel("Augur scores 2")
|
1039
|
+
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1040
|
+
ax.add_artist(legend1)
|
1041
|
+
|
1042
|
+
if save:
|
1043
|
+
plt.savefig(save, bbox_inches="tight")
|
1044
|
+
if show:
|
1045
|
+
plt.show()
|
1046
|
+
if return_fig:
|
1047
|
+
return plt.gcf()
|
1048
|
+
if not (show or save):
|
1049
|
+
return ax
|
1050
|
+
return None
|
1051
|
+
|
1052
|
+
def plot_important_features(
|
1053
|
+
self,
|
1054
|
+
data: dict[str, Any],
|
1055
|
+
key: str = "augurpy_results",
|
1056
|
+
top_n: int = 10,
|
1057
|
+
return_fig: bool | None = None,
|
1058
|
+
ax: Axes = None,
|
1059
|
+
show: bool | None = None,
|
1060
|
+
save: str | bool | None = None,
|
1061
|
+
) -> Axes | None:
|
1062
|
+
"""Plot a lollipop plot of the n features with largest feature importances.
|
1063
|
+
|
1064
|
+
Args:
|
1065
|
+
results: results after running `predict()` as dictionary or the AnnData object.
|
1066
|
+
key: Key in the AnnData object of the results
|
1067
|
+
top_n: n number feature importance values to plot. Default is 10.
|
1068
|
+
ax: optionally, axes used to draw plot
|
1069
|
+
return_figure: if `True` returns figure of the plot, default is `False`
|
1070
|
+
|
1071
|
+
Returns:
|
1072
|
+
Axes of the plot.
|
1073
|
+
|
1074
|
+
Examples:
|
1075
|
+
>>> import pertpy as pt
|
1076
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1077
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1078
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1079
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1080
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1081
|
+
... )
|
1082
|
+
>>> ag_rfc.plot_important_features(v_results)
|
1083
|
+
|
1084
|
+
Preview:
|
1085
|
+
.. image:: /_static/docstring_previews/augur_important_features.png
|
1086
|
+
"""
|
1087
|
+
if isinstance(data, AnnData):
|
1088
|
+
results = data.uns[key]
|
1089
|
+
else:
|
1090
|
+
results = data
|
1091
|
+
n_features = (
|
1092
|
+
results["feature_importances"]
|
1093
|
+
.groupby("genes", as_index=False)
|
1094
|
+
.feature_importances.mean()
|
1095
|
+
.sort_values(by="feature_importances")[-top_n:]
|
1096
|
+
)
|
1097
|
+
|
1098
|
+
if ax is None:
|
1099
|
+
fig, ax = plt.subplots()
|
1100
|
+
y_axes_range = range(1, top_n + 1)
|
1101
|
+
ax.hlines(
|
1102
|
+
y_axes_range,
|
1103
|
+
xmin=0,
|
1104
|
+
xmax=n_features["feature_importances"],
|
1105
|
+
)
|
1106
|
+
|
1107
|
+
ax.plot(n_features["feature_importances"], y_axes_range, "o")
|
1108
|
+
|
1109
|
+
plt.xlabel("Mean Feature Importance")
|
1110
|
+
plt.ylabel("Gene")
|
1111
|
+
plt.yticks(y_axes_range, n_features["genes"])
|
1112
|
+
|
1113
|
+
if save:
|
1114
|
+
plt.savefig(save, bbox_inches="tight")
|
1115
|
+
if show:
|
1116
|
+
plt.show()
|
1117
|
+
if return_fig:
|
1118
|
+
return plt.gcf()
|
1119
|
+
if not (show or save):
|
1120
|
+
return ax
|
1121
|
+
return None
|
1122
|
+
|
1123
|
+
def plot_lollipop(
|
1124
|
+
self,
|
1125
|
+
data: dict[str, Any],
|
1126
|
+
key: str = "augurpy_results",
|
1127
|
+
return_fig: bool | None = None,
|
1128
|
+
ax: Axes = None,
|
1129
|
+
show: bool | None = None,
|
1130
|
+
save: str | bool | None = None,
|
1131
|
+
) -> Axes | Figure | None:
|
1132
|
+
"""Plot a lollipop plot of the mean augur values.
|
1133
|
+
|
1134
|
+
Args:
|
1135
|
+
results: results after running `predict()` as dictionary or the AnnData object.
|
1136
|
+
key: Key in the AnnData object of the results
|
1137
|
+
ax: optionally, axes used to draw plot
|
1138
|
+
return_figure: if `True` returns figure of the plot
|
1139
|
+
|
1140
|
+
Returns:
|
1141
|
+
Axes of the plot.
|
1142
|
+
|
1143
|
+
Examples:
|
1144
|
+
>>> import pertpy as pt
|
1145
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1146
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1147
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1148
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1149
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1150
|
+
... )
|
1151
|
+
>>> ag_rfc.plot_lollipop(v_results)
|
1152
|
+
|
1153
|
+
Preview:
|
1154
|
+
.. image:: /_static/docstring_previews/augur_lollipop.png
|
1155
|
+
"""
|
1156
|
+
if isinstance(data, AnnData):
|
1157
|
+
results = data.uns[key]
|
1158
|
+
else:
|
1159
|
+
results = data
|
1160
|
+
if ax is None:
|
1161
|
+
fig, ax = plt.subplots()
|
1162
|
+
y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
|
1163
|
+
ax.hlines(
|
1164
|
+
y_axes_range,
|
1165
|
+
xmin=0,
|
1166
|
+
xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
1167
|
+
)
|
1168
|
+
|
1169
|
+
ax.plot(
|
1170
|
+
results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
1171
|
+
y_axes_range,
|
1172
|
+
"o",
|
1173
|
+
)
|
1174
|
+
|
1175
|
+
plt.xlabel("Mean Augur Score")
|
1176
|
+
plt.ylabel("Cell Type")
|
1177
|
+
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1178
|
+
|
1179
|
+
if save:
|
1180
|
+
plt.savefig(save, bbox_inches="tight")
|
1181
|
+
if show:
|
1182
|
+
plt.show()
|
1183
|
+
if return_fig:
|
1184
|
+
return plt.gcf()
|
1185
|
+
if not (show or save):
|
1186
|
+
return ax
|
1187
|
+
return None
|
1188
|
+
|
1189
|
+
def plot_scatterplot(
|
1190
|
+
self,
|
1191
|
+
results1: dict[str, Any],
|
1192
|
+
results2: dict[str, Any],
|
1193
|
+
top_n: int = None,
|
1194
|
+
return_fig: bool | None = None,
|
1195
|
+
show: bool | None = None,
|
1196
|
+
save: str | bool | None = None,
|
1197
|
+
) -> Axes | Figure | None:
|
1198
|
+
"""Create scatterplot with two augur results.
|
1199
|
+
|
1200
|
+
Args:
|
1201
|
+
results1: results after running `predict()`
|
1202
|
+
results2: results after running `predict()`
|
1203
|
+
top_n: optionally, the number of top prioritized cell types to label in the plot
|
1204
|
+
return_figure: if `True` returns figure of the plot
|
1205
|
+
|
1206
|
+
Returns:
|
1207
|
+
Axes of the plot.
|
1208
|
+
|
1209
|
+
Examples:
|
1210
|
+
>>> import pertpy as pt
|
1211
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1212
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1213
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1214
|
+
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
|
1215
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1216
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1217
|
+
... )
|
1218
|
+
>>> ag_rfc.plot_scatterplot(v_results, h_results)
|
1219
|
+
|
1220
|
+
Preview:
|
1221
|
+
.. image:: /_static/docstring_previews/augur_scatterplot.png
|
1222
|
+
"""
|
1223
|
+
cell_types = results1["summary_metrics"].columns
|
1224
|
+
|
1225
|
+
fig, ax = plt.subplots()
|
1226
|
+
ax.scatter(
|
1227
|
+
results1["summary_metrics"].loc["mean_augur_score", cell_types],
|
1228
|
+
results2["summary_metrics"].loc["mean_augur_score", cell_types],
|
1229
|
+
)
|
1230
|
+
|
1231
|
+
# adding optional labels
|
1232
|
+
top_n_cell_types = (
|
1233
|
+
(results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
|
1234
|
+
.sort_values(ascending=False)
|
1235
|
+
.index[:top_n]
|
1236
|
+
)
|
1237
|
+
for txt in top_n_cell_types:
|
1238
|
+
ax.annotate(
|
1239
|
+
txt,
|
1240
|
+
(
|
1241
|
+
results1["summary_metrics"].loc["mean_augur_score", txt],
|
1242
|
+
results2["summary_metrics"].loc["mean_augur_score", txt],
|
1243
|
+
),
|
1244
|
+
)
|
1245
|
+
|
1246
|
+
# adding diagonal
|
1247
|
+
limits = max(ax.get_xlim(), ax.get_ylim())
|
1248
|
+
(diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
|
1249
|
+
|
1250
|
+
plt.xlabel("Augur scores 1")
|
1251
|
+
plt.ylabel("Augur scores 2")
|
1252
|
+
|
1253
|
+
if save:
|
1254
|
+
plt.savefig(save, bbox_inches="tight")
|
1255
|
+
if show:
|
1256
|
+
plt.show()
|
1257
|
+
if return_fig:
|
1258
|
+
return plt.gcf()
|
1259
|
+
if not (show or save):
|
1260
|
+
return ax
|
1261
|
+
return None
|