pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_augur.py
CHANGED
@@ -4,14 +4,17 @@ import random
|
|
4
4
|
from collections import defaultdict
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from math import floor, nan
|
7
|
-
from typing import Any, Literal
|
7
|
+
from typing import TYPE_CHECKING, Any, Literal
|
8
8
|
|
9
|
+
import anndata as ad
|
10
|
+
import matplotlib.pyplot as plt
|
9
11
|
import numpy as np
|
10
12
|
import pandas as pd
|
11
13
|
import scanpy as sc
|
12
14
|
import statsmodels.api as sm
|
13
15
|
from anndata import AnnData
|
14
16
|
from joblib import Parallel, delayed
|
17
|
+
from lamin_utils import logger
|
15
18
|
from rich import print
|
16
19
|
from rich.progress import track
|
17
20
|
from scipy import sparse, stats
|
@@ -34,6 +37,10 @@ from sklearn.preprocessing import LabelEncoder
|
|
34
37
|
from skmisc.loess import loess
|
35
38
|
from statsmodels.stats.multitest import fdrcorrection
|
36
39
|
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
from matplotlib.axes import Axes
|
42
|
+
from matplotlib.figure import Figure
|
43
|
+
|
37
44
|
|
38
45
|
@dataclass
|
39
46
|
class Params:
|
@@ -121,7 +128,7 @@ class Augur:
|
|
121
128
|
_ = input[cell_type_col]
|
122
129
|
_ = input[label_col]
|
123
130
|
except KeyError:
|
124
|
-
|
131
|
+
logger.error("No column names matching cell_type_col and label_col.")
|
125
132
|
|
126
133
|
label = input[label_col] if meta is None else meta[label_col]
|
127
134
|
cell_type = input[cell_type_col] if meta is None else meta[cell_type_col]
|
@@ -134,9 +141,9 @@ class Augur:
|
|
134
141
|
if adata.obs["label"].dtype.name == "category":
|
135
142
|
# filter samples according to label
|
136
143
|
if condition_label is not None and treatment_label is not None:
|
137
|
-
|
138
|
-
adata =
|
139
|
-
adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
|
144
|
+
logger.info(f"Filtering samples with {condition_label} and {treatment_label} labels.")
|
145
|
+
adata = ad.concat(
|
146
|
+
[adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
|
140
147
|
)
|
141
148
|
label_encoder = LabelEncoder()
|
142
149
|
adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
|
@@ -214,7 +221,9 @@ class Augur:
|
|
214
221
|
>>> loaded_data = ag_rfc.load(adata)
|
215
222
|
>>> ag_rfc.select_highly_variable(loaded_data)
|
216
223
|
>>> features = loaded_data.var_names
|
217
|
-
>>> subsample = ag_rfc.sample(
|
224
|
+
>>> subsample = ag_rfc.sample(
|
225
|
+
... loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names
|
226
|
+
... )
|
218
227
|
"""
|
219
228
|
# export subsampling.
|
220
229
|
random.seed(random_state)
|
@@ -230,7 +239,7 @@ class Augur:
|
|
230
239
|
random_state=random_state,
|
231
240
|
)
|
232
241
|
)
|
233
|
-
subsample =
|
242
|
+
subsample = ad.concat([*label_subsamples], index_unique=None)
|
234
243
|
else:
|
235
244
|
subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
|
236
245
|
|
@@ -409,8 +418,8 @@ class Augur:
|
|
409
418
|
"""
|
410
419
|
if multiclass:
|
411
420
|
return {
|
412
|
-
"augur_score": make_scorer(roc_auc_score, multi_class="ovo",
|
413
|
-
"auc": make_scorer(roc_auc_score, multi_class="ovo",
|
421
|
+
"augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
|
422
|
+
"auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
|
414
423
|
"accuracy": make_scorer(accuracy_score),
|
415
424
|
"precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
|
416
425
|
"f1": make_scorer(f1_score, average="macro"),
|
@@ -418,8 +427,8 @@ class Augur:
|
|
418
427
|
}
|
419
428
|
return (
|
420
429
|
{
|
421
|
-
"augur_score": make_scorer(roc_auc_score,
|
422
|
-
"auc": make_scorer(roc_auc_score,
|
430
|
+
"augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
|
431
|
+
"auc": make_scorer(roc_auc_score, response_method="predict_proba"),
|
423
432
|
"accuracy": make_scorer(accuracy_score),
|
424
433
|
"precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
|
425
434
|
"f1": make_scorer(f1_score, average="binary"),
|
@@ -488,7 +497,7 @@ class Augur:
|
|
488
497
|
# feature importances
|
489
498
|
feature_importances = defaultdict(list)
|
490
499
|
if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
|
491
|
-
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
|
500
|
+
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
|
492
501
|
feature_importances["genes"].extend(x.columns.tolist())
|
493
502
|
feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
|
494
503
|
feature_importances["subsample_idx"].extend(len(x.columns) * [subsample_idx])
|
@@ -497,7 +506,7 @@ class Augur:
|
|
497
506
|
# standardized coefficients with Agresti method
|
498
507
|
# cf. https://think-lab.github.io/d/205/#3
|
499
508
|
if isinstance(self.estimator, LogisticRegression):
|
500
|
-
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
|
509
|
+
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
|
501
510
|
feature_importances["genes"].extend(x.columns.tolist())
|
502
511
|
feature_importances["feature_importances"].extend(
|
503
512
|
(self.estimator.coef_ * self.estimator.coef_.std()).flatten().tolist()
|
@@ -548,7 +557,7 @@ class Augur:
|
|
548
557
|
try:
|
549
558
|
sc.pp.highly_variable_genes(adata)
|
550
559
|
except ValueError:
|
551
|
-
|
560
|
+
logger.warn("Data not normalized. Normalizing now using scanpy log1p normalize.")
|
552
561
|
sc.pp.log1p(adata)
|
553
562
|
sc.pp.highly_variable_genes(adata)
|
554
563
|
|
@@ -600,7 +609,7 @@ class Augur:
|
|
600
609
|
var_quantile: The quantile below which features will be filtered, based on their residuals in a loess model.
|
601
610
|
filter_negative_residuals: if `True`, filter residuals at a fixed threshold of zero, instead of `var_quantile`
|
602
611
|
span: Smoothing factor, as a fraction of the number of points to take into account.
|
603
|
-
Should be in the range (0, 1].
|
612
|
+
Should be in the range (0, 1].
|
604
613
|
|
605
614
|
Return:
|
606
615
|
AnnData object with additional select_variance column in var.
|
@@ -692,13 +701,11 @@ class Augur:
|
|
692
701
|
feature_perc: proportion of genes that are randomly selected as features for input to the classifier in each
|
693
702
|
subsample using the random gene filter
|
694
703
|
var_quantile: The quantile below which features will be filtered, based on their residuals in a loess model.
|
695
|
-
Defaults to 0.5.
|
696
704
|
span: Smoothing factor, as a fraction of the number of points to take into account. Should be in the range (0, 1].
|
697
|
-
Defaults to 0.75.
|
698
705
|
filter_negative_residuals: if `True`, filter residuals at a fixed threshold of zero, instead of `var_quantile`
|
699
706
|
n_threads: number of threads to use for parallelization
|
700
707
|
select_variance_features: Whether to select genes based on the original Augur implementation (True)
|
701
|
-
or using scanpy's highly_variable_genes (False).
|
708
|
+
or using scanpy's highly_variable_genes (False).
|
702
709
|
key_added: Key to add results to in .uns
|
703
710
|
augur_mode: One of 'default', 'velocity' or 'permute'. Setting augur_mode = "velocity" disables feature selection,
|
704
711
|
assuming feature selection has been performed by the RNA velocity procedure to produce the input matrix,
|
@@ -723,6 +730,7 @@ class Augur:
|
|
723
730
|
>>> loaded_data = ag_rfc.load(adata)
|
724
731
|
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
|
725
732
|
"""
|
733
|
+
adata = adata.copy()
|
726
734
|
if augur_mode == "permute" and n_subsamples < 100:
|
727
735
|
n_subsamples = 500
|
728
736
|
if is_regressor(self.estimator) and len(adata.obs["y_"].unique()) <= 3:
|
@@ -742,8 +750,8 @@ class Augur:
|
|
742
750
|
"full_results": defaultdict(list),
|
743
751
|
}
|
744
752
|
if select_variance_features:
|
745
|
-
|
746
|
-
|
753
|
+
logger.warning("Set smaller span value in the case of a `segmentation fault` error.")
|
754
|
+
logger.warning("Set larger span in case of svddc or other near singularities error.")
|
747
755
|
adata.obs["augur_score"] = nan
|
748
756
|
for cell_type in track(adata.obs["cell_type"].unique(), description="Processing data..."):
|
749
757
|
cell_type_subsample = adata[adata.obs["cell_type"] == cell_type].copy()
|
@@ -759,17 +767,18 @@ class Augur:
|
|
759
767
|
)
|
760
768
|
)
|
761
769
|
if len(cell_type_subsample) < min_cells:
|
762
|
-
|
763
|
-
f"
|
770
|
+
logger.warning(
|
771
|
+
f"Skipping {cell_type} cell type - {len(cell_type_subsample)} samples is less than min_cells {min_cells}."
|
764
772
|
)
|
765
773
|
elif (
|
766
774
|
cell_type_subsample.obs.groupby(
|
767
775
|
["cell_type", "label"],
|
776
|
+
observed=True,
|
768
777
|
).y_.count()
|
769
778
|
< subsample_size
|
770
779
|
).any():
|
771
|
-
|
772
|
-
f"
|
780
|
+
logger.warning(
|
781
|
+
f"Skipping {cell_type} cell type - the number of samples for at least one class type is less than "
|
773
782
|
f"subsample size {subsample_size}."
|
774
783
|
)
|
775
784
|
else:
|
@@ -804,14 +813,14 @@ class Augur:
|
|
804
813
|
* (len(results["feature_importances"]["genes"]) - len(results["feature_importances"]["cell_type"]))
|
805
814
|
)
|
806
815
|
|
807
|
-
for idx, cv in zip(range(n_subsamples), results[cell_type]):
|
816
|
+
for idx, cv in zip(range(n_subsamples), results[cell_type], strict=False):
|
808
817
|
results["full_results"]["idx"].extend([idx] * folds)
|
809
818
|
results["full_results"]["augur_score"].extend(cv["test_augur_score"])
|
810
819
|
results["full_results"]["folds"].extend(range(folds))
|
811
820
|
results["full_results"]["cell_type"].extend([cell_type] * folds * n_subsamples)
|
812
821
|
# make sure one cell type worked
|
813
822
|
if len(results) <= 2:
|
814
|
-
|
823
|
+
logger.warning("No cells types had more than min_cells needed. Please adjust data or min_cells parameter.")
|
815
824
|
|
816
825
|
results["summary_metrics"] = pd.DataFrame(results["summary_metrics"])
|
817
826
|
results["feature_importances"] = pd.DataFrame(results["feature_importances"])
|
@@ -840,7 +849,7 @@ class Augur:
|
|
840
849
|
augur2: Augurpy results from condition 2, obtained from `predict()[1]`
|
841
850
|
permuted1: permuted Augurpy results from condition 1, obtained from `predict()` with argument `augur_mode=permute`
|
842
851
|
permuted2: permuted Augurpy results from condition 2, obtained from `predict()` with argument `augur_mode=permute`
|
843
|
-
n_subsamples: number of subsamples to pool when calculating the mean augur score for each permutation
|
852
|
+
n_subsamples: number of subsamples to pool when calculating the mean augur score for each permutation.
|
844
853
|
n_permutations: the total number of mean augur scores to calculate from a background distribution
|
845
854
|
|
846
855
|
Returns:
|
@@ -869,28 +878,31 @@ class Augur:
|
|
869
878
|
& set(permuted_results1["summary_metrics"].columns)
|
870
879
|
& set(permuted_results2["summary_metrics"].columns)
|
871
880
|
)
|
881
|
+
|
882
|
+
cell_types_list = list(cell_types)
|
883
|
+
|
872
884
|
# mean augur scores
|
873
885
|
augur_score1 = (
|
874
886
|
augur_results1["summary_metrics"]
|
875
|
-
.loc["mean_augur_score",
|
887
|
+
.loc["mean_augur_score", cell_types_list]
|
876
888
|
.reset_index()
|
877
889
|
.rename(columns={"index": "cell_type"})
|
878
890
|
)
|
879
891
|
augur_score2 = (
|
880
892
|
augur_results2["summary_metrics"]
|
881
|
-
.loc["mean_augur_score",
|
893
|
+
.loc["mean_augur_score", cell_types_list]
|
882
894
|
.reset_index()
|
883
895
|
.rename(columns={"index": "cell_type"})
|
884
896
|
)
|
885
897
|
|
886
898
|
# mean permuted scores over cross validation runs
|
887
899
|
permuted_cv_augur1 = (
|
888
|
-
permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(
|
900
|
+
permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types_list)]
|
889
901
|
.groupby(["cell_type", "idx"], as_index=False)
|
890
902
|
.mean()
|
891
903
|
)
|
892
904
|
permuted_cv_augur2 = (
|
893
|
-
permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(
|
905
|
+
permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types_list)]
|
894
906
|
.groupby(["cell_type", "idx"], as_index=False)
|
895
907
|
.mean()
|
896
908
|
)
|
@@ -901,7 +913,7 @@ class Augur:
|
|
901
913
|
# draw mean aucs for permute1 and permute2
|
902
914
|
for celltype in permuted_cv_augur1["cell_type"].unique():
|
903
915
|
df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
|
904
|
-
df2 = permuted_cv_augur2[
|
916
|
+
df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
|
905
917
|
for permutation_idx in range(n_permutations):
|
906
918
|
# subsample
|
907
919
|
sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
|
@@ -961,3 +973,288 @@ class Augur:
|
|
961
973
|
delta["padj"] = fdrcorrection(delta["pval"])[1]
|
962
974
|
|
963
975
|
return delta
|
976
|
+
|
977
|
+
def plot_dp_scatter(
|
978
|
+
self,
|
979
|
+
results: pd.DataFrame,
|
980
|
+
top_n: int = None,
|
981
|
+
return_fig: bool | None = None,
|
982
|
+
ax: Axes = None,
|
983
|
+
show: bool | None = None,
|
984
|
+
save: str | bool | None = None,
|
985
|
+
) -> Axes | Figure | None:
|
986
|
+
"""Plot scatterplot of differential prioritization.
|
987
|
+
|
988
|
+
Args:
|
989
|
+
results: Results after running differential prioritization.
|
990
|
+
top_n: optionally, the number of top prioritized cell types to label in the plot
|
991
|
+
ax: optionally, axes used to draw plot
|
992
|
+
|
993
|
+
Returns:
|
994
|
+
Axes of the plot.
|
995
|
+
|
996
|
+
Examples:
|
997
|
+
>>> import pertpy as pt
|
998
|
+
>>> adata = pt.dt.bhattacherjee()
|
999
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1000
|
+
|
1001
|
+
>>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine")
|
1002
|
+
>>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4)
|
1003
|
+
>>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
|
1004
|
+
|
1005
|
+
>>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine")
|
1006
|
+
>>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4)
|
1007
|
+
>>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
|
1008
|
+
|
1009
|
+
>>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
|
1010
|
+
permuted_results1=results_15_permute, permuted_results2=results_48_permute)
|
1011
|
+
>>> ag_rfc.plot_dp_scatter(pvals)
|
1012
|
+
|
1013
|
+
Preview:
|
1014
|
+
.. image:: /_static/docstring_previews/augur_dp_scatter.png
|
1015
|
+
"""
|
1016
|
+
x = results["mean_augur_score1"]
|
1017
|
+
y = results["mean_augur_score2"]
|
1018
|
+
|
1019
|
+
if ax is None:
|
1020
|
+
fig, ax = plt.subplots()
|
1021
|
+
scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
|
1022
|
+
|
1023
|
+
# adding optional labels
|
1024
|
+
top_n_index = results.sort_values(by="pval").index[:top_n]
|
1025
|
+
for idx in top_n_index:
|
1026
|
+
ax.annotate(
|
1027
|
+
results.loc[idx, "cell_type"],
|
1028
|
+
(results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
|
1029
|
+
)
|
1030
|
+
|
1031
|
+
# add diagonal
|
1032
|
+
limits = max(ax.get_xlim(), ax.get_ylim())
|
1033
|
+
(_,) = ax.plot(limits, limits, ls="--", c=".3")
|
1034
|
+
|
1035
|
+
# formatting and details
|
1036
|
+
plt.xlabel("Augur scores 1")
|
1037
|
+
plt.ylabel("Augur scores 2")
|
1038
|
+
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1039
|
+
ax.add_artist(legend1)
|
1040
|
+
|
1041
|
+
if save:
|
1042
|
+
plt.savefig(save, bbox_inches="tight")
|
1043
|
+
if show:
|
1044
|
+
plt.show()
|
1045
|
+
if return_fig:
|
1046
|
+
return plt.gcf()
|
1047
|
+
if not (show or save):
|
1048
|
+
return ax
|
1049
|
+
return None
|
1050
|
+
|
1051
|
+
def plot_important_features(
|
1052
|
+
self,
|
1053
|
+
data: dict[str, Any],
|
1054
|
+
key: str = "augurpy_results",
|
1055
|
+
top_n: int = 10,
|
1056
|
+
return_fig: bool | None = None,
|
1057
|
+
ax: Axes = None,
|
1058
|
+
show: bool | None = None,
|
1059
|
+
save: str | bool | None = None,
|
1060
|
+
) -> Axes | None:
|
1061
|
+
"""Plot a lollipop plot of the n features with largest feature importances.
|
1062
|
+
|
1063
|
+
Args:
|
1064
|
+
results: results after running `predict()` as dictionary or the AnnData object.
|
1065
|
+
key: Key in the AnnData object of the results
|
1066
|
+
top_n: n number feature importance values to plot. Default is 10.
|
1067
|
+
ax: optionally, axes used to draw plot
|
1068
|
+
return_figure: if `True` returns figure of the plot, default is `False`
|
1069
|
+
|
1070
|
+
Returns:
|
1071
|
+
Axes of the plot.
|
1072
|
+
|
1073
|
+
Examples:
|
1074
|
+
>>> import pertpy as pt
|
1075
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1076
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1077
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1078
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1079
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1080
|
+
... )
|
1081
|
+
>>> ag_rfc.plot_important_features(v_results)
|
1082
|
+
|
1083
|
+
Preview:
|
1084
|
+
.. image:: /_static/docstring_previews/augur_important_features.png
|
1085
|
+
"""
|
1086
|
+
if isinstance(data, AnnData):
|
1087
|
+
results = data.uns[key]
|
1088
|
+
else:
|
1089
|
+
results = data
|
1090
|
+
n_features = (
|
1091
|
+
results["feature_importances"]
|
1092
|
+
.groupby("genes", as_index=False)
|
1093
|
+
.feature_importances.mean()
|
1094
|
+
.sort_values(by="feature_importances")[-top_n:]
|
1095
|
+
)
|
1096
|
+
|
1097
|
+
if ax is None:
|
1098
|
+
fig, ax = plt.subplots()
|
1099
|
+
y_axes_range = range(1, top_n + 1)
|
1100
|
+
ax.hlines(
|
1101
|
+
y_axes_range,
|
1102
|
+
xmin=0,
|
1103
|
+
xmax=n_features["feature_importances"],
|
1104
|
+
)
|
1105
|
+
|
1106
|
+
ax.plot(n_features["feature_importances"], y_axes_range, "o")
|
1107
|
+
|
1108
|
+
plt.xlabel("Mean Feature Importance")
|
1109
|
+
plt.ylabel("Gene")
|
1110
|
+
plt.yticks(y_axes_range, n_features["genes"])
|
1111
|
+
|
1112
|
+
if save:
|
1113
|
+
plt.savefig(save, bbox_inches="tight")
|
1114
|
+
if show:
|
1115
|
+
plt.show()
|
1116
|
+
if return_fig:
|
1117
|
+
return plt.gcf()
|
1118
|
+
if not (show or save):
|
1119
|
+
return ax
|
1120
|
+
return None
|
1121
|
+
|
1122
|
+
def plot_lollipop(
|
1123
|
+
self,
|
1124
|
+
data: dict[str, Any],
|
1125
|
+
key: str = "augurpy_results",
|
1126
|
+
return_fig: bool | None = None,
|
1127
|
+
ax: Axes = None,
|
1128
|
+
show: bool | None = None,
|
1129
|
+
save: str | bool | None = None,
|
1130
|
+
) -> Axes | Figure | None:
|
1131
|
+
"""Plot a lollipop plot of the mean augur values.
|
1132
|
+
|
1133
|
+
Args:
|
1134
|
+
results: results after running `predict()` as dictionary or the AnnData object.
|
1135
|
+
key: Key in the AnnData object of the results
|
1136
|
+
ax: optionally, axes used to draw plot
|
1137
|
+
return_figure: if `True` returns figure of the plot
|
1138
|
+
|
1139
|
+
Returns:
|
1140
|
+
Axes of the plot.
|
1141
|
+
|
1142
|
+
Examples:
|
1143
|
+
>>> import pertpy as pt
|
1144
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1145
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1146
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1147
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1148
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1149
|
+
... )
|
1150
|
+
>>> ag_rfc.plot_lollipop(v_results)
|
1151
|
+
|
1152
|
+
Preview:
|
1153
|
+
.. image:: /_static/docstring_previews/augur_lollipop.png
|
1154
|
+
"""
|
1155
|
+
if isinstance(data, AnnData):
|
1156
|
+
results = data.uns[key]
|
1157
|
+
else:
|
1158
|
+
results = data
|
1159
|
+
if ax is None:
|
1160
|
+
fig, ax = plt.subplots()
|
1161
|
+
y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
|
1162
|
+
ax.hlines(
|
1163
|
+
y_axes_range,
|
1164
|
+
xmin=0,
|
1165
|
+
xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
1166
|
+
)
|
1167
|
+
|
1168
|
+
ax.plot(
|
1169
|
+
results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
1170
|
+
y_axes_range,
|
1171
|
+
"o",
|
1172
|
+
)
|
1173
|
+
|
1174
|
+
plt.xlabel("Mean Augur Score")
|
1175
|
+
plt.ylabel("Cell Type")
|
1176
|
+
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1177
|
+
|
1178
|
+
if save:
|
1179
|
+
plt.savefig(save, bbox_inches="tight")
|
1180
|
+
if show:
|
1181
|
+
plt.show()
|
1182
|
+
if return_fig:
|
1183
|
+
return plt.gcf()
|
1184
|
+
if not (show or save):
|
1185
|
+
return ax
|
1186
|
+
return None
|
1187
|
+
|
1188
|
+
def plot_scatterplot(
|
1189
|
+
self,
|
1190
|
+
results1: dict[str, Any],
|
1191
|
+
results2: dict[str, Any],
|
1192
|
+
top_n: int = None,
|
1193
|
+
return_fig: bool | None = None,
|
1194
|
+
show: bool | None = None,
|
1195
|
+
save: str | bool | None = None,
|
1196
|
+
) -> Axes | Figure | None:
|
1197
|
+
"""Create scatterplot with two augur results.
|
1198
|
+
|
1199
|
+
Args:
|
1200
|
+
results1: results after running `predict()`
|
1201
|
+
results2: results after running `predict()`
|
1202
|
+
top_n: optionally, the number of top prioritized cell types to label in the plot
|
1203
|
+
return_figure: if `True` returns figure of the plot
|
1204
|
+
|
1205
|
+
Returns:
|
1206
|
+
Axes of the plot.
|
1207
|
+
|
1208
|
+
Examples:
|
1209
|
+
>>> import pertpy as pt
|
1210
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1211
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1212
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1213
|
+
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
|
1214
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1215
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1216
|
+
... )
|
1217
|
+
>>> ag_rfc.plot_scatterplot(v_results, h_results)
|
1218
|
+
|
1219
|
+
Preview:
|
1220
|
+
.. image:: /_static/docstring_previews/augur_scatterplot.png
|
1221
|
+
"""
|
1222
|
+
cell_types = results1["summary_metrics"].columns
|
1223
|
+
|
1224
|
+
fig, ax = plt.subplots()
|
1225
|
+
ax.scatter(
|
1226
|
+
results1["summary_metrics"].loc["mean_augur_score", cell_types],
|
1227
|
+
results2["summary_metrics"].loc["mean_augur_score", cell_types],
|
1228
|
+
)
|
1229
|
+
|
1230
|
+
# adding optional labels
|
1231
|
+
top_n_cell_types = (
|
1232
|
+
(results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
|
1233
|
+
.sort_values(ascending=False)
|
1234
|
+
.index[:top_n]
|
1235
|
+
)
|
1236
|
+
for txt in top_n_cell_types:
|
1237
|
+
ax.annotate(
|
1238
|
+
txt,
|
1239
|
+
(
|
1240
|
+
results1["summary_metrics"].loc["mean_augur_score", txt],
|
1241
|
+
results2["summary_metrics"].loc["mean_augur_score", txt],
|
1242
|
+
),
|
1243
|
+
)
|
1244
|
+
|
1245
|
+
# adding diagonal
|
1246
|
+
limits = max(ax.get_xlim(), ax.get_ylim())
|
1247
|
+
(diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
|
1248
|
+
|
1249
|
+
plt.xlabel("Augur scores 1")
|
1250
|
+
plt.ylabel("Augur scores 2")
|
1251
|
+
|
1252
|
+
if save:
|
1253
|
+
plt.savefig(save, bbox_inches="tight")
|
1254
|
+
if show:
|
1255
|
+
plt.show()
|
1256
|
+
if return_fig:
|
1257
|
+
return plt.gcf()
|
1258
|
+
if not (show or save):
|
1259
|
+
return ax
|
1260
|
+
return None
|