pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_dialogue.py
CHANGED
@@ -2,27 +2,33 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import itertools
|
4
4
|
from collections import defaultdict
|
5
|
-
from typing import Any, Literal
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal
|
6
6
|
|
7
7
|
import anndata as ad
|
8
|
+
import matplotlib.pyplot as plt
|
8
9
|
import numpy as np
|
9
10
|
import pandas as pd
|
10
11
|
import scanpy as sc
|
11
|
-
import
|
12
|
+
import seaborn as sns
|
12
13
|
import statsmodels.formula.api as smf
|
13
14
|
import statsmodels.stats.multitest as ssm
|
14
15
|
from anndata import AnnData
|
16
|
+
from lamin_utils import logger
|
15
17
|
from pandas import DataFrame
|
16
|
-
from rich import print
|
17
18
|
from rich.console import Group
|
18
19
|
from rich.live import Live
|
19
20
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
20
21
|
from scipy import stats
|
21
22
|
from scipy.optimize import nnls
|
23
|
+
from seaborn import PairGrid
|
22
24
|
from sklearn.linear_model import LinearRegression
|
23
25
|
from sparsecca import lp_pmd, multicca_permute, multicca_pmd
|
24
26
|
from statsmodels.sandbox.stats.multicomp import multipletests
|
25
27
|
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from matplotlib.axes import Axes
|
30
|
+
from matplotlib.figure import Figure
|
31
|
+
|
26
32
|
|
27
33
|
class Dialogue:
|
28
34
|
"""Python implementation of DIALOGUE"""
|
@@ -53,8 +59,6 @@ class Dialogue:
|
|
53
59
|
|
54
60
|
Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
|
55
61
|
|
56
|
-
# TODO: Replace with decoupler's implementation
|
57
|
-
|
58
62
|
Args:
|
59
63
|
groupby: The key to groupby for pseudobulks
|
60
64
|
strategy: The pseudobulking strategy. One of "median" or "mean"
|
@@ -62,14 +66,15 @@ class Dialogue:
|
|
62
66
|
Returns:
|
63
67
|
A Pandas DataFrame of pseudobulk counts
|
64
68
|
"""
|
69
|
+
# TODO: Replace with decoupler's implementation
|
65
70
|
pseudobulk = {"Genes": adata.var_names.values}
|
66
71
|
|
67
72
|
for category in adata.obs.loc[:, groupby].cat.categories:
|
68
73
|
temp = adata.obs.loc[:, groupby] == category
|
69
74
|
if strategy == "median":
|
70
|
-
pseudobulk[category] = adata[temp].X.median(axis=0)
|
75
|
+
pseudobulk[category] = adata[temp].X.median(axis=0)
|
71
76
|
elif strategy == "mean":
|
72
|
-
pseudobulk[category] = adata[temp].X.mean(axis=0)
|
77
|
+
pseudobulk[category] = adata[temp].X.mean(axis=0)
|
73
78
|
|
74
79
|
pseudobulk = pd.DataFrame(pseudobulk).set_index("Genes")
|
75
80
|
|
@@ -101,8 +106,6 @@ class Dialogue:
|
|
101
106
|
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
|
102
107
|
"""Row-wise mean center and scale by the standard deviation.
|
103
108
|
|
104
|
-
TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
|
105
|
-
|
106
109
|
Args:
|
107
110
|
pseudobulks: The pseudobulk PCA components.
|
108
111
|
normalize: Whether to mimic DIALOGUE behavior or not.
|
@@ -110,9 +113,9 @@ class Dialogue:
|
|
110
113
|
Returns:
|
111
114
|
The scaled count matrix.
|
112
115
|
"""
|
116
|
+
# TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
|
113
117
|
# DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
|
114
118
|
# However, performing this scaling _does_ increase overall correlation of the end result
|
115
|
-
# WHEN SAMPLE ORDER AND DIALOGUE2+3 PROCESSING IS IGNORED.
|
116
119
|
if normalize:
|
117
120
|
return pseudobulks.to_numpy()
|
118
121
|
else:
|
@@ -288,7 +291,7 @@ class Dialogue:
|
|
288
291
|
mcp_name: Name of mcp which was used for calculation of column value.
|
289
292
|
max_length: Value needed to later decide at what index the threshold value should be extracted from column.
|
290
293
|
min_threshold: Minimal threshold to select final scores by if it is smaller than calculated threshold.
|
291
|
-
index: Column index to use eto calculate the significant genes.
|
294
|
+
index: Column index to use eto calculate the significant genes.
|
292
295
|
|
293
296
|
Returns:
|
294
297
|
According to the values in a df column (default: zscore) the significant up and downregulated gene names
|
@@ -313,13 +316,13 @@ class Dialogue:
|
|
313
316
|
def _apply_HLM_per_MCP_for_one_pair(
|
314
317
|
self,
|
315
318
|
mcp_name: str,
|
316
|
-
scores_df:
|
319
|
+
scores_df: pd.DataFrame,
|
317
320
|
ct_data: AnnData,
|
318
321
|
tme: pd.DataFrame,
|
319
322
|
sig: dict,
|
320
323
|
n_counts: str,
|
321
324
|
formula: str,
|
322
|
-
confounder: str,
|
325
|
+
confounder: str | None,
|
323
326
|
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
324
327
|
"""Applies hierarchical modeling for a single MCP.
|
325
328
|
|
@@ -340,7 +343,7 @@ class Dialogue:
|
|
340
343
|
"""
|
341
344
|
HLM_result = self._mixed_effects(
|
342
345
|
scores=scores_df[[mcp_name]],
|
343
|
-
x_labels=ct_data.obs[[n_counts, confounder]],
|
346
|
+
x_labels=ct_data.obs[[n_counts, confounder]] if confounder else ct_data.obs[[n_counts]],
|
344
347
|
tme=tme,
|
345
348
|
genes_in_mcp=list(sig[mcp_name]["up"]) + list(sig[mcp_name]["down"]),
|
346
349
|
formula=formula,
|
@@ -367,19 +370,13 @@ class Dialogue:
|
|
367
370
|
return np.array(resid)
|
368
371
|
|
369
372
|
def _iterative_nnls(self, A_orig: np.ndarray, y_orig: np.ndarray, feature_ranks: list[int], n_iter: int = 1000):
|
370
|
-
"""Solves non-negative least
|
373
|
+
"""Solves non-negative least-squares separately for different feature categories.
|
371
374
|
|
372
375
|
Mimics DLG.iterative.nnls.
|
373
376
|
Variables are notated according to:
|
374
377
|
|
375
378
|
`argmin|Ax - y|`
|
376
379
|
|
377
|
-
Args:
|
378
|
-
A_orig:
|
379
|
-
y_orig:
|
380
|
-
feature_ranks:
|
381
|
-
n_iter: Passed to scipy.optimize.nnls. Defaults to 1000.
|
382
|
-
|
383
380
|
Returns:
|
384
381
|
Returns the aggregated coefficients from nnls.
|
385
382
|
"""
|
@@ -398,7 +395,7 @@ class Dialogue:
|
|
398
395
|
|
399
396
|
x_final = np.zeros(A_orig.shape[0])
|
400
397
|
Ax = np.zeros(A_orig.shape[1])
|
401
|
-
for _, mask in zip(sig_ranks, masks):
|
398
|
+
for _, mask in zip(sig_ranks, masks, strict=False):
|
402
399
|
A = A_orig[mask].T
|
403
400
|
coef_nnls, _ = nnls(A, y, maxiter=n_iter)
|
404
401
|
y = y - A @ coef_nnls # residuals
|
@@ -516,8 +513,8 @@ class Dialogue:
|
|
516
513
|
# TODO: probably format the up and down within get_top_elements
|
517
514
|
cca_sig: dict[str, Any] = defaultdict(dict)
|
518
515
|
for i in range(0, int(len(cca_sig_unformatted) / 2)):
|
519
|
-
cca_sig[f"MCP{i
|
520
|
-
cca_sig[f"MCP{i
|
516
|
+
cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
|
517
|
+
cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
|
521
518
|
|
522
519
|
cca_sig = dict(cca_sig)
|
523
520
|
cca_sig_results[ct] = cca_sig
|
@@ -555,7 +552,7 @@ class Dialogue:
|
|
555
552
|
|
556
553
|
return cca_sig_results, new_mcp_scores
|
557
554
|
|
558
|
-
def
|
555
|
+
def _load(
|
559
556
|
self,
|
560
557
|
adata: AnnData,
|
561
558
|
ct_order: list[str],
|
@@ -569,21 +566,11 @@ class Dialogue:
|
|
569
566
|
Args:
|
570
567
|
adata: AnnData object generate celltype objects for
|
571
568
|
ct_order: The order of cell types
|
572
|
-
agg_pca: Whether to aggregate pseudobulks with PCA or not.
|
573
|
-
normalize: Whether to mimic DIALOGUE behavior or not.
|
569
|
+
agg_pca: Whether to aggregate pseudobulks with PCA or not.
|
570
|
+
normalize: Whether to mimic DIALOGUE behavior or not.
|
574
571
|
|
575
572
|
Returns:
|
576
573
|
A celltype_label:array dictionary.
|
577
|
-
|
578
|
-
Examples:
|
579
|
-
>>> import pertpy as pt
|
580
|
-
>>> import scanpy as sc
|
581
|
-
>>> adata = pt.dt.dialogue_example()
|
582
|
-
>>> sc.pp.pca(adata)
|
583
|
-
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
584
|
-
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
585
|
-
>>> cell_types = adata.obs[dl.celltype_key].astype("category").cat.categories
|
586
|
-
>>> mcca_in, ct_subs = dl.load(adata, ct_order=cell_types)
|
587
574
|
"""
|
588
575
|
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
|
589
576
|
fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
|
@@ -620,7 +607,6 @@ class Dialogue:
|
|
620
607
|
agg_pca: Whether to calculate cell-averaged PCA components.
|
621
608
|
solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
|
622
609
|
For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
|
623
|
-
Defaults to 'bs'.
|
624
610
|
normalize: Whether to mimic DIALOGUE as close as possible
|
625
611
|
|
626
612
|
Returns:
|
@@ -631,25 +617,31 @@ class Dialogue:
|
|
631
617
|
>>> import scanpy as sc
|
632
618
|
>>> adata = pt.dt.dialogue_example()
|
633
619
|
>>> sc.pp.pca(adata)
|
634
|
-
>>> dl = pt.tl.Dialogue(
|
635
|
-
|
620
|
+
>>> dl = pt.tl.Dialogue(
|
621
|
+
... sample_id="clinical.status", celltype_key="cell.subtypes", n_counts_key="nCount_RNA", n_mpcs=3
|
622
|
+
... )
|
636
623
|
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
637
624
|
"""
|
638
|
-
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
|
639
|
-
# it is important here that to obtain the same result as in R, we pass the matrices in
|
640
|
-
# in the same order.
|
625
|
+
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
|
626
|
+
# As such, it is important here that to obtain the same result as in R, we pass the matrices in the same order.
|
641
627
|
if ct_order is not None:
|
642
628
|
cell_types = ct_order
|
643
629
|
else:
|
644
630
|
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
|
645
631
|
|
646
|
-
mcca_in, ct_subs = self.
|
632
|
+
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
|
647
633
|
|
648
634
|
n_samples = mcca_in[0].shape[1]
|
649
635
|
if penalties is None:
|
650
|
-
|
651
|
-
|
652
|
-
|
636
|
+
try:
|
637
|
+
penalties = multicca_permute(
|
638
|
+
mcca_in, penalties=np.sqrt(n_samples) / 2, nperms=10, niter=50, standardize=True
|
639
|
+
)["bestpenalties"]
|
640
|
+
except ValueError as e:
|
641
|
+
if "matmul: input operand 1 has a mismatch in its core dimension" in str(e):
|
642
|
+
raise ValueError("Please ensure that every cell type is represented in every sample.") from e
|
643
|
+
else:
|
644
|
+
raise
|
653
645
|
else:
|
654
646
|
penalties = penalties
|
655
647
|
|
@@ -685,7 +677,7 @@ class Dialogue:
|
|
685
677
|
ct_subs: dict,
|
686
678
|
mcp_scores: dict,
|
687
679
|
ws_dict: dict,
|
688
|
-
confounder: str,
|
680
|
+
confounder: str | None,
|
689
681
|
formula: str = None,
|
690
682
|
):
|
691
683
|
"""Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
|
@@ -700,7 +692,6 @@ class Dialogue:
|
|
700
692
|
A Pandas DataFrame containing:
|
701
693
|
- for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
|
702
694
|
- merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
|
703
|
-
TODO: Describe both returns
|
704
695
|
|
705
696
|
Examples:
|
706
697
|
>>> import pertpy as pt
|
@@ -713,7 +704,9 @@ class Dialogue:
|
|
713
704
|
>>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \
|
714
705
|
confounder="gender")
|
715
706
|
"""
|
716
|
-
#
|
707
|
+
# TODO the returns of the function better
|
708
|
+
|
709
|
+
# all possible pairs of cell types without pairing same cell type
|
717
710
|
cell_types = list(ct_subs.keys())
|
718
711
|
pairs = list(itertools.combinations(cell_types, 2))
|
719
712
|
|
@@ -721,9 +714,9 @@ class Dialogue:
|
|
721
714
|
formula = f"y ~ x + {self.n_counts_key}"
|
722
715
|
|
723
716
|
# Hierarchical modeling expects DataFrames
|
724
|
-
mcp_cell_types = {f"MCP{i
|
717
|
+
mcp_cell_types = {f"MCP{i}": cell_types for i in range(self.n_mcps)}
|
725
718
|
mcp_scores_df = {
|
726
|
-
ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=mcp_cell_types.keys())
|
719
|
+
ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=list(mcp_cell_types.keys()))
|
727
720
|
for ct, v in mcp_scores.items()
|
728
721
|
}
|
729
722
|
|
@@ -762,10 +755,10 @@ class Dialogue:
|
|
762
755
|
mcps.append(mcp)
|
763
756
|
|
764
757
|
if len(mcps) == 0:
|
765
|
-
|
758
|
+
logger.warning(f"No shared MCPs between {cell_type_1} and {cell_type_2}.")
|
766
759
|
continue
|
767
760
|
|
768
|
-
|
761
|
+
logger.info(f"{len(mcps)} MCPs identified for {cell_type_1} and {cell_type_2}.")
|
769
762
|
|
770
763
|
new_mcp_scores: dict[Any, list[Any]]
|
771
764
|
cca_sig, new_mcp_scores = self._calculate_cca_sig(
|
@@ -805,7 +798,7 @@ class Dialogue:
|
|
805
798
|
for mcp in mcps:
|
806
799
|
mixed_model_progress.update(mm_task, description=f"[bold blue]Determining mixed effects for {mcp}")
|
807
800
|
|
808
|
-
# TODO Check
|
801
|
+
# TODO Check whether the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
|
809
802
|
result = {}
|
810
803
|
result["HLM_result_1"], result["sig_genes_1"] = self._apply_HLM_per_MCP_for_one_pair(
|
811
804
|
mcp_name=mcp,
|
@@ -875,22 +868,19 @@ class Dialogue:
|
|
875
868
|
sample_label = self.sample_id
|
876
869
|
n_mcps = self.n_mcps
|
877
870
|
|
878
|
-
# create conditions_compare if not supplied
|
879
871
|
if conditions_compare is None:
|
880
|
-
conditions_compare = list(adata.obs[
|
872
|
+
conditions_compare = list(adata.obs[condition_label].cat.categories) # type: ignore
|
881
873
|
if len(conditions_compare) != 2:
|
882
874
|
raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
|
883
875
|
|
884
|
-
# create data frames to store results
|
885
876
|
pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
886
877
|
tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
887
878
|
pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
888
879
|
|
889
880
|
response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
|
890
881
|
for celltype in adata.obs[celltype_label].unique():
|
891
|
-
# subset data to cell type
|
892
882
|
df = adata.obs[adata.obs[celltype_label] == celltype]
|
893
|
-
|
883
|
+
|
894
884
|
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
|
895
885
|
mns = df.groupby(sample_label)[mcpnum].mean()
|
896
886
|
mns = pd.concat([mns, response], axis=1)
|
@@ -900,11 +890,10 @@ class Dialogue:
|
|
900
890
|
)
|
901
891
|
pvals.loc[celltype, mcpnum] = res[1]
|
902
892
|
tstats.loc[celltype, mcpnum] = res[0]
|
903
|
-
# return(res)
|
904
893
|
|
905
|
-
# benjamini-hochberg correction for number of cell types (use BH because correlated MCPs)
|
906
894
|
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
|
907
895
|
pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
|
896
|
+
|
908
897
|
return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
|
909
898
|
|
910
899
|
def get_mlm_mcp_genes(
|
@@ -921,10 +910,8 @@ class Dialogue:
|
|
921
910
|
celltype: Cell type of interest.
|
922
911
|
results: dl.MultilevelModeling result object.
|
923
912
|
MCP: MCP key of the result object.
|
924
|
-
|
925
|
-
Defaults to 0.70.
|
913
|
+
threshold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
|
926
914
|
focal_celltypes: None (compare against all cell types) or a list of other cell types which you want to compare against.
|
927
|
-
Defaults to None.
|
928
915
|
|
929
916
|
Returns:
|
930
917
|
Dict with keys 'up_genes' and 'down_genes' and values of lists of genes
|
@@ -945,7 +932,6 @@ class Dialogue:
|
|
945
932
|
# REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD
|
946
933
|
if MCP.startswith("mcp_"):
|
947
934
|
MCP = MCP.replace("mcp_", "MCP")
|
948
|
-
# convert from MCPx to MCPx+1
|
949
935
|
MCP = "MCP" + str(int(MCP[3:]) - 1)
|
950
936
|
|
951
937
|
# Extract all comparison keys from the results object
|
@@ -1004,27 +990,24 @@ class Dialogue:
|
|
1004
990
|
Args:
|
1005
991
|
ct_subs: Dialogue output ct_subs dictionary
|
1006
992
|
mcp: The name of the marker gene expression column.
|
1007
|
-
Defaults to "mcp_0".
|
1008
993
|
fraction: Fraction of extreme cells to consider for gene ranking.
|
1009
994
|
Should be between 0 and 1.
|
1010
|
-
Defaults to 0.1.
|
1011
995
|
|
1012
996
|
Returns:
|
1013
997
|
Dictionary where keys are subpopulation names and values are Anndata
|
1014
998
|
objects containing the results of gene ranking analysis.
|
1015
999
|
|
1016
1000
|
Examples:
|
1017
|
-
ct_subs = {
|
1018
|
-
"subpop1": anndata_obj1,
|
1019
|
-
"subpop2": anndata_obj2,
|
1020
|
-
# ... more subpopulations ...
|
1021
|
-
}
|
1022
|
-
genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
|
1001
|
+
>>> ct_subs = {
|
1002
|
+
... "subpop1": anndata_obj1,
|
1003
|
+
... "subpop2": anndata_obj2,
|
1004
|
+
... # ... more subpopulations ...
|
1005
|
+
... }
|
1006
|
+
>>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
|
1023
1007
|
"""
|
1024
1008
|
genes = {}
|
1025
1009
|
for ct in ct_subs.keys():
|
1026
1010
|
mini = ct_subs[ct]
|
1027
|
-
mini.obs[mcp]
|
1028
1011
|
mini.obs["extrema"] = pd.qcut(
|
1029
1012
|
mini.obs[mcp],
|
1030
1013
|
[0, 0 + fraction, 1 - fraction, 1.0],
|
@@ -1034,6 +1017,7 @@ class Dialogue:
|
|
1034
1017
|
mini, "extrema", groups=["high" + mcp + " " + ct], reference="low " + mcp + " " + ct
|
1035
1018
|
)
|
1036
1019
|
genes[ct] = mini # .uns['rank_genes_groups']
|
1020
|
+
|
1037
1021
|
return genes
|
1038
1022
|
|
1039
1023
|
def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
|
@@ -1046,7 +1030,7 @@ class Dialogue:
|
|
1046
1030
|
Args:
|
1047
1031
|
ct_subs: Dialogue output ct_subs dictionary
|
1048
1032
|
fraction: Fraction of extreme cells to consider for gene ranking.
|
1049
|
-
Should be between 0 and 1.
|
1033
|
+
Should be between 0 and 1.
|
1050
1034
|
|
1051
1035
|
Returns:
|
1052
1036
|
Nested dictionary where keys of the first level are MCPs (of the form "mcp_0" etc)
|
@@ -1064,7 +1048,7 @@ class Dialogue:
|
|
1064
1048
|
>>> extrema_mcp_genes = dl.get_extrema_MCP_genes(ct_subs)
|
1065
1049
|
"""
|
1066
1050
|
rank_dfs: dict[str, dict[Any, Any]] = {}
|
1067
|
-
|
1051
|
+
ct_sub = next(iter(ct_subs.values()))
|
1068
1052
|
mcps = [col for col in ct_sub.obs.columns if col.startswith("mcp_")]
|
1069
1053
|
|
1070
1054
|
for mcp in mcps:
|
@@ -1072,4 +1056,123 @@ class Dialogue:
|
|
1072
1056
|
ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
|
1073
1057
|
for celltype in ct_ranked.keys():
|
1074
1058
|
rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
|
1059
|
+
|
1075
1060
|
return rank_dfs
|
1061
|
+
|
1062
|
+
def plot_split_violins(
|
1063
|
+
self,
|
1064
|
+
adata: AnnData,
|
1065
|
+
split_key: str,
|
1066
|
+
celltype_key: str,
|
1067
|
+
split_which: tuple[str, str] = None,
|
1068
|
+
mcp: str = "mcp_0",
|
1069
|
+
return_fig: bool | None = None,
|
1070
|
+
ax: Axes | None = None,
|
1071
|
+
save: bool | str | None = None,
|
1072
|
+
show: bool | None = None,
|
1073
|
+
) -> Axes | Figure | None:
|
1074
|
+
"""Plots split violin plots for a given MCP and split variable.
|
1075
|
+
|
1076
|
+
Any cells with a value for split_key not in split_which are removed from the plot.
|
1077
|
+
|
1078
|
+
Args:
|
1079
|
+
adata: Annotated data object.
|
1080
|
+
split_key: Variable in adata.obs used to split the data.
|
1081
|
+
celltype_key: Key for cell type annotations.
|
1082
|
+
split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
|
1083
|
+
mcp: Key for MCP data.
|
1084
|
+
|
1085
|
+
Returns:
|
1086
|
+
A :class:`~matplotlib.axes.Axes` object
|
1087
|
+
|
1088
|
+
Examples:
|
1089
|
+
>>> import pertpy as pt
|
1090
|
+
>>> import scanpy as sc
|
1091
|
+
>>> adata = pt.dt.dialogue_example()
|
1092
|
+
>>> sc.pp.pca(adata)
|
1093
|
+
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
1094
|
+
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
1095
|
+
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
1096
|
+
>>> dl.plot_split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
|
1097
|
+
|
1098
|
+
Preview:
|
1099
|
+
.. image:: /_static/docstring_previews/dialogue_violin.png
|
1100
|
+
"""
|
1101
|
+
df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
|
1102
|
+
if split_which is None:
|
1103
|
+
split_which = df[split_key].unique()
|
1104
|
+
df = df[df[split_key].isin(split_which)]
|
1105
|
+
df[split_key] = df[split_key].cat.remove_unused_categories()
|
1106
|
+
|
1107
|
+
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
1108
|
+
|
1109
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
1110
|
+
|
1111
|
+
if save:
|
1112
|
+
plt.savefig(save, bbox_inches="tight")
|
1113
|
+
if show:
|
1114
|
+
plt.show()
|
1115
|
+
if return_fig:
|
1116
|
+
return plt.gcf()
|
1117
|
+
if not (show or save):
|
1118
|
+
return ax
|
1119
|
+
return None
|
1120
|
+
|
1121
|
+
def plot_pairplot(
|
1122
|
+
self,
|
1123
|
+
adata: AnnData,
|
1124
|
+
celltype_key: str,
|
1125
|
+
color: str,
|
1126
|
+
sample_id: str,
|
1127
|
+
mcp: str = "mcp_0",
|
1128
|
+
return_fig: bool | None = None,
|
1129
|
+
show: bool | None = None,
|
1130
|
+
save: bool | str | None = None,
|
1131
|
+
) -> PairGrid | Figure | None:
|
1132
|
+
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
1133
|
+
|
1134
|
+
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
|
1135
|
+
then creates a pairplot to visualize the relationships between these mean MCP values.
|
1136
|
+
|
1137
|
+
Args:
|
1138
|
+
adata: Annotated data object.
|
1139
|
+
celltype_key: Key in `adata.obs` containing cell type annotations.
|
1140
|
+
color: Key in `adata.obs` for color annotations. This parameter is used as the hue
|
1141
|
+
sample_id: Key in `adata.obs` for the sample annotations.
|
1142
|
+
mcp: Key in `adata.obs` for MCP feature values.
|
1143
|
+
|
1144
|
+
Returns:
|
1145
|
+
Seaborn Pairgrid object.
|
1146
|
+
|
1147
|
+
Examples:
|
1148
|
+
>>> import pertpy as pt
|
1149
|
+
>>> import scanpy as sc
|
1150
|
+
>>> adata = pt.dt.dialogue_example()
|
1151
|
+
>>> sc.pp.pca(adata)
|
1152
|
+
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
1153
|
+
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
1154
|
+
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
1155
|
+
>>> dl.plot_pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
|
1156
|
+
|
1157
|
+
Preview:
|
1158
|
+
.. image:: /_static/docstring_previews/dialogue_pairplot.png
|
1159
|
+
"""
|
1160
|
+
mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
|
1161
|
+
mean_mcps = mean_mcps.reset_index()
|
1162
|
+
mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
|
1163
|
+
|
1164
|
+
aggstats = adata.obs.groupby([sample_id])[color].describe()
|
1165
|
+
aggstats = aggstats.loc[list(mcp_pivot.index), :]
|
1166
|
+
aggstats[color] = aggstats["top"]
|
1167
|
+
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
1168
|
+
ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
|
1169
|
+
|
1170
|
+
if save:
|
1171
|
+
plt.savefig(save, bbox_inches="tight")
|
1172
|
+
if show:
|
1173
|
+
plt.show()
|
1174
|
+
if return_fig:
|
1175
|
+
return plt.gcf()
|
1176
|
+
if not (show or save):
|
1177
|
+
return ax
|
1178
|
+
return None
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from ._base import ContrastType, LinearModelBase, MethodBase
|
2
|
+
from ._dge_comparison import DGEEVAL
|
3
|
+
from ._edger import EdgeR
|
4
|
+
from ._pydeseq2 import PyDESeq2
|
5
|
+
from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
|
6
|
+
from ._statsmodels import Statsmodels
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"MethodBase",
|
10
|
+
"LinearModelBase",
|
11
|
+
"EdgeR",
|
12
|
+
"PyDESeq2",
|
13
|
+
"Statsmodels",
|
14
|
+
"SimpleComparisonBase",
|
15
|
+
"WilcoxonTest",
|
16
|
+
"TTest",
|
17
|
+
"ContrastType",
|
18
|
+
]
|
19
|
+
|
20
|
+
AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
|