pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_dialogue.py
CHANGED
@@ -2,13 +2,14 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import itertools
|
4
4
|
from collections import defaultdict
|
5
|
-
from typing import Any, Literal
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal
|
6
6
|
|
7
7
|
import anndata as ad
|
8
|
+
import matplotlib.pyplot as plt
|
8
9
|
import numpy as np
|
9
10
|
import pandas as pd
|
10
11
|
import scanpy as sc
|
11
|
-
import
|
12
|
+
import seaborn as sns
|
12
13
|
import statsmodels.formula.api as smf
|
13
14
|
import statsmodels.stats.multitest as ssm
|
14
15
|
from anndata import AnnData
|
@@ -19,10 +20,15 @@ from rich.live import Live
|
|
19
20
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
20
21
|
from scipy import stats
|
21
22
|
from scipy.optimize import nnls
|
23
|
+
from seaborn import PairGrid
|
22
24
|
from sklearn.linear_model import LinearRegression
|
23
25
|
from sparsecca import lp_pmd, multicca_permute, multicca_pmd
|
24
26
|
from statsmodels.sandbox.stats.multicomp import multipletests
|
25
27
|
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from matplotlib.axes import Axes
|
30
|
+
from matplotlib.figure import Figure
|
31
|
+
|
26
32
|
|
27
33
|
class Dialogue:
|
28
34
|
"""Python implementation of DIALOGUE"""
|
@@ -53,8 +59,6 @@ class Dialogue:
|
|
53
59
|
|
54
60
|
Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
|
55
61
|
|
56
|
-
# TODO: Replace with decoupler's implementation
|
57
|
-
|
58
62
|
Args:
|
59
63
|
groupby: The key to groupby for pseudobulks
|
60
64
|
strategy: The pseudobulking strategy. One of "median" or "mean"
|
@@ -62,14 +66,15 @@ class Dialogue:
|
|
62
66
|
Returns:
|
63
67
|
A Pandas DataFrame of pseudobulk counts
|
64
68
|
"""
|
69
|
+
# TODO: Replace with decoupler's implementation
|
65
70
|
pseudobulk = {"Genes": adata.var_names.values}
|
66
71
|
|
67
72
|
for category in adata.obs.loc[:, groupby].cat.categories:
|
68
73
|
temp = adata.obs.loc[:, groupby] == category
|
69
74
|
if strategy == "median":
|
70
|
-
pseudobulk[category] = adata[temp].X.median(axis=0)
|
75
|
+
pseudobulk[category] = adata[temp].X.median(axis=0)
|
71
76
|
elif strategy == "mean":
|
72
|
-
pseudobulk[category] = adata[temp].X.mean(axis=0)
|
77
|
+
pseudobulk[category] = adata[temp].X.mean(axis=0)
|
73
78
|
|
74
79
|
pseudobulk = pd.DataFrame(pseudobulk).set_index("Genes")
|
75
80
|
|
@@ -101,8 +106,6 @@ class Dialogue:
|
|
101
106
|
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
|
102
107
|
"""Row-wise mean center and scale by the standard deviation.
|
103
108
|
|
104
|
-
TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
|
105
|
-
|
106
109
|
Args:
|
107
110
|
pseudobulks: The pseudobulk PCA components.
|
108
111
|
normalize: Whether to mimic DIALOGUE behavior or not.
|
@@ -110,9 +113,9 @@ class Dialogue:
|
|
110
113
|
Returns:
|
111
114
|
The scaled count matrix.
|
112
115
|
"""
|
116
|
+
# TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
|
113
117
|
# DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
|
114
118
|
# However, performing this scaling _does_ increase overall correlation of the end result
|
115
|
-
# WHEN SAMPLE ORDER AND DIALOGUE2+3 PROCESSING IS IGNORED.
|
116
119
|
if normalize:
|
117
120
|
return pseudobulks.to_numpy()
|
118
121
|
else:
|
@@ -313,13 +316,13 @@ class Dialogue:
|
|
313
316
|
def _apply_HLM_per_MCP_for_one_pair(
|
314
317
|
self,
|
315
318
|
mcp_name: str,
|
316
|
-
scores_df:
|
319
|
+
scores_df: pd.DataFrame,
|
317
320
|
ct_data: AnnData,
|
318
321
|
tme: pd.DataFrame,
|
319
322
|
sig: dict,
|
320
323
|
n_counts: str,
|
321
324
|
formula: str,
|
322
|
-
confounder: str,
|
325
|
+
confounder: str | None,
|
323
326
|
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
324
327
|
"""Applies hierarchical modeling for a single MCP.
|
325
328
|
|
@@ -340,7 +343,7 @@ class Dialogue:
|
|
340
343
|
"""
|
341
344
|
HLM_result = self._mixed_effects(
|
342
345
|
scores=scores_df[[mcp_name]],
|
343
|
-
x_labels=ct_data.obs[[n_counts, confounder]],
|
346
|
+
x_labels=ct_data.obs[[n_counts, confounder]] if confounder else ct_data.obs[[n_counts]],
|
344
347
|
tme=tme,
|
345
348
|
genes_in_mcp=list(sig[mcp_name]["up"]) + list(sig[mcp_name]["down"]),
|
346
349
|
formula=formula,
|
@@ -367,7 +370,7 @@ class Dialogue:
|
|
367
370
|
return np.array(resid)
|
368
371
|
|
369
372
|
def _iterative_nnls(self, A_orig: np.ndarray, y_orig: np.ndarray, feature_ranks: list[int], n_iter: int = 1000):
|
370
|
-
"""Solves non-negative least
|
373
|
+
"""Solves non-negative least-squares separately for different feature categories.
|
371
374
|
|
372
375
|
Mimics DLG.iterative.nnls.
|
373
376
|
Variables are notated according to:
|
@@ -398,7 +401,7 @@ class Dialogue:
|
|
398
401
|
|
399
402
|
x_final = np.zeros(A_orig.shape[0])
|
400
403
|
Ax = np.zeros(A_orig.shape[1])
|
401
|
-
for _, mask in zip(sig_ranks, masks):
|
404
|
+
for _, mask in zip(sig_ranks, masks, strict=False):
|
402
405
|
A = A_orig[mask].T
|
403
406
|
coef_nnls, _ = nnls(A, y, maxiter=n_iter)
|
404
407
|
y = y - A @ coef_nnls # residuals
|
@@ -516,8 +519,8 @@ class Dialogue:
|
|
516
519
|
# TODO: probably format the up and down within get_top_elements
|
517
520
|
cca_sig: dict[str, Any] = defaultdict(dict)
|
518
521
|
for i in range(0, int(len(cca_sig_unformatted) / 2)):
|
519
|
-
cca_sig[f"MCP{i
|
520
|
-
cca_sig[f"MCP{i
|
522
|
+
cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
|
523
|
+
cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
|
521
524
|
|
522
525
|
cca_sig = dict(cca_sig)
|
523
526
|
cca_sig_results[ct] = cca_sig
|
@@ -555,7 +558,7 @@ class Dialogue:
|
|
555
558
|
|
556
559
|
return cca_sig_results, new_mcp_scores
|
557
560
|
|
558
|
-
def
|
561
|
+
def _load(
|
559
562
|
self,
|
560
563
|
adata: AnnData,
|
561
564
|
ct_order: list[str],
|
@@ -574,16 +577,6 @@ class Dialogue:
|
|
574
577
|
|
575
578
|
Returns:
|
576
579
|
A celltype_label:array dictionary.
|
577
|
-
|
578
|
-
Examples:
|
579
|
-
>>> import pertpy as pt
|
580
|
-
>>> import scanpy as sc
|
581
|
-
>>> adata = pt.dt.dialogue_example()
|
582
|
-
>>> sc.pp.pca(adata)
|
583
|
-
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
584
|
-
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
585
|
-
>>> cell_types = adata.obs[dl.celltype_key].astype("category").cat.categories
|
586
|
-
>>> mcca_in, ct_subs = dl.load(adata, ct_order=cell_types)
|
587
580
|
"""
|
588
581
|
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
|
589
582
|
fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
|
@@ -631,19 +624,19 @@ class Dialogue:
|
|
631
624
|
>>> import scanpy as sc
|
632
625
|
>>> adata = pt.dt.dialogue_example()
|
633
626
|
>>> sc.pp.pca(adata)
|
634
|
-
>>> dl = pt.tl.Dialogue(
|
635
|
-
|
627
|
+
>>> dl = pt.tl.Dialogue(
|
628
|
+
... sample_id="clinical.status", celltype_key="cell.subtypes", n_counts_key="nCount_RNA", n_mpcs=3
|
629
|
+
... )
|
636
630
|
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
637
631
|
"""
|
638
|
-
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
|
639
|
-
# it is important here that to obtain the same result as in R, we pass the matrices in
|
640
|
-
# in the same order.
|
632
|
+
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
|
633
|
+
# As such, it is important here that to obtain the same result as in R, we pass the matrices in the same order.
|
641
634
|
if ct_order is not None:
|
642
635
|
cell_types = ct_order
|
643
636
|
else:
|
644
637
|
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
|
645
638
|
|
646
|
-
mcca_in, ct_subs = self.
|
639
|
+
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
|
647
640
|
|
648
641
|
n_samples = mcca_in[0].shape[1]
|
649
642
|
if penalties is None:
|
@@ -685,7 +678,7 @@ class Dialogue:
|
|
685
678
|
ct_subs: dict,
|
686
679
|
mcp_scores: dict,
|
687
680
|
ws_dict: dict,
|
688
|
-
confounder: str,
|
681
|
+
confounder: str | None,
|
689
682
|
formula: str = None,
|
690
683
|
):
|
691
684
|
"""Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
|
@@ -700,7 +693,6 @@ class Dialogue:
|
|
700
693
|
A Pandas DataFrame containing:
|
701
694
|
- for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
|
702
695
|
- merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
|
703
|
-
TODO: Describe both returns
|
704
696
|
|
705
697
|
Examples:
|
706
698
|
>>> import pertpy as pt
|
@@ -713,7 +705,9 @@ class Dialogue:
|
|
713
705
|
>>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \
|
714
706
|
confounder="gender")
|
715
707
|
"""
|
716
|
-
#
|
708
|
+
# TODO the returns of the function better
|
709
|
+
|
710
|
+
# all possible pairs of cell types without pairing same cell type
|
717
711
|
cell_types = list(ct_subs.keys())
|
718
712
|
pairs = list(itertools.combinations(cell_types, 2))
|
719
713
|
|
@@ -721,9 +715,9 @@ class Dialogue:
|
|
721
715
|
formula = f"y ~ x + {self.n_counts_key}"
|
722
716
|
|
723
717
|
# Hierarchical modeling expects DataFrames
|
724
|
-
mcp_cell_types = {f"MCP{i
|
718
|
+
mcp_cell_types = {f"MCP{i}": cell_types for i in range(self.n_mcps)}
|
725
719
|
mcp_scores_df = {
|
726
|
-
ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=mcp_cell_types.keys())
|
720
|
+
ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=list(mcp_cell_types.keys()))
|
727
721
|
for ct, v in mcp_scores.items()
|
728
722
|
}
|
729
723
|
|
@@ -805,7 +799,7 @@ class Dialogue:
|
|
805
799
|
for mcp in mcps:
|
806
800
|
mixed_model_progress.update(mm_task, description=f"[bold blue]Determining mixed effects for {mcp}")
|
807
801
|
|
808
|
-
# TODO Check
|
802
|
+
# TODO Check whether the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
|
809
803
|
result = {}
|
810
804
|
result["HLM_result_1"], result["sig_genes_1"] = self._apply_HLM_per_MCP_for_one_pair(
|
811
805
|
mcp_name=mcp,
|
@@ -875,22 +869,19 @@ class Dialogue:
|
|
875
869
|
sample_label = self.sample_id
|
876
870
|
n_mcps = self.n_mcps
|
877
871
|
|
878
|
-
# create conditions_compare if not supplied
|
879
872
|
if conditions_compare is None:
|
880
|
-
conditions_compare = list(adata.obs[
|
873
|
+
conditions_compare = list(adata.obs[condition_label].cat.categories) # type: ignore
|
881
874
|
if len(conditions_compare) != 2:
|
882
875
|
raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
|
883
876
|
|
884
|
-
# create data frames to store results
|
885
877
|
pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
886
878
|
tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
887
879
|
pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
888
880
|
|
889
881
|
response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
|
890
882
|
for celltype in adata.obs[celltype_label].unique():
|
891
|
-
# subset data to cell type
|
892
883
|
df = adata.obs[adata.obs[celltype_label] == celltype]
|
893
|
-
|
884
|
+
|
894
885
|
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
|
895
886
|
mns = df.groupby(sample_label)[mcpnum].mean()
|
896
887
|
mns = pd.concat([mns, response], axis=1)
|
@@ -900,11 +891,10 @@ class Dialogue:
|
|
900
891
|
)
|
901
892
|
pvals.loc[celltype, mcpnum] = res[1]
|
902
893
|
tstats.loc[celltype, mcpnum] = res[0]
|
903
|
-
# return(res)
|
904
894
|
|
905
|
-
# benjamini-hochberg correction for number of cell types (use BH because correlated MCPs)
|
906
895
|
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
|
907
896
|
pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
|
897
|
+
|
908
898
|
return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
|
909
899
|
|
910
900
|
def get_mlm_mcp_genes(
|
@@ -921,7 +911,7 @@ class Dialogue:
|
|
921
911
|
celltype: Cell type of interest.
|
922
912
|
results: dl.MultilevelModeling result object.
|
923
913
|
MCP: MCP key of the result object.
|
924
|
-
|
914
|
+
threshold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
|
925
915
|
Defaults to 0.70.
|
926
916
|
focal_celltypes: None (compare against all cell types) or a list of other cell types which you want to compare against.
|
927
917
|
Defaults to None.
|
@@ -945,7 +935,6 @@ class Dialogue:
|
|
945
935
|
# REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD
|
946
936
|
if MCP.startswith("mcp_"):
|
947
937
|
MCP = MCP.replace("mcp_", "MCP")
|
948
|
-
# convert from MCPx to MCPx+1
|
949
938
|
MCP = "MCP" + str(int(MCP[3:]) - 1)
|
950
939
|
|
951
940
|
# Extract all comparison keys from the results object
|
@@ -1014,17 +1003,16 @@ class Dialogue:
|
|
1014
1003
|
objects containing the results of gene ranking analysis.
|
1015
1004
|
|
1016
1005
|
Examples:
|
1017
|
-
ct_subs = {
|
1018
|
-
"subpop1": anndata_obj1,
|
1019
|
-
"subpop2": anndata_obj2,
|
1020
|
-
# ... more subpopulations ...
|
1021
|
-
}
|
1022
|
-
genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
|
1006
|
+
>>> ct_subs = {
|
1007
|
+
... "subpop1": anndata_obj1,
|
1008
|
+
... "subpop2": anndata_obj2,
|
1009
|
+
... # ... more subpopulations ...
|
1010
|
+
... }
|
1011
|
+
>>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
|
1023
1012
|
"""
|
1024
1013
|
genes = {}
|
1025
1014
|
for ct in ct_subs.keys():
|
1026
1015
|
mini = ct_subs[ct]
|
1027
|
-
mini.obs[mcp]
|
1028
1016
|
mini.obs["extrema"] = pd.qcut(
|
1029
1017
|
mini.obs[mcp],
|
1030
1018
|
[0, 0 + fraction, 1 - fraction, 1.0],
|
@@ -1034,6 +1022,7 @@ class Dialogue:
|
|
1034
1022
|
mini, "extrema", groups=["high" + mcp + " " + ct], reference="low " + mcp + " " + ct
|
1035
1023
|
)
|
1036
1024
|
genes[ct] = mini # .uns['rank_genes_groups']
|
1025
|
+
|
1037
1026
|
return genes
|
1038
1027
|
|
1039
1028
|
def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
|
@@ -1064,7 +1053,7 @@ class Dialogue:
|
|
1064
1053
|
>>> extrema_mcp_genes = dl.get_extrema_MCP_genes(ct_subs)
|
1065
1054
|
"""
|
1066
1055
|
rank_dfs: dict[str, dict[Any, Any]] = {}
|
1067
|
-
|
1056
|
+
ct_sub = next(iter(ct_subs.values()))
|
1068
1057
|
mcps = [col for col in ct_sub.obs.columns if col.startswith("mcp_")]
|
1069
1058
|
|
1070
1059
|
for mcp in mcps:
|
@@ -1072,4 +1061,123 @@ class Dialogue:
|
|
1072
1061
|
ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
|
1073
1062
|
for celltype in ct_ranked.keys():
|
1074
1063
|
rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
|
1064
|
+
|
1075
1065
|
return rank_dfs
|
1066
|
+
|
1067
|
+
def plot_split_violins(
|
1068
|
+
self,
|
1069
|
+
adata: AnnData,
|
1070
|
+
split_key: str,
|
1071
|
+
celltype_key: str,
|
1072
|
+
split_which: tuple[str, str] = None,
|
1073
|
+
mcp: str = "mcp_0",
|
1074
|
+
return_fig: bool | None = None,
|
1075
|
+
ax: Axes | None = None,
|
1076
|
+
save: bool | str | None = None,
|
1077
|
+
show: bool | None = None,
|
1078
|
+
) -> Axes | Figure | None:
|
1079
|
+
"""Plots split violin plots for a given MCP and split variable.
|
1080
|
+
|
1081
|
+
Any cells with a value for split_key not in split_which are removed from the plot.
|
1082
|
+
|
1083
|
+
Args:
|
1084
|
+
adata: Annotated data object.
|
1085
|
+
split_key: Variable in adata.obs used to split the data.
|
1086
|
+
celltype_key: Key for cell type annotations.
|
1087
|
+
split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
|
1088
|
+
mcp: Key for MCP data. Defaults to "mcp_0".
|
1089
|
+
|
1090
|
+
Returns:
|
1091
|
+
A :class:`~matplotlib.axes.Axes` object
|
1092
|
+
|
1093
|
+
Examples:
|
1094
|
+
>>> import pertpy as pt
|
1095
|
+
>>> import scanpy as sc
|
1096
|
+
>>> adata = pt.dt.dialogue_example()
|
1097
|
+
>>> sc.pp.pca(adata)
|
1098
|
+
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
1099
|
+
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
1100
|
+
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
1101
|
+
>>> dl.plot_split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
|
1102
|
+
|
1103
|
+
Preview:
|
1104
|
+
.. image:: /_static/docstring_previews/dialogue_violin.png
|
1105
|
+
"""
|
1106
|
+
df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
|
1107
|
+
if split_which is None:
|
1108
|
+
split_which = df[split_key].unique()
|
1109
|
+
df = df[df[split_key].isin(split_which)]
|
1110
|
+
df[split_key] = df[split_key].cat.remove_unused_categories()
|
1111
|
+
|
1112
|
+
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
1113
|
+
|
1114
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
1115
|
+
|
1116
|
+
if save:
|
1117
|
+
plt.savefig(save, bbox_inches="tight")
|
1118
|
+
if show:
|
1119
|
+
plt.show()
|
1120
|
+
if return_fig:
|
1121
|
+
return plt.gcf()
|
1122
|
+
if not (show or save):
|
1123
|
+
return ax
|
1124
|
+
return None
|
1125
|
+
|
1126
|
+
def plot_pairplot(
|
1127
|
+
self,
|
1128
|
+
adata: AnnData,
|
1129
|
+
celltype_key: str,
|
1130
|
+
color: str,
|
1131
|
+
sample_id: str,
|
1132
|
+
mcp: str = "mcp_0",
|
1133
|
+
return_fig: bool | None = None,
|
1134
|
+
show: bool | None = None,
|
1135
|
+
save: bool | str | None = None,
|
1136
|
+
) -> PairGrid | Figure | None:
|
1137
|
+
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
1138
|
+
|
1139
|
+
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
|
1140
|
+
then creates a pairplot to visualize the relationships between these mean MCP values.
|
1141
|
+
|
1142
|
+
Args:
|
1143
|
+
adata: Annotated data object.
|
1144
|
+
celltype_key: Key in `adata.obs` containing cell type annotations.
|
1145
|
+
color: Key in `adata.obs` for color annotations. This parameter is used as the hue
|
1146
|
+
sample_id: Key in `adata.obs` for the sample annotations.
|
1147
|
+
mcp: Key in `adata.obs` for MCP feature values. Defaults to `"mcp_0"`.
|
1148
|
+
|
1149
|
+
Returns:
|
1150
|
+
Seaborn Pairgrid object.
|
1151
|
+
|
1152
|
+
Examples:
|
1153
|
+
>>> import pertpy as pt
|
1154
|
+
>>> import scanpy as sc
|
1155
|
+
>>> adata = pt.dt.dialogue_example()
|
1156
|
+
>>> sc.pp.pca(adata)
|
1157
|
+
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
1158
|
+
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
1159
|
+
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
1160
|
+
>>> dl.plot_pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
|
1161
|
+
|
1162
|
+
Preview:
|
1163
|
+
.. image:: /_static/docstring_previews/dialogue_pairplot.png
|
1164
|
+
"""
|
1165
|
+
mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
|
1166
|
+
mean_mcps = mean_mcps.reset_index()
|
1167
|
+
mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
|
1168
|
+
|
1169
|
+
aggstats = adata.obs.groupby([sample_id])[color].describe()
|
1170
|
+
aggstats = aggstats.loc[list(mcp_pivot.index), :]
|
1171
|
+
aggstats[color] = aggstats["top"]
|
1172
|
+
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
1173
|
+
ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
|
1174
|
+
|
1175
|
+
if save:
|
1176
|
+
plt.savefig(save, bbox_inches="tight")
|
1177
|
+
if show:
|
1178
|
+
plt.show()
|
1179
|
+
if return_fig:
|
1180
|
+
return plt.gcf()
|
1181
|
+
if not (show or save):
|
1182
|
+
return ax
|
1183
|
+
return None
|