pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_dialogue.py
CHANGED
@@ -2,13 +2,14 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import itertools
|
4
4
|
from collections import defaultdict
|
5
|
-
from typing import Any, Literal
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal
|
6
6
|
|
7
7
|
import anndata as ad
|
8
|
+
import matplotlib.pyplot as plt
|
8
9
|
import numpy as np
|
9
10
|
import pandas as pd
|
10
11
|
import scanpy as sc
|
11
|
-
import
|
12
|
+
import seaborn as sns
|
12
13
|
import statsmodels.formula.api as smf
|
13
14
|
import statsmodels.stats.multitest as ssm
|
14
15
|
from anndata import AnnData
|
@@ -19,10 +20,15 @@ from rich.live import Live
|
|
19
20
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
20
21
|
from scipy import stats
|
21
22
|
from scipy.optimize import nnls
|
23
|
+
from seaborn import PairGrid
|
22
24
|
from sklearn.linear_model import LinearRegression
|
23
25
|
from sparsecca import lp_pmd, multicca_permute, multicca_pmd
|
24
26
|
from statsmodels.sandbox.stats.multicomp import multipletests
|
25
27
|
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from matplotlib.axes import Axes
|
30
|
+
from matplotlib.figure import Figure
|
31
|
+
|
26
32
|
|
27
33
|
class Dialogue:
|
28
34
|
"""Python implementation of DIALOGUE"""
|
@@ -53,8 +59,6 @@ class Dialogue:
|
|
53
59
|
|
54
60
|
Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
|
55
61
|
|
56
|
-
# TODO: Replace with decoupler's implementation
|
57
|
-
|
58
62
|
Args:
|
59
63
|
groupby: The key to groupby for pseudobulks
|
60
64
|
strategy: The pseudobulking strategy. One of "median" or "mean"
|
@@ -62,14 +66,15 @@ class Dialogue:
|
|
62
66
|
Returns:
|
63
67
|
A Pandas DataFrame of pseudobulk counts
|
64
68
|
"""
|
69
|
+
# TODO: Replace with decoupler's implementation
|
65
70
|
pseudobulk = {"Genes": adata.var_names.values}
|
66
71
|
|
67
72
|
for category in adata.obs.loc[:, groupby].cat.categories:
|
68
73
|
temp = adata.obs.loc[:, groupby] == category
|
69
74
|
if strategy == "median":
|
70
|
-
pseudobulk[category] = adata[temp].X.median(axis=0)
|
75
|
+
pseudobulk[category] = adata[temp].X.median(axis=0)
|
71
76
|
elif strategy == "mean":
|
72
|
-
pseudobulk[category] = adata[temp].X.mean(axis=0)
|
77
|
+
pseudobulk[category] = adata[temp].X.mean(axis=0)
|
73
78
|
|
74
79
|
pseudobulk = pd.DataFrame(pseudobulk).set_index("Genes")
|
75
80
|
|
@@ -101,8 +106,6 @@ class Dialogue:
|
|
101
106
|
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
|
102
107
|
"""Row-wise mean center and scale by the standard deviation.
|
103
108
|
|
104
|
-
TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
|
105
|
-
|
106
109
|
Args:
|
107
110
|
pseudobulks: The pseudobulk PCA components.
|
108
111
|
normalize: Whether to mimic DIALOGUE behavior or not.
|
@@ -110,9 +113,9 @@ class Dialogue:
|
|
110
113
|
Returns:
|
111
114
|
The scaled count matrix.
|
112
115
|
"""
|
116
|
+
# TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
|
113
117
|
# DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
|
114
118
|
# However, performing this scaling _does_ increase overall correlation of the end result
|
115
|
-
# WHEN SAMPLE ORDER AND DIALOGUE2+3 PROCESSING IS IGNORED.
|
116
119
|
if normalize:
|
117
120
|
return pseudobulks.to_numpy()
|
118
121
|
else:
|
@@ -313,13 +316,13 @@ class Dialogue:
|
|
313
316
|
def _apply_HLM_per_MCP_for_one_pair(
|
314
317
|
self,
|
315
318
|
mcp_name: str,
|
316
|
-
scores_df:
|
319
|
+
scores_df: pd.DataFrame,
|
317
320
|
ct_data: AnnData,
|
318
321
|
tme: pd.DataFrame,
|
319
322
|
sig: dict,
|
320
323
|
n_counts: str,
|
321
324
|
formula: str,
|
322
|
-
confounder: str,
|
325
|
+
confounder: str | None,
|
323
326
|
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
324
327
|
"""Applies hierarchical modeling for a single MCP.
|
325
328
|
|
@@ -340,7 +343,7 @@ class Dialogue:
|
|
340
343
|
"""
|
341
344
|
HLM_result = self._mixed_effects(
|
342
345
|
scores=scores_df[[mcp_name]],
|
343
|
-
x_labels=ct_data.obs[[n_counts, confounder]],
|
346
|
+
x_labels=ct_data.obs[[n_counts, confounder]] if confounder else ct_data.obs[[n_counts]],
|
344
347
|
tme=tme,
|
345
348
|
genes_in_mcp=list(sig[mcp_name]["up"]) + list(sig[mcp_name]["down"]),
|
346
349
|
formula=formula,
|
@@ -367,7 +370,7 @@ class Dialogue:
|
|
367
370
|
return np.array(resid)
|
368
371
|
|
369
372
|
def _iterative_nnls(self, A_orig: np.ndarray, y_orig: np.ndarray, feature_ranks: list[int], n_iter: int = 1000):
|
370
|
-
"""Solves non-negative least
|
373
|
+
"""Solves non-negative least-squares separately for different feature categories.
|
371
374
|
|
372
375
|
Mimics DLG.iterative.nnls.
|
373
376
|
Variables are notated according to:
|
@@ -398,7 +401,7 @@ class Dialogue:
|
|
398
401
|
|
399
402
|
x_final = np.zeros(A_orig.shape[0])
|
400
403
|
Ax = np.zeros(A_orig.shape[1])
|
401
|
-
for _, mask in zip(sig_ranks, masks):
|
404
|
+
for _, mask in zip(sig_ranks, masks, strict=False):
|
402
405
|
A = A_orig[mask].T
|
403
406
|
coef_nnls, _ = nnls(A, y, maxiter=n_iter)
|
404
407
|
y = y - A @ coef_nnls # residuals
|
@@ -516,8 +519,8 @@ class Dialogue:
|
|
516
519
|
# TODO: probably format the up and down within get_top_elements
|
517
520
|
cca_sig: dict[str, Any] = defaultdict(dict)
|
518
521
|
for i in range(0, int(len(cca_sig_unformatted) / 2)):
|
519
|
-
cca_sig[f"MCP{i
|
520
|
-
cca_sig[f"MCP{i
|
522
|
+
cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
|
523
|
+
cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
|
521
524
|
|
522
525
|
cca_sig = dict(cca_sig)
|
523
526
|
cca_sig_results[ct] = cca_sig
|
@@ -555,7 +558,7 @@ class Dialogue:
|
|
555
558
|
|
556
559
|
return cca_sig_results, new_mcp_scores
|
557
560
|
|
558
|
-
def
|
561
|
+
def _load(
|
559
562
|
self,
|
560
563
|
adata: AnnData,
|
561
564
|
ct_order: list[str],
|
@@ -574,16 +577,6 @@ class Dialogue:
|
|
574
577
|
|
575
578
|
Returns:
|
576
579
|
A celltype_label:array dictionary.
|
577
|
-
|
578
|
-
Examples:
|
579
|
-
>>> import pertpy as pt
|
580
|
-
>>> import scanpy as sc
|
581
|
-
>>> adata = pt.dt.dialogue_example()
|
582
|
-
>>> sc.pp.pca(adata)
|
583
|
-
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
584
|
-
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
585
|
-
>>> cell_types = adata.obs[dl.celltype_key].astype("category").cat.categories
|
586
|
-
>>> mcca_in, ct_subs = dl.load(adata, ct_order=cell_types)
|
587
580
|
"""
|
588
581
|
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
|
589
582
|
fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
|
@@ -631,19 +624,19 @@ class Dialogue:
|
|
631
624
|
>>> import scanpy as sc
|
632
625
|
>>> adata = pt.dt.dialogue_example()
|
633
626
|
>>> sc.pp.pca(adata)
|
634
|
-
>>> dl = pt.tl.Dialogue(
|
635
|
-
|
627
|
+
>>> dl = pt.tl.Dialogue(
|
628
|
+
... sample_id="clinical.status", celltype_key="cell.subtypes", n_counts_key="nCount_RNA", n_mpcs=3
|
629
|
+
... )
|
636
630
|
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
637
631
|
"""
|
638
|
-
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
|
639
|
-
# it is important here that to obtain the same result as in R, we pass the matrices in
|
640
|
-
# in the same order.
|
632
|
+
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
|
633
|
+
# As such, it is important here that to obtain the same result as in R, we pass the matrices in the same order.
|
641
634
|
if ct_order is not None:
|
642
635
|
cell_types = ct_order
|
643
636
|
else:
|
644
637
|
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
|
645
638
|
|
646
|
-
mcca_in, ct_subs = self.
|
639
|
+
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
|
647
640
|
|
648
641
|
n_samples = mcca_in[0].shape[1]
|
649
642
|
if penalties is None:
|
@@ -685,7 +678,7 @@ class Dialogue:
|
|
685
678
|
ct_subs: dict,
|
686
679
|
mcp_scores: dict,
|
687
680
|
ws_dict: dict,
|
688
|
-
confounder: str,
|
681
|
+
confounder: str | None,
|
689
682
|
formula: str = None,
|
690
683
|
):
|
691
684
|
"""Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
|
@@ -700,7 +693,6 @@ class Dialogue:
|
|
700
693
|
A Pandas DataFrame containing:
|
701
694
|
- for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
|
702
695
|
- merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
|
703
|
-
TODO: Describe both returns
|
704
696
|
|
705
697
|
Examples:
|
706
698
|
>>> import pertpy as pt
|
@@ -713,7 +705,9 @@ class Dialogue:
|
|
713
705
|
>>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \
|
714
706
|
confounder="gender")
|
715
707
|
"""
|
716
|
-
#
|
708
|
+
# TODO the returns of the function better
|
709
|
+
|
710
|
+
# all possible pairs of cell types without pairing same cell type
|
717
711
|
cell_types = list(ct_subs.keys())
|
718
712
|
pairs = list(itertools.combinations(cell_types, 2))
|
719
713
|
|
@@ -721,9 +715,9 @@ class Dialogue:
|
|
721
715
|
formula = f"y ~ x + {self.n_counts_key}"
|
722
716
|
|
723
717
|
# Hierarchical modeling expects DataFrames
|
724
|
-
mcp_cell_types = {f"MCP{i
|
718
|
+
mcp_cell_types = {f"MCP{i}": cell_types for i in range(self.n_mcps)}
|
725
719
|
mcp_scores_df = {
|
726
|
-
ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=mcp_cell_types.keys())
|
720
|
+
ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=list(mcp_cell_types.keys()))
|
727
721
|
for ct, v in mcp_scores.items()
|
728
722
|
}
|
729
723
|
|
@@ -805,7 +799,7 @@ class Dialogue:
|
|
805
799
|
for mcp in mcps:
|
806
800
|
mixed_model_progress.update(mm_task, description=f"[bold blue]Determining mixed effects for {mcp}")
|
807
801
|
|
808
|
-
# TODO Check
|
802
|
+
# TODO Check whether the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
|
809
803
|
result = {}
|
810
804
|
result["HLM_result_1"], result["sig_genes_1"] = self._apply_HLM_per_MCP_for_one_pair(
|
811
805
|
mcp_name=mcp,
|
@@ -875,22 +869,19 @@ class Dialogue:
|
|
875
869
|
sample_label = self.sample_id
|
876
870
|
n_mcps = self.n_mcps
|
877
871
|
|
878
|
-
# create conditions_compare if not supplied
|
879
872
|
if conditions_compare is None:
|
880
|
-
conditions_compare = list(adata.obs[
|
873
|
+
conditions_compare = list(adata.obs[condition_label].cat.categories) # type: ignore
|
881
874
|
if len(conditions_compare) != 2:
|
882
875
|
raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
|
883
876
|
|
884
|
-
# create data frames to store results
|
885
877
|
pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
886
878
|
tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
887
879
|
pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
888
880
|
|
889
881
|
response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
|
890
882
|
for celltype in adata.obs[celltype_label].unique():
|
891
|
-
# subset data to cell type
|
892
883
|
df = adata.obs[adata.obs[celltype_label] == celltype]
|
893
|
-
|
884
|
+
|
894
885
|
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
|
895
886
|
mns = df.groupby(sample_label)[mcpnum].mean()
|
896
887
|
mns = pd.concat([mns, response], axis=1)
|
@@ -900,11 +891,10 @@ class Dialogue:
|
|
900
891
|
)
|
901
892
|
pvals.loc[celltype, mcpnum] = res[1]
|
902
893
|
tstats.loc[celltype, mcpnum] = res[0]
|
903
|
-
# return(res)
|
904
894
|
|
905
|
-
# benjamini-hochberg correction for number of cell types (use BH because correlated MCPs)
|
906
895
|
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
|
907
896
|
pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
|
897
|
+
|
908
898
|
return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
|
909
899
|
|
910
900
|
def get_mlm_mcp_genes(
|
@@ -921,7 +911,7 @@ class Dialogue:
|
|
921
911
|
celltype: Cell type of interest.
|
922
912
|
results: dl.MultilevelModeling result object.
|
923
913
|
MCP: MCP key of the result object.
|
924
|
-
|
914
|
+
threshold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
|
925
915
|
Defaults to 0.70.
|
926
916
|
focal_celltypes: None (compare against all cell types) or a list of other cell types which you want to compare against.
|
927
917
|
Defaults to None.
|
@@ -945,7 +935,6 @@ class Dialogue:
|
|
945
935
|
# REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD
|
946
936
|
if MCP.startswith("mcp_"):
|
947
937
|
MCP = MCP.replace("mcp_", "MCP")
|
948
|
-
# convert from MCPx to MCPx+1
|
949
938
|
MCP = "MCP" + str(int(MCP[3:]) - 1)
|
950
939
|
|
951
940
|
# Extract all comparison keys from the results object
|
@@ -1014,17 +1003,16 @@ class Dialogue:
|
|
1014
1003
|
objects containing the results of gene ranking analysis.
|
1015
1004
|
|
1016
1005
|
Examples:
|
1017
|
-
ct_subs = {
|
1018
|
-
"subpop1": anndata_obj1,
|
1019
|
-
"subpop2": anndata_obj2,
|
1020
|
-
# ... more subpopulations ...
|
1021
|
-
}
|
1022
|
-
genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
|
1006
|
+
>>> ct_subs = {
|
1007
|
+
... "subpop1": anndata_obj1,
|
1008
|
+
... "subpop2": anndata_obj2,
|
1009
|
+
... # ... more subpopulations ...
|
1010
|
+
... }
|
1011
|
+
>>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
|
1023
1012
|
"""
|
1024
1013
|
genes = {}
|
1025
1014
|
for ct in ct_subs.keys():
|
1026
1015
|
mini = ct_subs[ct]
|
1027
|
-
mini.obs[mcp]
|
1028
1016
|
mini.obs["extrema"] = pd.qcut(
|
1029
1017
|
mini.obs[mcp],
|
1030
1018
|
[0, 0 + fraction, 1 - fraction, 1.0],
|
@@ -1034,6 +1022,7 @@ class Dialogue:
|
|
1034
1022
|
mini, "extrema", groups=["high" + mcp + " " + ct], reference="low " + mcp + " " + ct
|
1035
1023
|
)
|
1036
1024
|
genes[ct] = mini # .uns['rank_genes_groups']
|
1025
|
+
|
1037
1026
|
return genes
|
1038
1027
|
|
1039
1028
|
def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
|
@@ -1064,7 +1053,7 @@ class Dialogue:
|
|
1064
1053
|
>>> extrema_mcp_genes = dl.get_extrema_MCP_genes(ct_subs)
|
1065
1054
|
"""
|
1066
1055
|
rank_dfs: dict[str, dict[Any, Any]] = {}
|
1067
|
-
|
1056
|
+
ct_sub = next(iter(ct_subs.values()))
|
1068
1057
|
mcps = [col for col in ct_sub.obs.columns if col.startswith("mcp_")]
|
1069
1058
|
|
1070
1059
|
for mcp in mcps:
|
@@ -1072,4 +1061,123 @@ class Dialogue:
|
|
1072
1061
|
ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
|
1073
1062
|
for celltype in ct_ranked.keys():
|
1074
1063
|
rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
|
1064
|
+
|
1075
1065
|
return rank_dfs
|
1066
|
+
|
1067
|
+
def plot_split_violins(
|
1068
|
+
self,
|
1069
|
+
adata: AnnData,
|
1070
|
+
split_key: str,
|
1071
|
+
celltype_key: str,
|
1072
|
+
split_which: tuple[str, str] = None,
|
1073
|
+
mcp: str = "mcp_0",
|
1074
|
+
return_fig: bool | None = None,
|
1075
|
+
ax: Axes | None = None,
|
1076
|
+
save: bool | str | None = None,
|
1077
|
+
show: bool | None = None,
|
1078
|
+
) -> Axes | Figure | None:
|
1079
|
+
"""Plots split violin plots for a given MCP and split variable.
|
1080
|
+
|
1081
|
+
Any cells with a value for split_key not in split_which are removed from the plot.
|
1082
|
+
|
1083
|
+
Args:
|
1084
|
+
adata: Annotated data object.
|
1085
|
+
split_key: Variable in adata.obs used to split the data.
|
1086
|
+
celltype_key: Key for cell type annotations.
|
1087
|
+
split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
|
1088
|
+
mcp: Key for MCP data. Defaults to "mcp_0".
|
1089
|
+
|
1090
|
+
Returns:
|
1091
|
+
A :class:`~matplotlib.axes.Axes` object
|
1092
|
+
|
1093
|
+
Examples:
|
1094
|
+
>>> import pertpy as pt
|
1095
|
+
>>> import scanpy as sc
|
1096
|
+
>>> adata = pt.dt.dialogue_example()
|
1097
|
+
>>> sc.pp.pca(adata)
|
1098
|
+
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
1099
|
+
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
1100
|
+
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
1101
|
+
>>> dl.plot_split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
|
1102
|
+
|
1103
|
+
Preview:
|
1104
|
+
.. image:: /_static/docstring_previews/dialogue_violin.png
|
1105
|
+
"""
|
1106
|
+
df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
|
1107
|
+
if split_which is None:
|
1108
|
+
split_which = df[split_key].unique()
|
1109
|
+
df = df[df[split_key].isin(split_which)]
|
1110
|
+
df[split_key] = df[split_key].cat.remove_unused_categories()
|
1111
|
+
|
1112
|
+
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
1113
|
+
|
1114
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
1115
|
+
|
1116
|
+
if save:
|
1117
|
+
plt.savefig(save, bbox_inches="tight")
|
1118
|
+
if show:
|
1119
|
+
plt.show()
|
1120
|
+
if return_fig:
|
1121
|
+
return plt.gcf()
|
1122
|
+
if not (show or save):
|
1123
|
+
return ax
|
1124
|
+
return None
|
1125
|
+
|
1126
|
+
def plot_pairplot(
|
1127
|
+
self,
|
1128
|
+
adata: AnnData,
|
1129
|
+
celltype_key: str,
|
1130
|
+
color: str,
|
1131
|
+
sample_id: str,
|
1132
|
+
mcp: str = "mcp_0",
|
1133
|
+
return_fig: bool | None = None,
|
1134
|
+
show: bool | None = None,
|
1135
|
+
save: bool | str | None = None,
|
1136
|
+
) -> PairGrid | Figure | None:
|
1137
|
+
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
1138
|
+
|
1139
|
+
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
|
1140
|
+
then creates a pairplot to visualize the relationships between these mean MCP values.
|
1141
|
+
|
1142
|
+
Args:
|
1143
|
+
adata: Annotated data object.
|
1144
|
+
celltype_key: Key in `adata.obs` containing cell type annotations.
|
1145
|
+
color: Key in `adata.obs` for color annotations. This parameter is used as the hue
|
1146
|
+
sample_id: Key in `adata.obs` for the sample annotations.
|
1147
|
+
mcp: Key in `adata.obs` for MCP feature values. Defaults to `"mcp_0"`.
|
1148
|
+
|
1149
|
+
Returns:
|
1150
|
+
Seaborn Pairgrid object.
|
1151
|
+
|
1152
|
+
Examples:
|
1153
|
+
>>> import pertpy as pt
|
1154
|
+
>>> import scanpy as sc
|
1155
|
+
>>> adata = pt.dt.dialogue_example()
|
1156
|
+
>>> sc.pp.pca(adata)
|
1157
|
+
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
1158
|
+
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
1159
|
+
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
1160
|
+
>>> dl.plot_pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
|
1161
|
+
|
1162
|
+
Preview:
|
1163
|
+
.. image:: /_static/docstring_previews/dialogue_pairplot.png
|
1164
|
+
"""
|
1165
|
+
mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
|
1166
|
+
mean_mcps = mean_mcps.reset_index()
|
1167
|
+
mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
|
1168
|
+
|
1169
|
+
aggstats = adata.obs.groupby([sample_id])[color].describe()
|
1170
|
+
aggstats = aggstats.loc[list(mcp_pivot.index), :]
|
1171
|
+
aggstats[color] = aggstats["top"]
|
1172
|
+
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
1173
|
+
ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
|
1174
|
+
|
1175
|
+
if save:
|
1176
|
+
plt.savefig(save, bbox_inches="tight")
|
1177
|
+
if show:
|
1178
|
+
plt.show()
|
1179
|
+
if return_fig:
|
1180
|
+
return plt.gcf()
|
1181
|
+
if not (show or save):
|
1182
|
+
return ax
|
1183
|
+
return None
|