pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_dialogue.py
CHANGED
@@ -2,27 +2,33 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import itertools
|
4
4
|
from collections import defaultdict
|
5
|
-
from typing import Any, Literal
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal
|
6
6
|
|
7
7
|
import anndata as ad
|
8
|
+
import matplotlib.pyplot as plt
|
8
9
|
import numpy as np
|
9
10
|
import pandas as pd
|
10
11
|
import scanpy as sc
|
11
|
-
import
|
12
|
+
import seaborn as sns
|
12
13
|
import statsmodels.formula.api as smf
|
13
14
|
import statsmodels.stats.multitest as ssm
|
14
15
|
from anndata import AnnData
|
16
|
+
from lamin_utils import logger
|
15
17
|
from pandas import DataFrame
|
16
|
-
from rich import print
|
17
18
|
from rich.console import Group
|
18
19
|
from rich.live import Live
|
19
20
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
20
21
|
from scipy import stats
|
21
22
|
from scipy.optimize import nnls
|
23
|
+
from seaborn import PairGrid
|
22
24
|
from sklearn.linear_model import LinearRegression
|
23
25
|
from sparsecca import lp_pmd, multicca_permute, multicca_pmd
|
24
26
|
from statsmodels.sandbox.stats.multicomp import multipletests
|
25
27
|
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from matplotlib.axes import Axes
|
30
|
+
from matplotlib.figure import Figure
|
31
|
+
|
26
32
|
|
27
33
|
class Dialogue:
|
28
34
|
"""Python implementation of DIALOGUE"""
|
@@ -53,8 +59,6 @@ class Dialogue:
|
|
53
59
|
|
54
60
|
Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
|
55
61
|
|
56
|
-
# TODO: Replace with decoupler's implementation
|
57
|
-
|
58
62
|
Args:
|
59
63
|
groupby: The key to groupby for pseudobulks
|
60
64
|
strategy: The pseudobulking strategy. One of "median" or "mean"
|
@@ -62,14 +66,15 @@ class Dialogue:
|
|
62
66
|
Returns:
|
63
67
|
A Pandas DataFrame of pseudobulk counts
|
64
68
|
"""
|
69
|
+
# TODO: Replace with decoupler's implementation
|
65
70
|
pseudobulk = {"Genes": adata.var_names.values}
|
66
71
|
|
67
72
|
for category in adata.obs.loc[:, groupby].cat.categories:
|
68
73
|
temp = adata.obs.loc[:, groupby] == category
|
69
74
|
if strategy == "median":
|
70
|
-
pseudobulk[category] = adata[temp].X.median(axis=0)
|
75
|
+
pseudobulk[category] = adata[temp].X.median(axis=0)
|
71
76
|
elif strategy == "mean":
|
72
|
-
pseudobulk[category] = adata[temp].X.mean(axis=0)
|
77
|
+
pseudobulk[category] = adata[temp].X.mean(axis=0)
|
73
78
|
|
74
79
|
pseudobulk = pd.DataFrame(pseudobulk).set_index("Genes")
|
75
80
|
|
@@ -101,8 +106,6 @@ class Dialogue:
|
|
101
106
|
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
|
102
107
|
"""Row-wise mean center and scale by the standard deviation.
|
103
108
|
|
104
|
-
TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
|
105
|
-
|
106
109
|
Args:
|
107
110
|
pseudobulks: The pseudobulk PCA components.
|
108
111
|
normalize: Whether to mimic DIALOGUE behavior or not.
|
@@ -110,9 +113,9 @@ class Dialogue:
|
|
110
113
|
Returns:
|
111
114
|
The scaled count matrix.
|
112
115
|
"""
|
116
|
+
# TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
|
113
117
|
# DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
|
114
118
|
# However, performing this scaling _does_ increase overall correlation of the end result
|
115
|
-
# WHEN SAMPLE ORDER AND DIALOGUE2+3 PROCESSING IS IGNORED.
|
116
119
|
if normalize:
|
117
120
|
return pseudobulks.to_numpy()
|
118
121
|
else:
|
@@ -288,7 +291,7 @@ class Dialogue:
|
|
288
291
|
mcp_name: Name of mcp which was used for calculation of column value.
|
289
292
|
max_length: Value needed to later decide at what index the threshold value should be extracted from column.
|
290
293
|
min_threshold: Minimal threshold to select final scores by if it is smaller than calculated threshold.
|
291
|
-
index: Column index to use eto calculate the significant genes.
|
294
|
+
index: Column index to use eto calculate the significant genes.
|
292
295
|
|
293
296
|
Returns:
|
294
297
|
According to the values in a df column (default: zscore) the significant up and downregulated gene names
|
@@ -313,13 +316,13 @@ class Dialogue:
|
|
313
316
|
def _apply_HLM_per_MCP_for_one_pair(
|
314
317
|
self,
|
315
318
|
mcp_name: str,
|
316
|
-
scores_df:
|
319
|
+
scores_df: pd.DataFrame,
|
317
320
|
ct_data: AnnData,
|
318
321
|
tme: pd.DataFrame,
|
319
322
|
sig: dict,
|
320
323
|
n_counts: str,
|
321
324
|
formula: str,
|
322
|
-
confounder: str,
|
325
|
+
confounder: str | None,
|
323
326
|
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
324
327
|
"""Applies hierarchical modeling for a single MCP.
|
325
328
|
|
@@ -340,7 +343,7 @@ class Dialogue:
|
|
340
343
|
"""
|
341
344
|
HLM_result = self._mixed_effects(
|
342
345
|
scores=scores_df[[mcp_name]],
|
343
|
-
x_labels=ct_data.obs[[n_counts, confounder]],
|
346
|
+
x_labels=ct_data.obs[[n_counts, confounder]] if confounder else ct_data.obs[[n_counts]],
|
344
347
|
tme=tme,
|
345
348
|
genes_in_mcp=list(sig[mcp_name]["up"]) + list(sig[mcp_name]["down"]),
|
346
349
|
formula=formula,
|
@@ -367,19 +370,13 @@ class Dialogue:
|
|
367
370
|
return np.array(resid)
|
368
371
|
|
369
372
|
def _iterative_nnls(self, A_orig: np.ndarray, y_orig: np.ndarray, feature_ranks: list[int], n_iter: int = 1000):
|
370
|
-
"""Solves non-negative least
|
373
|
+
"""Solves non-negative least-squares separately for different feature categories.
|
371
374
|
|
372
375
|
Mimics DLG.iterative.nnls.
|
373
376
|
Variables are notated according to:
|
374
377
|
|
375
378
|
`argmin|Ax - y|`
|
376
379
|
|
377
|
-
Args:
|
378
|
-
A_orig:
|
379
|
-
y_orig:
|
380
|
-
feature_ranks:
|
381
|
-
n_iter: Passed to scipy.optimize.nnls. Defaults to 1000.
|
382
|
-
|
383
380
|
Returns:
|
384
381
|
Returns the aggregated coefficients from nnls.
|
385
382
|
"""
|
@@ -398,7 +395,7 @@ class Dialogue:
|
|
398
395
|
|
399
396
|
x_final = np.zeros(A_orig.shape[0])
|
400
397
|
Ax = np.zeros(A_orig.shape[1])
|
401
|
-
for _, mask in zip(sig_ranks, masks):
|
398
|
+
for _, mask in zip(sig_ranks, masks, strict=False):
|
402
399
|
A = A_orig[mask].T
|
403
400
|
coef_nnls, _ = nnls(A, y, maxiter=n_iter)
|
404
401
|
y = y - A @ coef_nnls # residuals
|
@@ -516,8 +513,8 @@ class Dialogue:
|
|
516
513
|
# TODO: probably format the up and down within get_top_elements
|
517
514
|
cca_sig: dict[str, Any] = defaultdict(dict)
|
518
515
|
for i in range(0, int(len(cca_sig_unformatted) / 2)):
|
519
|
-
cca_sig[f"MCP{i
|
520
|
-
cca_sig[f"MCP{i
|
516
|
+
cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
|
517
|
+
cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
|
521
518
|
|
522
519
|
cca_sig = dict(cca_sig)
|
523
520
|
cca_sig_results[ct] = cca_sig
|
@@ -555,7 +552,7 @@ class Dialogue:
|
|
555
552
|
|
556
553
|
return cca_sig_results, new_mcp_scores
|
557
554
|
|
558
|
-
def
|
555
|
+
def _load(
|
559
556
|
self,
|
560
557
|
adata: AnnData,
|
561
558
|
ct_order: list[str],
|
@@ -569,21 +566,11 @@ class Dialogue:
|
|
569
566
|
Args:
|
570
567
|
adata: AnnData object generate celltype objects for
|
571
568
|
ct_order: The order of cell types
|
572
|
-
agg_pca: Whether to aggregate pseudobulks with PCA or not.
|
573
|
-
normalize: Whether to mimic DIALOGUE behavior or not.
|
569
|
+
agg_pca: Whether to aggregate pseudobulks with PCA or not.
|
570
|
+
normalize: Whether to mimic DIALOGUE behavior or not.
|
574
571
|
|
575
572
|
Returns:
|
576
573
|
A celltype_label:array dictionary.
|
577
|
-
|
578
|
-
Examples:
|
579
|
-
>>> import pertpy as pt
|
580
|
-
>>> import scanpy as sc
|
581
|
-
>>> adata = pt.dt.dialogue_example()
|
582
|
-
>>> sc.pp.pca(adata)
|
583
|
-
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
584
|
-
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
585
|
-
>>> cell_types = adata.obs[dl.celltype_key].astype("category").cat.categories
|
586
|
-
>>> mcca_in, ct_subs = dl.load(adata, ct_order=cell_types)
|
587
574
|
"""
|
588
575
|
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
|
589
576
|
fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
|
@@ -620,7 +607,6 @@ class Dialogue:
|
|
620
607
|
agg_pca: Whether to calculate cell-averaged PCA components.
|
621
608
|
solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
|
622
609
|
For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
|
623
|
-
Defaults to 'bs'.
|
624
610
|
normalize: Whether to mimic DIALOGUE as close as possible
|
625
611
|
|
626
612
|
Returns:
|
@@ -631,25 +617,31 @@ class Dialogue:
|
|
631
617
|
>>> import scanpy as sc
|
632
618
|
>>> adata = pt.dt.dialogue_example()
|
633
619
|
>>> sc.pp.pca(adata)
|
634
|
-
>>> dl = pt.tl.Dialogue(
|
635
|
-
|
620
|
+
>>> dl = pt.tl.Dialogue(
|
621
|
+
... sample_id="clinical.status", celltype_key="cell.subtypes", n_counts_key="nCount_RNA", n_mpcs=3
|
622
|
+
... )
|
636
623
|
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
637
624
|
"""
|
638
|
-
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
|
639
|
-
# it is important here that to obtain the same result as in R, we pass the matrices in
|
640
|
-
# in the same order.
|
625
|
+
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
|
626
|
+
# As such, it is important here that to obtain the same result as in R, we pass the matrices in the same order.
|
641
627
|
if ct_order is not None:
|
642
628
|
cell_types = ct_order
|
643
629
|
else:
|
644
630
|
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
|
645
631
|
|
646
|
-
mcca_in, ct_subs = self.
|
632
|
+
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
|
647
633
|
|
648
634
|
n_samples = mcca_in[0].shape[1]
|
649
635
|
if penalties is None:
|
650
|
-
|
651
|
-
|
652
|
-
|
636
|
+
try:
|
637
|
+
penalties = multicca_permute(
|
638
|
+
mcca_in, penalties=np.sqrt(n_samples) / 2, nperms=10, niter=50, standardize=True
|
639
|
+
)["bestpenalties"]
|
640
|
+
except ValueError as e:
|
641
|
+
if "matmul: input operand 1 has a mismatch in its core dimension" in str(e):
|
642
|
+
raise ValueError("Please ensure that every cell type is represented in every sample.") from e
|
643
|
+
else:
|
644
|
+
raise
|
653
645
|
else:
|
654
646
|
penalties = penalties
|
655
647
|
|
@@ -685,7 +677,7 @@ class Dialogue:
|
|
685
677
|
ct_subs: dict,
|
686
678
|
mcp_scores: dict,
|
687
679
|
ws_dict: dict,
|
688
|
-
confounder: str,
|
680
|
+
confounder: str | None,
|
689
681
|
formula: str = None,
|
690
682
|
):
|
691
683
|
"""Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
|
@@ -700,7 +692,6 @@ class Dialogue:
|
|
700
692
|
A Pandas DataFrame containing:
|
701
693
|
- for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
|
702
694
|
- merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
|
703
|
-
TODO: Describe both returns
|
704
695
|
|
705
696
|
Examples:
|
706
697
|
>>> import pertpy as pt
|
@@ -713,7 +704,9 @@ class Dialogue:
|
|
713
704
|
>>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \
|
714
705
|
confounder="gender")
|
715
706
|
"""
|
716
|
-
#
|
707
|
+
# TODO the returns of the function better
|
708
|
+
|
709
|
+
# all possible pairs of cell types without pairing same cell type
|
717
710
|
cell_types = list(ct_subs.keys())
|
718
711
|
pairs = list(itertools.combinations(cell_types, 2))
|
719
712
|
|
@@ -721,9 +714,9 @@ class Dialogue:
|
|
721
714
|
formula = f"y ~ x + {self.n_counts_key}"
|
722
715
|
|
723
716
|
# Hierarchical modeling expects DataFrames
|
724
|
-
mcp_cell_types = {f"MCP{i
|
717
|
+
mcp_cell_types = {f"MCP{i}": cell_types for i in range(self.n_mcps)}
|
725
718
|
mcp_scores_df = {
|
726
|
-
ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=mcp_cell_types.keys())
|
719
|
+
ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=list(mcp_cell_types.keys()))
|
727
720
|
for ct, v in mcp_scores.items()
|
728
721
|
}
|
729
722
|
|
@@ -762,10 +755,10 @@ class Dialogue:
|
|
762
755
|
mcps.append(mcp)
|
763
756
|
|
764
757
|
if len(mcps) == 0:
|
765
|
-
|
758
|
+
logger.warning(f"No shared MCPs between {cell_type_1} and {cell_type_2}.")
|
766
759
|
continue
|
767
760
|
|
768
|
-
|
761
|
+
logger.info(f"{len(mcps)} MCPs identified for {cell_type_1} and {cell_type_2}.")
|
769
762
|
|
770
763
|
new_mcp_scores: dict[Any, list[Any]]
|
771
764
|
cca_sig, new_mcp_scores = self._calculate_cca_sig(
|
@@ -805,7 +798,7 @@ class Dialogue:
|
|
805
798
|
for mcp in mcps:
|
806
799
|
mixed_model_progress.update(mm_task, description=f"[bold blue]Determining mixed effects for {mcp}")
|
807
800
|
|
808
|
-
# TODO Check
|
801
|
+
# TODO Check whether the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
|
809
802
|
result = {}
|
810
803
|
result["HLM_result_1"], result["sig_genes_1"] = self._apply_HLM_per_MCP_for_one_pair(
|
811
804
|
mcp_name=mcp,
|
@@ -875,22 +868,19 @@ class Dialogue:
|
|
875
868
|
sample_label = self.sample_id
|
876
869
|
n_mcps = self.n_mcps
|
877
870
|
|
878
|
-
# create conditions_compare if not supplied
|
879
871
|
if conditions_compare is None:
|
880
|
-
conditions_compare = list(adata.obs[
|
872
|
+
conditions_compare = list(adata.obs[condition_label].cat.categories) # type: ignore
|
881
873
|
if len(conditions_compare) != 2:
|
882
874
|
raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
|
883
875
|
|
884
|
-
# create data frames to store results
|
885
876
|
pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
886
877
|
tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
887
878
|
pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
|
888
879
|
|
889
880
|
response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
|
890
881
|
for celltype in adata.obs[celltype_label].unique():
|
891
|
-
# subset data to cell type
|
892
882
|
df = adata.obs[adata.obs[celltype_label] == celltype]
|
893
|
-
|
883
|
+
|
894
884
|
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
|
895
885
|
mns = df.groupby(sample_label)[mcpnum].mean()
|
896
886
|
mns = pd.concat([mns, response], axis=1)
|
@@ -900,11 +890,10 @@ class Dialogue:
|
|
900
890
|
)
|
901
891
|
pvals.loc[celltype, mcpnum] = res[1]
|
902
892
|
tstats.loc[celltype, mcpnum] = res[0]
|
903
|
-
# return(res)
|
904
893
|
|
905
|
-
# benjamini-hochberg correction for number of cell types (use BH because correlated MCPs)
|
906
894
|
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
|
907
895
|
pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
|
896
|
+
|
908
897
|
return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
|
909
898
|
|
910
899
|
def get_mlm_mcp_genes(
|
@@ -921,10 +910,8 @@ class Dialogue:
|
|
921
910
|
celltype: Cell type of interest.
|
922
911
|
results: dl.MultilevelModeling result object.
|
923
912
|
MCP: MCP key of the result object.
|
924
|
-
|
925
|
-
Defaults to 0.70.
|
913
|
+
threshold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
|
926
914
|
focal_celltypes: None (compare against all cell types) or a list of other cell types which you want to compare against.
|
927
|
-
Defaults to None.
|
928
915
|
|
929
916
|
Returns:
|
930
917
|
Dict with keys 'up_genes' and 'down_genes' and values of lists of genes
|
@@ -945,7 +932,6 @@ class Dialogue:
|
|
945
932
|
# REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD
|
946
933
|
if MCP.startswith("mcp_"):
|
947
934
|
MCP = MCP.replace("mcp_", "MCP")
|
948
|
-
# convert from MCPx to MCPx+1
|
949
935
|
MCP = "MCP" + str(int(MCP[3:]) - 1)
|
950
936
|
|
951
937
|
# Extract all comparison keys from the results object
|
@@ -1004,27 +990,24 @@ class Dialogue:
|
|
1004
990
|
Args:
|
1005
991
|
ct_subs: Dialogue output ct_subs dictionary
|
1006
992
|
mcp: The name of the marker gene expression column.
|
1007
|
-
Defaults to "mcp_0".
|
1008
993
|
fraction: Fraction of extreme cells to consider for gene ranking.
|
1009
994
|
Should be between 0 and 1.
|
1010
|
-
Defaults to 0.1.
|
1011
995
|
|
1012
996
|
Returns:
|
1013
997
|
Dictionary where keys are subpopulation names and values are Anndata
|
1014
998
|
objects containing the results of gene ranking analysis.
|
1015
999
|
|
1016
1000
|
Examples:
|
1017
|
-
ct_subs = {
|
1018
|
-
"subpop1": anndata_obj1,
|
1019
|
-
"subpop2": anndata_obj2,
|
1020
|
-
# ... more subpopulations ...
|
1021
|
-
}
|
1022
|
-
genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
|
1001
|
+
>>> ct_subs = {
|
1002
|
+
... "subpop1": anndata_obj1,
|
1003
|
+
... "subpop2": anndata_obj2,
|
1004
|
+
... # ... more subpopulations ...
|
1005
|
+
... }
|
1006
|
+
>>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
|
1023
1007
|
"""
|
1024
1008
|
genes = {}
|
1025
1009
|
for ct in ct_subs.keys():
|
1026
1010
|
mini = ct_subs[ct]
|
1027
|
-
mini.obs[mcp]
|
1028
1011
|
mini.obs["extrema"] = pd.qcut(
|
1029
1012
|
mini.obs[mcp],
|
1030
1013
|
[0, 0 + fraction, 1 - fraction, 1.0],
|
@@ -1034,6 +1017,7 @@ class Dialogue:
|
|
1034
1017
|
mini, "extrema", groups=["high" + mcp + " " + ct], reference="low " + mcp + " " + ct
|
1035
1018
|
)
|
1036
1019
|
genes[ct] = mini # .uns['rank_genes_groups']
|
1020
|
+
|
1037
1021
|
return genes
|
1038
1022
|
|
1039
1023
|
def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
|
@@ -1046,7 +1030,7 @@ class Dialogue:
|
|
1046
1030
|
Args:
|
1047
1031
|
ct_subs: Dialogue output ct_subs dictionary
|
1048
1032
|
fraction: Fraction of extreme cells to consider for gene ranking.
|
1049
|
-
Should be between 0 and 1.
|
1033
|
+
Should be between 0 and 1.
|
1050
1034
|
|
1051
1035
|
Returns:
|
1052
1036
|
Nested dictionary where keys of the first level are MCPs (of the form "mcp_0" etc)
|
@@ -1064,7 +1048,7 @@ class Dialogue:
|
|
1064
1048
|
>>> extrema_mcp_genes = dl.get_extrema_MCP_genes(ct_subs)
|
1065
1049
|
"""
|
1066
1050
|
rank_dfs: dict[str, dict[Any, Any]] = {}
|
1067
|
-
|
1051
|
+
ct_sub = next(iter(ct_subs.values()))
|
1068
1052
|
mcps = [col for col in ct_sub.obs.columns if col.startswith("mcp_")]
|
1069
1053
|
|
1070
1054
|
for mcp in mcps:
|
@@ -1072,4 +1056,123 @@ class Dialogue:
|
|
1072
1056
|
ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
|
1073
1057
|
for celltype in ct_ranked.keys():
|
1074
1058
|
rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
|
1059
|
+
|
1075
1060
|
return rank_dfs
|
1061
|
+
|
1062
|
+
def plot_split_violins(
|
1063
|
+
self,
|
1064
|
+
adata: AnnData,
|
1065
|
+
split_key: str,
|
1066
|
+
celltype_key: str,
|
1067
|
+
split_which: tuple[str, str] = None,
|
1068
|
+
mcp: str = "mcp_0",
|
1069
|
+
return_fig: bool | None = None,
|
1070
|
+
ax: Axes | None = None,
|
1071
|
+
save: bool | str | None = None,
|
1072
|
+
show: bool | None = None,
|
1073
|
+
) -> Axes | Figure | None:
|
1074
|
+
"""Plots split violin plots for a given MCP and split variable.
|
1075
|
+
|
1076
|
+
Any cells with a value for split_key not in split_which are removed from the plot.
|
1077
|
+
|
1078
|
+
Args:
|
1079
|
+
adata: Annotated data object.
|
1080
|
+
split_key: Variable in adata.obs used to split the data.
|
1081
|
+
celltype_key: Key for cell type annotations.
|
1082
|
+
split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
|
1083
|
+
mcp: Key for MCP data.
|
1084
|
+
|
1085
|
+
Returns:
|
1086
|
+
A :class:`~matplotlib.axes.Axes` object
|
1087
|
+
|
1088
|
+
Examples:
|
1089
|
+
>>> import pertpy as pt
|
1090
|
+
>>> import scanpy as sc
|
1091
|
+
>>> adata = pt.dt.dialogue_example()
|
1092
|
+
>>> sc.pp.pca(adata)
|
1093
|
+
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
1094
|
+
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
1095
|
+
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
1096
|
+
>>> dl.plot_split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
|
1097
|
+
|
1098
|
+
Preview:
|
1099
|
+
.. image:: /_static/docstring_previews/dialogue_violin.png
|
1100
|
+
"""
|
1101
|
+
df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
|
1102
|
+
if split_which is None:
|
1103
|
+
split_which = df[split_key].unique()
|
1104
|
+
df = df[df[split_key].isin(split_which)]
|
1105
|
+
df[split_key] = df[split_key].cat.remove_unused_categories()
|
1106
|
+
|
1107
|
+
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
1108
|
+
|
1109
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
1110
|
+
|
1111
|
+
if save:
|
1112
|
+
plt.savefig(save, bbox_inches="tight")
|
1113
|
+
if show:
|
1114
|
+
plt.show()
|
1115
|
+
if return_fig:
|
1116
|
+
return plt.gcf()
|
1117
|
+
if not (show or save):
|
1118
|
+
return ax
|
1119
|
+
return None
|
1120
|
+
|
1121
|
+
def plot_pairplot(
|
1122
|
+
self,
|
1123
|
+
adata: AnnData,
|
1124
|
+
celltype_key: str,
|
1125
|
+
color: str,
|
1126
|
+
sample_id: str,
|
1127
|
+
mcp: str = "mcp_0",
|
1128
|
+
return_fig: bool | None = None,
|
1129
|
+
show: bool | None = None,
|
1130
|
+
save: bool | str | None = None,
|
1131
|
+
) -> PairGrid | Figure | None:
|
1132
|
+
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
1133
|
+
|
1134
|
+
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
|
1135
|
+
then creates a pairplot to visualize the relationships between these mean MCP values.
|
1136
|
+
|
1137
|
+
Args:
|
1138
|
+
adata: Annotated data object.
|
1139
|
+
celltype_key: Key in `adata.obs` containing cell type annotations.
|
1140
|
+
color: Key in `adata.obs` for color annotations. This parameter is used as the hue
|
1141
|
+
sample_id: Key in `adata.obs` for the sample annotations.
|
1142
|
+
mcp: Key in `adata.obs` for MCP feature values.
|
1143
|
+
|
1144
|
+
Returns:
|
1145
|
+
Seaborn Pairgrid object.
|
1146
|
+
|
1147
|
+
Examples:
|
1148
|
+
>>> import pertpy as pt
|
1149
|
+
>>> import scanpy as sc
|
1150
|
+
>>> adata = pt.dt.dialogue_example()
|
1151
|
+
>>> sc.pp.pca(adata)
|
1152
|
+
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
1153
|
+
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
1154
|
+
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
1155
|
+
>>> dl.plot_pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
|
1156
|
+
|
1157
|
+
Preview:
|
1158
|
+
.. image:: /_static/docstring_previews/dialogue_pairplot.png
|
1159
|
+
"""
|
1160
|
+
mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
|
1161
|
+
mean_mcps = mean_mcps.reset_index()
|
1162
|
+
mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
|
1163
|
+
|
1164
|
+
aggstats = adata.obs.groupby([sample_id])[color].describe()
|
1165
|
+
aggstats = aggstats.loc[list(mcp_pivot.index), :]
|
1166
|
+
aggstats[color] = aggstats["top"]
|
1167
|
+
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
1168
|
+
ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
|
1169
|
+
|
1170
|
+
if save:
|
1171
|
+
plt.savefig(save, bbox_inches="tight")
|
1172
|
+
if show:
|
1173
|
+
plt.show()
|
1174
|
+
if return_fig:
|
1175
|
+
return plt.gcf()
|
1176
|
+
if not (show or save):
|
1177
|
+
return ax
|
1178
|
+
return None
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from ._base import ContrastType, LinearModelBase, MethodBase
|
2
|
+
from ._dge_comparison import DGEEVAL
|
3
|
+
from ._edger import EdgeR
|
4
|
+
from ._pydeseq2 import PyDESeq2
|
5
|
+
from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
|
6
|
+
from ._statsmodels import Statsmodels
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"MethodBase",
|
10
|
+
"LinearModelBase",
|
11
|
+
"EdgeR",
|
12
|
+
"PyDESeq2",
|
13
|
+
"Statsmodels",
|
14
|
+
"SimpleComparisonBase",
|
15
|
+
"WilcoxonTest",
|
16
|
+
"TTest",
|
17
|
+
"ContrastType",
|
18
|
+
]
|
19
|
+
|
20
|
+
AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
|