pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_augur.py
CHANGED
@@ -4,14 +4,17 @@ import random
|
|
4
4
|
from collections import defaultdict
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from math import floor, nan
|
7
|
-
from typing import Any, Literal
|
7
|
+
from typing import TYPE_CHECKING, Any, Literal
|
8
8
|
|
9
|
+
import anndata as ad
|
10
|
+
import matplotlib.pyplot as plt
|
9
11
|
import numpy as np
|
10
12
|
import pandas as pd
|
11
13
|
import scanpy as sc
|
12
14
|
import statsmodels.api as sm
|
13
15
|
from anndata import AnnData
|
14
16
|
from joblib import Parallel, delayed
|
17
|
+
from lamin_utils import logger
|
15
18
|
from rich import print
|
16
19
|
from rich.progress import track
|
17
20
|
from scipy import sparse, stats
|
@@ -34,6 +37,10 @@ from sklearn.preprocessing import LabelEncoder
|
|
34
37
|
from skmisc.loess import loess
|
35
38
|
from statsmodels.stats.multitest import fdrcorrection
|
36
39
|
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
from matplotlib.axes import Axes
|
42
|
+
from matplotlib.figure import Figure
|
43
|
+
|
37
44
|
|
38
45
|
@dataclass
|
39
46
|
class Params:
|
@@ -121,7 +128,7 @@ class Augur:
|
|
121
128
|
_ = input[cell_type_col]
|
122
129
|
_ = input[label_col]
|
123
130
|
except KeyError:
|
124
|
-
|
131
|
+
logger.error("No column names matching cell_type_col and label_col.")
|
125
132
|
|
126
133
|
label = input[label_col] if meta is None else meta[label_col]
|
127
134
|
cell_type = input[cell_type_col] if meta is None else meta[cell_type_col]
|
@@ -134,9 +141,9 @@ class Augur:
|
|
134
141
|
if adata.obs["label"].dtype.name == "category":
|
135
142
|
# filter samples according to label
|
136
143
|
if condition_label is not None and treatment_label is not None:
|
137
|
-
|
138
|
-
adata =
|
139
|
-
adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
|
144
|
+
logger.info(f"Filtering samples with {condition_label} and {treatment_label} labels.")
|
145
|
+
adata = ad.concat(
|
146
|
+
[adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
|
140
147
|
)
|
141
148
|
label_encoder = LabelEncoder()
|
142
149
|
adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
|
@@ -214,7 +221,9 @@ class Augur:
|
|
214
221
|
>>> loaded_data = ag_rfc.load(adata)
|
215
222
|
>>> ag_rfc.select_highly_variable(loaded_data)
|
216
223
|
>>> features = loaded_data.var_names
|
217
|
-
>>> subsample = ag_rfc.sample(
|
224
|
+
>>> subsample = ag_rfc.sample(
|
225
|
+
... loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names
|
226
|
+
... )
|
218
227
|
"""
|
219
228
|
# export subsampling.
|
220
229
|
random.seed(random_state)
|
@@ -230,7 +239,7 @@ class Augur:
|
|
230
239
|
random_state=random_state,
|
231
240
|
)
|
232
241
|
)
|
233
|
-
subsample =
|
242
|
+
subsample = ad.concat([*label_subsamples], index_unique=None)
|
234
243
|
else:
|
235
244
|
subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
|
236
245
|
|
@@ -409,8 +418,8 @@ class Augur:
|
|
409
418
|
"""
|
410
419
|
if multiclass:
|
411
420
|
return {
|
412
|
-
"augur_score": make_scorer(roc_auc_score, multi_class="ovo",
|
413
|
-
"auc": make_scorer(roc_auc_score, multi_class="ovo",
|
421
|
+
"augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
|
422
|
+
"auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
|
414
423
|
"accuracy": make_scorer(accuracy_score),
|
415
424
|
"precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
|
416
425
|
"f1": make_scorer(f1_score, average="macro"),
|
@@ -418,8 +427,8 @@ class Augur:
|
|
418
427
|
}
|
419
428
|
return (
|
420
429
|
{
|
421
|
-
"augur_score": make_scorer(roc_auc_score,
|
422
|
-
"auc": make_scorer(roc_auc_score,
|
430
|
+
"augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
|
431
|
+
"auc": make_scorer(roc_auc_score, response_method="predict_proba"),
|
423
432
|
"accuracy": make_scorer(accuracy_score),
|
424
433
|
"precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
|
425
434
|
"f1": make_scorer(f1_score, average="binary"),
|
@@ -488,7 +497,7 @@ class Augur:
|
|
488
497
|
# feature importances
|
489
498
|
feature_importances = defaultdict(list)
|
490
499
|
if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
|
491
|
-
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
|
500
|
+
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
|
492
501
|
feature_importances["genes"].extend(x.columns.tolist())
|
493
502
|
feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
|
494
503
|
feature_importances["subsample_idx"].extend(len(x.columns) * [subsample_idx])
|
@@ -497,7 +506,7 @@ class Augur:
|
|
497
506
|
# standardized coefficients with Agresti method
|
498
507
|
# cf. https://think-lab.github.io/d/205/#3
|
499
508
|
if isinstance(self.estimator, LogisticRegression):
|
500
|
-
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
|
509
|
+
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
|
501
510
|
feature_importances["genes"].extend(x.columns.tolist())
|
502
511
|
feature_importances["feature_importances"].extend(
|
503
512
|
(self.estimator.coef_ * self.estimator.coef_.std()).flatten().tolist()
|
@@ -548,7 +557,7 @@ class Augur:
|
|
548
557
|
try:
|
549
558
|
sc.pp.highly_variable_genes(adata)
|
550
559
|
except ValueError:
|
551
|
-
|
560
|
+
logger.warn("Data not normalized. Normalizing now using scanpy log1p normalize.")
|
552
561
|
sc.pp.log1p(adata)
|
553
562
|
sc.pp.highly_variable_genes(adata)
|
554
563
|
|
@@ -600,7 +609,7 @@ class Augur:
|
|
600
609
|
var_quantile: The quantile below which features will be filtered, based on their residuals in a loess model.
|
601
610
|
filter_negative_residuals: if `True`, filter residuals at a fixed threshold of zero, instead of `var_quantile`
|
602
611
|
span: Smoothing factor, as a fraction of the number of points to take into account.
|
603
|
-
Should be in the range (0, 1].
|
612
|
+
Should be in the range (0, 1].
|
604
613
|
|
605
614
|
Return:
|
606
615
|
AnnData object with additional select_variance column in var.
|
@@ -692,13 +701,11 @@ class Augur:
|
|
692
701
|
feature_perc: proportion of genes that are randomly selected as features for input to the classifier in each
|
693
702
|
subsample using the random gene filter
|
694
703
|
var_quantile: The quantile below which features will be filtered, based on their residuals in a loess model.
|
695
|
-
Defaults to 0.5.
|
696
704
|
span: Smoothing factor, as a fraction of the number of points to take into account. Should be in the range (0, 1].
|
697
|
-
Defaults to 0.75.
|
698
705
|
filter_negative_residuals: if `True`, filter residuals at a fixed threshold of zero, instead of `var_quantile`
|
699
706
|
n_threads: number of threads to use for parallelization
|
700
707
|
select_variance_features: Whether to select genes based on the original Augur implementation (True)
|
701
|
-
or using scanpy's highly_variable_genes (False).
|
708
|
+
or using scanpy's highly_variable_genes (False).
|
702
709
|
key_added: Key to add results to in .uns
|
703
710
|
augur_mode: One of 'default', 'velocity' or 'permute'. Setting augur_mode = "velocity" disables feature selection,
|
704
711
|
assuming feature selection has been performed by the RNA velocity procedure to produce the input matrix,
|
@@ -723,6 +730,7 @@ class Augur:
|
|
723
730
|
>>> loaded_data = ag_rfc.load(adata)
|
724
731
|
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
|
725
732
|
"""
|
733
|
+
adata = adata.copy()
|
726
734
|
if augur_mode == "permute" and n_subsamples < 100:
|
727
735
|
n_subsamples = 500
|
728
736
|
if is_regressor(self.estimator) and len(adata.obs["y_"].unique()) <= 3:
|
@@ -742,8 +750,8 @@ class Augur:
|
|
742
750
|
"full_results": defaultdict(list),
|
743
751
|
}
|
744
752
|
if select_variance_features:
|
745
|
-
|
746
|
-
|
753
|
+
logger.warning("Set smaller span value in the case of a `segmentation fault` error.")
|
754
|
+
logger.warning("Set larger span in case of svddc or other near singularities error.")
|
747
755
|
adata.obs["augur_score"] = nan
|
748
756
|
for cell_type in track(adata.obs["cell_type"].unique(), description="Processing data..."):
|
749
757
|
cell_type_subsample = adata[adata.obs["cell_type"] == cell_type].copy()
|
@@ -759,17 +767,18 @@ class Augur:
|
|
759
767
|
)
|
760
768
|
)
|
761
769
|
if len(cell_type_subsample) < min_cells:
|
762
|
-
|
763
|
-
f"
|
770
|
+
logger.warning(
|
771
|
+
f"Skipping {cell_type} cell type - {len(cell_type_subsample)} samples is less than min_cells {min_cells}."
|
764
772
|
)
|
765
773
|
elif (
|
766
774
|
cell_type_subsample.obs.groupby(
|
767
775
|
["cell_type", "label"],
|
776
|
+
observed=True,
|
768
777
|
).y_.count()
|
769
778
|
< subsample_size
|
770
779
|
).any():
|
771
|
-
|
772
|
-
f"
|
780
|
+
logger.warning(
|
781
|
+
f"Skipping {cell_type} cell type - the number of samples for at least one class type is less than "
|
773
782
|
f"subsample size {subsample_size}."
|
774
783
|
)
|
775
784
|
else:
|
@@ -804,14 +813,14 @@ class Augur:
|
|
804
813
|
* (len(results["feature_importances"]["genes"]) - len(results["feature_importances"]["cell_type"]))
|
805
814
|
)
|
806
815
|
|
807
|
-
for idx, cv in zip(range(n_subsamples), results[cell_type]):
|
816
|
+
for idx, cv in zip(range(n_subsamples), results[cell_type], strict=False):
|
808
817
|
results["full_results"]["idx"].extend([idx] * folds)
|
809
818
|
results["full_results"]["augur_score"].extend(cv["test_augur_score"])
|
810
819
|
results["full_results"]["folds"].extend(range(folds))
|
811
820
|
results["full_results"]["cell_type"].extend([cell_type] * folds * n_subsamples)
|
812
821
|
# make sure one cell type worked
|
813
822
|
if len(results) <= 2:
|
814
|
-
|
823
|
+
logger.warning("No cells types had more than min_cells needed. Please adjust data or min_cells parameter.")
|
815
824
|
|
816
825
|
results["summary_metrics"] = pd.DataFrame(results["summary_metrics"])
|
817
826
|
results["feature_importances"] = pd.DataFrame(results["feature_importances"])
|
@@ -840,7 +849,7 @@ class Augur:
|
|
840
849
|
augur2: Augurpy results from condition 2, obtained from `predict()[1]`
|
841
850
|
permuted1: permuted Augurpy results from condition 1, obtained from `predict()` with argument `augur_mode=permute`
|
842
851
|
permuted2: permuted Augurpy results from condition 2, obtained from `predict()` with argument `augur_mode=permute`
|
843
|
-
n_subsamples: number of subsamples to pool when calculating the mean augur score for each permutation
|
852
|
+
n_subsamples: number of subsamples to pool when calculating the mean augur score for each permutation.
|
844
853
|
n_permutations: the total number of mean augur scores to calculate from a background distribution
|
845
854
|
|
846
855
|
Returns:
|
@@ -869,28 +878,31 @@ class Augur:
|
|
869
878
|
& set(permuted_results1["summary_metrics"].columns)
|
870
879
|
& set(permuted_results2["summary_metrics"].columns)
|
871
880
|
)
|
881
|
+
|
882
|
+
cell_types_list = list(cell_types)
|
883
|
+
|
872
884
|
# mean augur scores
|
873
885
|
augur_score1 = (
|
874
886
|
augur_results1["summary_metrics"]
|
875
|
-
.loc["mean_augur_score",
|
887
|
+
.loc["mean_augur_score", cell_types_list]
|
876
888
|
.reset_index()
|
877
889
|
.rename(columns={"index": "cell_type"})
|
878
890
|
)
|
879
891
|
augur_score2 = (
|
880
892
|
augur_results2["summary_metrics"]
|
881
|
-
.loc["mean_augur_score",
|
893
|
+
.loc["mean_augur_score", cell_types_list]
|
882
894
|
.reset_index()
|
883
895
|
.rename(columns={"index": "cell_type"})
|
884
896
|
)
|
885
897
|
|
886
898
|
# mean permuted scores over cross validation runs
|
887
899
|
permuted_cv_augur1 = (
|
888
|
-
permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(
|
900
|
+
permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types_list)]
|
889
901
|
.groupby(["cell_type", "idx"], as_index=False)
|
890
902
|
.mean()
|
891
903
|
)
|
892
904
|
permuted_cv_augur2 = (
|
893
|
-
permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(
|
905
|
+
permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types_list)]
|
894
906
|
.groupby(["cell_type", "idx"], as_index=False)
|
895
907
|
.mean()
|
896
908
|
)
|
@@ -901,7 +913,7 @@ class Augur:
|
|
901
913
|
# draw mean aucs for permute1 and permute2
|
902
914
|
for celltype in permuted_cv_augur1["cell_type"].unique():
|
903
915
|
df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
|
904
|
-
df2 = permuted_cv_augur2[
|
916
|
+
df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
|
905
917
|
for permutation_idx in range(n_permutations):
|
906
918
|
# subsample
|
907
919
|
sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
|
@@ -961,3 +973,288 @@ class Augur:
|
|
961
973
|
delta["padj"] = fdrcorrection(delta["pval"])[1]
|
962
974
|
|
963
975
|
return delta
|
976
|
+
|
977
|
+
def plot_dp_scatter(
|
978
|
+
self,
|
979
|
+
results: pd.DataFrame,
|
980
|
+
top_n: int = None,
|
981
|
+
return_fig: bool | None = None,
|
982
|
+
ax: Axes = None,
|
983
|
+
show: bool | None = None,
|
984
|
+
save: str | bool | None = None,
|
985
|
+
) -> Axes | Figure | None:
|
986
|
+
"""Plot scatterplot of differential prioritization.
|
987
|
+
|
988
|
+
Args:
|
989
|
+
results: Results after running differential prioritization.
|
990
|
+
top_n: optionally, the number of top prioritized cell types to label in the plot
|
991
|
+
ax: optionally, axes used to draw plot
|
992
|
+
|
993
|
+
Returns:
|
994
|
+
Axes of the plot.
|
995
|
+
|
996
|
+
Examples:
|
997
|
+
>>> import pertpy as pt
|
998
|
+
>>> adata = pt.dt.bhattacherjee()
|
999
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1000
|
+
|
1001
|
+
>>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine")
|
1002
|
+
>>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4)
|
1003
|
+
>>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
|
1004
|
+
|
1005
|
+
>>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine")
|
1006
|
+
>>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4)
|
1007
|
+
>>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
|
1008
|
+
|
1009
|
+
>>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
|
1010
|
+
permuted_results1=results_15_permute, permuted_results2=results_48_permute)
|
1011
|
+
>>> ag_rfc.plot_dp_scatter(pvals)
|
1012
|
+
|
1013
|
+
Preview:
|
1014
|
+
.. image:: /_static/docstring_previews/augur_dp_scatter.png
|
1015
|
+
"""
|
1016
|
+
x = results["mean_augur_score1"]
|
1017
|
+
y = results["mean_augur_score2"]
|
1018
|
+
|
1019
|
+
if ax is None:
|
1020
|
+
fig, ax = plt.subplots()
|
1021
|
+
scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
|
1022
|
+
|
1023
|
+
# adding optional labels
|
1024
|
+
top_n_index = results.sort_values(by="pval").index[:top_n]
|
1025
|
+
for idx in top_n_index:
|
1026
|
+
ax.annotate(
|
1027
|
+
results.loc[idx, "cell_type"],
|
1028
|
+
(results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
|
1029
|
+
)
|
1030
|
+
|
1031
|
+
# add diagonal
|
1032
|
+
limits = max(ax.get_xlim(), ax.get_ylim())
|
1033
|
+
(_,) = ax.plot(limits, limits, ls="--", c=".3")
|
1034
|
+
|
1035
|
+
# formatting and details
|
1036
|
+
plt.xlabel("Augur scores 1")
|
1037
|
+
plt.ylabel("Augur scores 2")
|
1038
|
+
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1039
|
+
ax.add_artist(legend1)
|
1040
|
+
|
1041
|
+
if save:
|
1042
|
+
plt.savefig(save, bbox_inches="tight")
|
1043
|
+
if show:
|
1044
|
+
plt.show()
|
1045
|
+
if return_fig:
|
1046
|
+
return plt.gcf()
|
1047
|
+
if not (show or save):
|
1048
|
+
return ax
|
1049
|
+
return None
|
1050
|
+
|
1051
|
+
def plot_important_features(
|
1052
|
+
self,
|
1053
|
+
data: dict[str, Any],
|
1054
|
+
key: str = "augurpy_results",
|
1055
|
+
top_n: int = 10,
|
1056
|
+
return_fig: bool | None = None,
|
1057
|
+
ax: Axes = None,
|
1058
|
+
show: bool | None = None,
|
1059
|
+
save: str | bool | None = None,
|
1060
|
+
) -> Axes | None:
|
1061
|
+
"""Plot a lollipop plot of the n features with largest feature importances.
|
1062
|
+
|
1063
|
+
Args:
|
1064
|
+
results: results after running `predict()` as dictionary or the AnnData object.
|
1065
|
+
key: Key in the AnnData object of the results
|
1066
|
+
top_n: n number feature importance values to plot. Default is 10.
|
1067
|
+
ax: optionally, axes used to draw plot
|
1068
|
+
return_figure: if `True` returns figure of the plot, default is `False`
|
1069
|
+
|
1070
|
+
Returns:
|
1071
|
+
Axes of the plot.
|
1072
|
+
|
1073
|
+
Examples:
|
1074
|
+
>>> import pertpy as pt
|
1075
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1076
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1077
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1078
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1079
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1080
|
+
... )
|
1081
|
+
>>> ag_rfc.plot_important_features(v_results)
|
1082
|
+
|
1083
|
+
Preview:
|
1084
|
+
.. image:: /_static/docstring_previews/augur_important_features.png
|
1085
|
+
"""
|
1086
|
+
if isinstance(data, AnnData):
|
1087
|
+
results = data.uns[key]
|
1088
|
+
else:
|
1089
|
+
results = data
|
1090
|
+
n_features = (
|
1091
|
+
results["feature_importances"]
|
1092
|
+
.groupby("genes", as_index=False)
|
1093
|
+
.feature_importances.mean()
|
1094
|
+
.sort_values(by="feature_importances")[-top_n:]
|
1095
|
+
)
|
1096
|
+
|
1097
|
+
if ax is None:
|
1098
|
+
fig, ax = plt.subplots()
|
1099
|
+
y_axes_range = range(1, top_n + 1)
|
1100
|
+
ax.hlines(
|
1101
|
+
y_axes_range,
|
1102
|
+
xmin=0,
|
1103
|
+
xmax=n_features["feature_importances"],
|
1104
|
+
)
|
1105
|
+
|
1106
|
+
ax.plot(n_features["feature_importances"], y_axes_range, "o")
|
1107
|
+
|
1108
|
+
plt.xlabel("Mean Feature Importance")
|
1109
|
+
plt.ylabel("Gene")
|
1110
|
+
plt.yticks(y_axes_range, n_features["genes"])
|
1111
|
+
|
1112
|
+
if save:
|
1113
|
+
plt.savefig(save, bbox_inches="tight")
|
1114
|
+
if show:
|
1115
|
+
plt.show()
|
1116
|
+
if return_fig:
|
1117
|
+
return plt.gcf()
|
1118
|
+
if not (show or save):
|
1119
|
+
return ax
|
1120
|
+
return None
|
1121
|
+
|
1122
|
+
def plot_lollipop(
|
1123
|
+
self,
|
1124
|
+
data: dict[str, Any],
|
1125
|
+
key: str = "augurpy_results",
|
1126
|
+
return_fig: bool | None = None,
|
1127
|
+
ax: Axes = None,
|
1128
|
+
show: bool | None = None,
|
1129
|
+
save: str | bool | None = None,
|
1130
|
+
) -> Axes | Figure | None:
|
1131
|
+
"""Plot a lollipop plot of the mean augur values.
|
1132
|
+
|
1133
|
+
Args:
|
1134
|
+
results: results after running `predict()` as dictionary or the AnnData object.
|
1135
|
+
key: Key in the AnnData object of the results
|
1136
|
+
ax: optionally, axes used to draw plot
|
1137
|
+
return_figure: if `True` returns figure of the plot
|
1138
|
+
|
1139
|
+
Returns:
|
1140
|
+
Axes of the plot.
|
1141
|
+
|
1142
|
+
Examples:
|
1143
|
+
>>> import pertpy as pt
|
1144
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1145
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1146
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1147
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1148
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1149
|
+
... )
|
1150
|
+
>>> ag_rfc.plot_lollipop(v_results)
|
1151
|
+
|
1152
|
+
Preview:
|
1153
|
+
.. image:: /_static/docstring_previews/augur_lollipop.png
|
1154
|
+
"""
|
1155
|
+
if isinstance(data, AnnData):
|
1156
|
+
results = data.uns[key]
|
1157
|
+
else:
|
1158
|
+
results = data
|
1159
|
+
if ax is None:
|
1160
|
+
fig, ax = plt.subplots()
|
1161
|
+
y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
|
1162
|
+
ax.hlines(
|
1163
|
+
y_axes_range,
|
1164
|
+
xmin=0,
|
1165
|
+
xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
1166
|
+
)
|
1167
|
+
|
1168
|
+
ax.plot(
|
1169
|
+
results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
|
1170
|
+
y_axes_range,
|
1171
|
+
"o",
|
1172
|
+
)
|
1173
|
+
|
1174
|
+
plt.xlabel("Mean Augur Score")
|
1175
|
+
plt.ylabel("Cell Type")
|
1176
|
+
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1177
|
+
|
1178
|
+
if save:
|
1179
|
+
plt.savefig(save, bbox_inches="tight")
|
1180
|
+
if show:
|
1181
|
+
plt.show()
|
1182
|
+
if return_fig:
|
1183
|
+
return plt.gcf()
|
1184
|
+
if not (show or save):
|
1185
|
+
return ax
|
1186
|
+
return None
|
1187
|
+
|
1188
|
+
def plot_scatterplot(
|
1189
|
+
self,
|
1190
|
+
results1: dict[str, Any],
|
1191
|
+
results2: dict[str, Any],
|
1192
|
+
top_n: int = None,
|
1193
|
+
return_fig: bool | None = None,
|
1194
|
+
show: bool | None = None,
|
1195
|
+
save: str | bool | None = None,
|
1196
|
+
) -> Axes | Figure | None:
|
1197
|
+
"""Create scatterplot with two augur results.
|
1198
|
+
|
1199
|
+
Args:
|
1200
|
+
results1: results after running `predict()`
|
1201
|
+
results2: results after running `predict()`
|
1202
|
+
top_n: optionally, the number of top prioritized cell types to label in the plot
|
1203
|
+
return_figure: if `True` returns figure of the plot
|
1204
|
+
|
1205
|
+
Returns:
|
1206
|
+
Axes of the plot.
|
1207
|
+
|
1208
|
+
Examples:
|
1209
|
+
>>> import pertpy as pt
|
1210
|
+
>>> adata = pt.dt.sc_sim_augur()
|
1211
|
+
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
1212
|
+
>>> loaded_data = ag_rfc.load(adata)
|
1213
|
+
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
|
1214
|
+
>>> v_adata, v_results = ag_rfc.predict(
|
1215
|
+
... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
|
1216
|
+
... )
|
1217
|
+
>>> ag_rfc.plot_scatterplot(v_results, h_results)
|
1218
|
+
|
1219
|
+
Preview:
|
1220
|
+
.. image:: /_static/docstring_previews/augur_scatterplot.png
|
1221
|
+
"""
|
1222
|
+
cell_types = results1["summary_metrics"].columns
|
1223
|
+
|
1224
|
+
fig, ax = plt.subplots()
|
1225
|
+
ax.scatter(
|
1226
|
+
results1["summary_metrics"].loc["mean_augur_score", cell_types],
|
1227
|
+
results2["summary_metrics"].loc["mean_augur_score", cell_types],
|
1228
|
+
)
|
1229
|
+
|
1230
|
+
# adding optional labels
|
1231
|
+
top_n_cell_types = (
|
1232
|
+
(results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
|
1233
|
+
.sort_values(ascending=False)
|
1234
|
+
.index[:top_n]
|
1235
|
+
)
|
1236
|
+
for txt in top_n_cell_types:
|
1237
|
+
ax.annotate(
|
1238
|
+
txt,
|
1239
|
+
(
|
1240
|
+
results1["summary_metrics"].loc["mean_augur_score", txt],
|
1241
|
+
results2["summary_metrics"].loc["mean_augur_score", txt],
|
1242
|
+
),
|
1243
|
+
)
|
1244
|
+
|
1245
|
+
# adding diagonal
|
1246
|
+
limits = max(ax.get_xlim(), ax.get_ylim())
|
1247
|
+
(diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
|
1248
|
+
|
1249
|
+
plt.xlabel("Augur scores 1")
|
1250
|
+
plt.ylabel("Augur scores 2")
|
1251
|
+
|
1252
|
+
if save:
|
1253
|
+
plt.savefig(save, bbox_inches="tight")
|
1254
|
+
if show:
|
1255
|
+
plt.show()
|
1256
|
+
if return_fig:
|
1257
|
+
return plt.gcf()
|
1258
|
+
if not (show or save):
|
1259
|
+
return ax
|
1260
|
+
return None
|