pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +5 -1
- pertpy/_doc.py +2 -5
- pertpy/_types.py +6 -0
- pertpy/data/_dataloader.py +68 -24
- pertpy/data/_datasets.py +9 -9
- pertpy/metadata/__init__.py +2 -1
- pertpy/metadata/_cell_line.py +136 -30
- pertpy/metadata/_look_up.py +13 -19
- pertpy/metadata/_moa.py +1 -1
- pertpy/preprocessing/_guide_rna.py +221 -39
- pertpy/preprocessing/_guide_rna_mixture.py +177 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +138 -142
- pertpy/tools/_cinemaot.py +75 -117
- pertpy/tools/_coda/_base_coda.py +150 -174
- pertpy/tools/_coda/_sccoda.py +66 -69
- pertpy/tools/_coda/_tasccoda.py +71 -79
- pertpy/tools/_dialogue.py +60 -56
- pertpy/tools/_differential_gene_expression/_base.py +25 -43
- pertpy/tools/_differential_gene_expression/_checks.py +4 -6
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
- pertpy/tools/_differential_gene_expression/_edger.py +6 -10
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
- pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
- pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
- pertpy/tools/_distances/_distance_tests.py +1 -2
- pertpy/tools/_distances/_distances.py +86 -92
- pertpy/tools/_enrichment.py +8 -25
- pertpy/tools/_milo.py +23 -27
- pertpy/tools/_mixscape.py +261 -175
- pertpy/tools/_perturbation_space/_clustering.py +4 -4
- pertpy/tools/_perturbation_space/_comparison.py +4 -4
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
- pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
- pertpy/tools/_perturbation_space/_simple.py +13 -17
- pertpy/tools/_scgen/_scgen.py +17 -20
- pertpy/tools/_scgen/_scgenvae.py +2 -2
- pertpy/tools/_scgen/_utils.py +3 -1
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
- pertpy-0.11.0.dist-info/RECORD +58 -0
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
- pertpy/tools/_kernel_pca.py +0 -50
- pertpy-0.9.5.dist-info/RECORD +0 -57
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_dialogue.py
CHANGED
@@ -33,9 +33,17 @@ if TYPE_CHECKING:
|
|
33
33
|
|
34
34
|
|
35
35
|
class Dialogue:
|
36
|
-
"""Python implementation of DIALOGUE"""
|
36
|
+
"""Python implementation of DIALOGUE."""
|
37
37
|
|
38
|
-
def __init__(
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
sample_id: str,
|
41
|
+
celltype_key: str,
|
42
|
+
n_counts_key: str,
|
43
|
+
n_mpcs: int,
|
44
|
+
feature_space_key: str = "X_pca",
|
45
|
+
n_components: int = 50,
|
46
|
+
):
|
39
47
|
"""Constructor for Dialogue.
|
40
48
|
|
41
49
|
Args:
|
@@ -43,6 +51,8 @@ class Dialogue:
|
|
43
51
|
celltype_key: The key in AnnData.obs which contains the cell type column.
|
44
52
|
n_counts_key: The key of the number of counts in Anndata.obs . Also commonly the size factor.
|
45
53
|
n_mpcs: Number of PMD components which corresponds to the number of determined MCPs.
|
54
|
+
feature_space_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
|
55
|
+
n_components: The number of components of the feature space to use, e.g. PCA components.
|
46
56
|
"""
|
47
57
|
self.sample_id = sample_id
|
48
58
|
self.celltype_key = celltype_key
|
@@ -53,6 +63,8 @@ class Dialogue:
|
|
53
63
|
)
|
54
64
|
self.n_counts_key = n_counts_key
|
55
65
|
self.n_mcps = n_mpcs
|
66
|
+
self.feature_space_key = feature_space_key
|
67
|
+
self.n_components = n_components
|
56
68
|
|
57
69
|
def _get_pseudobulks(
|
58
70
|
self, adata: AnnData, groupby: str, strategy: Literal["median", "mean"] = "median"
|
@@ -62,6 +74,7 @@ class Dialogue:
|
|
62
74
|
Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
|
63
75
|
|
64
76
|
Args:
|
77
|
+
adata: Annotated data matrix.
|
65
78
|
groupby: The key to groupby for pseudobulks
|
66
79
|
strategy: The pseudobulking strategy. One of "median" or "mean"
|
67
80
|
|
@@ -82,27 +95,28 @@ class Dialogue:
|
|
82
95
|
|
83
96
|
return pseudobulk
|
84
97
|
|
85
|
-
def
|
86
|
-
|
98
|
+
def _pseudobulk_feature_space(
|
99
|
+
self,
|
100
|
+
adata: AnnData,
|
101
|
+
groupby: str,
|
102
|
+
) -> pd.DataFrame:
|
103
|
+
"""Return Cell-averaged components from a passed feature space.
|
87
104
|
|
88
105
|
TODO: consider merging with `get_pseudobulks`
|
89
106
|
TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
|
90
107
|
|
91
108
|
Args:
|
92
|
-
|
93
|
-
|
109
|
+
adata: Annotated data matrix.
|
110
|
+
groupby: The key to groupby for pseudobulks.
|
94
111
|
|
95
112
|
Returns:
|
96
|
-
A pseudobulk of
|
113
|
+
A pseudobulk DataFrame of the averaged components.
|
97
114
|
"""
|
98
115
|
aggr = {}
|
99
|
-
|
100
116
|
for category in adata.obs.loc[:, groupby].cat.categories:
|
101
117
|
temp = adata.obs.loc[:, groupby] == category
|
102
|
-
aggr[category] = adata[temp].obsm[
|
103
|
-
|
118
|
+
aggr[category] = adata[temp].obsm[self.feature_space_key][:, : self.n_components].mean(axis=0)
|
104
119
|
aggr = pd.DataFrame(aggr)
|
105
|
-
|
106
120
|
return aggr
|
107
121
|
|
108
122
|
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
|
@@ -130,6 +144,7 @@ class Dialogue:
|
|
130
144
|
|
131
145
|
Args:
|
132
146
|
adata: The AnnData object to append mcp scores to.
|
147
|
+
ct_subs: cell type objects.
|
133
148
|
mcp_scores: The MCP scores dictionary.
|
134
149
|
celltype_key: Key of the cell type column in obs.
|
135
150
|
|
@@ -213,7 +228,7 @@ class Dialogue:
|
|
213
228
|
sample_obs: str,
|
214
229
|
return_all: bool = False,
|
215
230
|
):
|
216
|
-
"""Applies a mixed linear model using the specified formula (MCP scores used for the dependent var) and returns the coefficient and p-value
|
231
|
+
"""Applies a mixed linear model using the specified formula (MCP scores used for the dependent var) and returns the coefficient and p-value.
|
217
232
|
|
218
233
|
TODO: reduce runtime? Maybe we can use an approximation or something that isn't statsmodels.
|
219
234
|
|
@@ -332,7 +347,7 @@ class Dialogue:
|
|
332
347
|
|
333
348
|
Args:
|
334
349
|
mcp_name: The name of the MCP to model.
|
335
|
-
|
350
|
+
scores_df: The MCP scores for a cell type. Number of MCPs x number of features.
|
336
351
|
ct_data: The AnnData object containing the metadata and labels in obs.
|
337
352
|
tme: Transcript mean expression in `x`.
|
338
353
|
sig: DataFrame containing a series of up and downregulated MCPs.
|
@@ -418,11 +433,10 @@ class Dialogue:
|
|
418
433
|
# Finally get corr coeff
|
419
434
|
return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None], ssB[None]))
|
420
435
|
|
436
|
+
# TODO: needs check for correctness and variable renaming
|
437
|
+
# TODO: Confirm that this doesn't return duplicate gene names.
|
421
438
|
def _get_top_elements(self, m: pd.DataFrame, max_length: int, min_threshold: float):
|
422
|
-
"""
|
423
|
-
|
424
|
-
TODO: needs check for correctness and variable renaming
|
425
|
-
TODO: Confirm that this doesn't return duplicate gene names
|
439
|
+
"""Get top elements.
|
426
440
|
|
427
441
|
Args:
|
428
442
|
m: Any DataFrame of Gene name as index with variable columns.
|
@@ -457,12 +471,11 @@ class Dialogue:
|
|
457
471
|
# TODO this whole function should be standalone
|
458
472
|
# It will contain the calculation of up/down + calculation (new final mcp scores)
|
459
473
|
# Ensure that it'll still fit/work with the hierarchical multilevel_modeling
|
460
|
-
|
461
474
|
"""Determine the up and down genes per MCP."""
|
462
475
|
# TODO: something is slightly slow here
|
463
476
|
cca_sig_results: dict[Any, dict[str, Any]] = {}
|
464
477
|
new_mcp_scores: dict[Any, list[Any]] = {}
|
465
|
-
for ct in ct_subs
|
478
|
+
for ct in ct_subs:
|
466
479
|
ct_adata = ct_subs[ct]
|
467
480
|
conf_m = ct_adata.obs[n_counts_key].values
|
468
481
|
|
@@ -483,9 +496,7 @@ class Dialogue:
|
|
483
496
|
from scipy.stats import spearmanr
|
484
497
|
|
485
498
|
def _pcor_mat(v1, v2, v3, method="spearman"):
|
486
|
-
"""
|
487
|
-
MAJOR TODO: I've only used normal correlation instead of partial correlation as we wait on the implementation
|
488
|
-
"""
|
499
|
+
"""MAJOR TODO: I've only used normal correlation instead of partial correlation as we wait on the implementation."""
|
489
500
|
correlations = [] # R
|
490
501
|
pvals = [] # P
|
491
502
|
for x2 in v2:
|
@@ -506,7 +517,7 @@ class Dialogue:
|
|
506
517
|
return np.array(correlations), np.array(pvals) # pvals_adjusted
|
507
518
|
|
508
519
|
C1, P1 = _pcor_mat(ct_adata[:, top_cor_genes_flattened].X.toarray().T, mcp_scores[ct].T, conf_m)
|
509
|
-
C1[
|
520
|
+
C1[(0.05 / ct_adata.shape[1]) < P1] = 0 # why?
|
510
521
|
|
511
522
|
cca_sig_unformatted = self._get_top_elements( # 3 up, 3 dn, for each mcp
|
512
523
|
pd.DataFrame(C1.T, index=top_cor_genes_flattened), max_length=max_genes, min_threshold=0.05
|
@@ -514,7 +525,7 @@ class Dialogue:
|
|
514
525
|
|
515
526
|
# TODO: probably format the up and down within get_top_elements
|
516
527
|
cca_sig: dict[str, Any] = defaultdict(dict)
|
517
|
-
for i in range(
|
528
|
+
for i in range(int(len(cca_sig_unformatted) / 2)):
|
518
529
|
cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
|
519
530
|
cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
|
520
531
|
|
@@ -523,7 +534,7 @@ class Dialogue:
|
|
523
534
|
|
524
535
|
# This is basically DIALOGUE 3 now
|
525
536
|
pre_r_scores = {
|
526
|
-
ct: ct_subs[ct].obsm[
|
537
|
+
ct: ct_subs[ct].obsm[self.feature_space_key][:, : self.n_components] @ ws_dict[ct]
|
527
538
|
for i, ct in enumerate(ct_subs.keys())
|
528
539
|
# TODO This is a recalculation and not a new calculation
|
529
540
|
}
|
@@ -558,7 +569,7 @@ class Dialogue:
|
|
558
569
|
self,
|
559
570
|
adata: AnnData,
|
560
571
|
ct_order: list[str],
|
561
|
-
|
572
|
+
agg_feature: bool = True,
|
562
573
|
normalize: bool = True,
|
563
574
|
) -> tuple[list, dict]:
|
564
575
|
"""Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
|
@@ -568,14 +579,14 @@ class Dialogue:
|
|
568
579
|
Args:
|
569
580
|
adata: AnnData object generate celltype objects for
|
570
581
|
ct_order: The order of cell types
|
571
|
-
|
582
|
+
agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
|
572
583
|
normalize: Whether to mimic DIALOGUE behavior or not.
|
573
584
|
|
574
585
|
Returns:
|
575
586
|
A celltype_label:array dictionary.
|
576
587
|
"""
|
577
588
|
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
|
578
|
-
fn = self.
|
589
|
+
fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
|
579
590
|
ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
|
580
591
|
|
581
592
|
# TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
|
@@ -591,9 +602,9 @@ class Dialogue:
|
|
591
602
|
def calculate_multifactor_PMD(
|
592
603
|
self,
|
593
604
|
adata: AnnData,
|
594
|
-
penalties: list[int] = None,
|
595
|
-
ct_order: list[str] = None,
|
596
|
-
|
605
|
+
penalties: list[int] | None = None,
|
606
|
+
ct_order: list[str] | None = None,
|
607
|
+
agg_feature: bool = True,
|
597
608
|
solver: Literal["lp", "bs"] = "bs",
|
598
609
|
normalize: bool = True,
|
599
610
|
) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
|
@@ -603,10 +614,9 @@ class Dialogue:
|
|
603
614
|
|
604
615
|
Args:
|
605
616
|
adata: AnnData object to calculate PMD for.
|
606
|
-
sample_id: Key to use for pseudobulk determination.
|
607
617
|
penalties: PMD penalties.
|
608
618
|
ct_order: The order of cell types.
|
609
|
-
|
619
|
+
agg_feature: Whether to calculate cell-averaged principal components.
|
610
620
|
solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
|
611
621
|
For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
|
612
622
|
normalize: Whether to mimic DIALOGUE as close as possible
|
@@ -631,7 +641,7 @@ class Dialogue:
|
|
631
641
|
else:
|
632
642
|
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
|
633
643
|
|
634
|
-
mcca_in, ct_subs = self._load(adata, ct_order=cell_types,
|
644
|
+
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
|
635
645
|
|
636
646
|
n_samples = mcca_in[0].shape[1]
|
637
647
|
if penalties is None:
|
@@ -644,8 +654,6 @@ class Dialogue:
|
|
644
654
|
raise ValueError("Please ensure that every cell type is represented in every sample.") from e
|
645
655
|
else:
|
646
656
|
raise
|
647
|
-
else:
|
648
|
-
penalties = penalties
|
649
657
|
|
650
658
|
if solver == "bs":
|
651
659
|
ws, _ = multicca_pmd(mcca_in, penalties, K=self.n_mcps, standardize=True, niter=100, mimic_R=normalize)
|
@@ -656,8 +664,8 @@ class Dialogue:
|
|
656
664
|
ws_dict = {ct: ws[i] for i, ct in enumerate(ct_order)}
|
657
665
|
|
658
666
|
pre_r_scores = {
|
659
|
-
ct: ct_subs[ct].obsm[
|
660
|
-
for i, ct in enumerate(cell_types)
|
667
|
+
ct: ct_subs[ct].obsm[self.feature_space_key][:, : self.n_components] @ ws[i]
|
668
|
+
for i, ct in enumerate(cell_types)
|
661
669
|
}
|
662
670
|
|
663
671
|
# TODO: output format needs some cleanup, even though each MCP score is matched to one cell, it's not at all
|
@@ -681,17 +689,17 @@ class Dialogue:
|
|
681
689
|
ws_dict: dict,
|
682
690
|
confounder: str | None,
|
683
691
|
formula: str = None,
|
684
|
-
):
|
692
|
+
) -> pd.DataFrame:
|
685
693
|
"""Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
|
686
694
|
|
687
695
|
Args:
|
688
696
|
ct_subs: The DIALOGUE cell type objects.
|
689
697
|
mcp_scores: The determined MCP scores from the PMD step.
|
698
|
+
ws_dict: WS dictionary.
|
690
699
|
confounder: Any modeling confounders.
|
691
700
|
formula: The hierarchical modeling formula. Defaults to y ~ x + n_counts.
|
692
701
|
|
693
702
|
Returns:
|
694
|
-
A Pandas DataFrame containing:
|
695
703
|
- for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
|
696
704
|
- merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
|
697
705
|
|
@@ -875,15 +883,15 @@ class Dialogue:
|
|
875
883
|
if len(conditions_compare) != 2:
|
876
884
|
raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
|
877
885
|
|
878
|
-
pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(
|
879
|
-
tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(
|
880
|
-
pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(
|
886
|
+
pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
|
887
|
+
tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
|
888
|
+
pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
|
881
889
|
|
882
890
|
response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
|
883
891
|
for celltype in adata.obs[celltype_label].unique():
|
884
892
|
df = adata.obs[adata.obs[celltype_label] == celltype]
|
885
893
|
|
886
|
-
for mcpnum in ["mcp_" + str(n) for n in range(
|
894
|
+
for mcpnum in ["mcp_" + str(n) for n in range(n_mcps)]:
|
887
895
|
mns = df.groupby(sample_label)[mcpnum].mean()
|
888
896
|
mns = pd.concat([mns, response], axis=1)
|
889
897
|
res = stats.ttest_ind(
|
@@ -893,7 +901,7 @@ class Dialogue:
|
|
893
901
|
pvals.loc[celltype, mcpnum] = res[1]
|
894
902
|
tstats.loc[celltype, mcpnum] = res[0]
|
895
903
|
|
896
|
-
for mcpnum in ["mcp_" + str(n) for n in range(
|
904
|
+
for mcpnum in ["mcp_" + str(n) for n in range(n_mcps)]:
|
897
905
|
pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
|
898
906
|
|
899
907
|
return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
|
@@ -956,7 +964,7 @@ class Dialogue:
|
|
956
964
|
|
957
965
|
genes_dict_up = {} # type: ignore
|
958
966
|
genes_dict_down = {} # type: ignore
|
959
|
-
for celltype2 in mcp_dict
|
967
|
+
for celltype2 in mcp_dict:
|
960
968
|
for gene in mcp_dict[celltype2][MCP + ".up"]:
|
961
969
|
if gene in genes_dict_up:
|
962
970
|
genes_dict_up[gene] += 1
|
@@ -1008,7 +1016,7 @@ class Dialogue:
|
|
1008
1016
|
>>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
|
1009
1017
|
"""
|
1010
1018
|
genes = {}
|
1011
|
-
for ct in ct_subs
|
1019
|
+
for ct in ct_subs:
|
1012
1020
|
mini = ct_subs[ct]
|
1013
1021
|
mini.obs["extrema"] = pd.qcut(
|
1014
1022
|
mini.obs[mcp],
|
@@ -1056,13 +1064,13 @@ class Dialogue:
|
|
1056
1064
|
for mcp in mcps:
|
1057
1065
|
rank_dfs[mcp] = {}
|
1058
1066
|
ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
|
1059
|
-
for celltype in ct_ranked
|
1067
|
+
for celltype in ct_ranked:
|
1060
1068
|
rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
|
1061
1069
|
|
1062
1070
|
return rank_dfs
|
1063
1071
|
|
1064
1072
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1065
|
-
def plot_split_violins(
|
1073
|
+
def plot_split_violins( # pragma: no cover # noqa: D417
|
1066
1074
|
self,
|
1067
1075
|
adata: AnnData,
|
1068
1076
|
split_key: str,
|
@@ -1070,7 +1078,6 @@ class Dialogue:
|
|
1070
1078
|
*,
|
1071
1079
|
split_which: tuple[str, str] = None,
|
1072
1080
|
mcp: str = "mcp_0",
|
1073
|
-
show: bool = True,
|
1074
1081
|
return_fig: bool = False,
|
1075
1082
|
) -> Figure | None:
|
1076
1083
|
"""Plots split violin plots for a given MCP and split variable.
|
@@ -1110,14 +1117,13 @@ class Dialogue:
|
|
1110
1117
|
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
1111
1118
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
1112
1119
|
|
1113
|
-
if show:
|
1114
|
-
plt.show()
|
1115
1120
|
if return_fig:
|
1116
1121
|
return plt.gcf()
|
1122
|
+
plt.show()
|
1117
1123
|
return None
|
1118
1124
|
|
1119
1125
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1120
|
-
def plot_pairplot(
|
1126
|
+
def plot_pairplot( # pragma: no cover # noqa: D417
|
1121
1127
|
self,
|
1122
1128
|
adata: AnnData,
|
1123
1129
|
celltype_key: str,
|
@@ -1125,7 +1131,6 @@ class Dialogue:
|
|
1125
1131
|
sample_id: str,
|
1126
1132
|
*,
|
1127
1133
|
mcp: str = "mcp_0",
|
1128
|
-
show: bool = True,
|
1129
1134
|
return_fig: bool = False,
|
1130
1135
|
) -> Figure | None:
|
1131
1136
|
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
@@ -1167,8 +1172,7 @@ class Dialogue:
|
|
1167
1172
|
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
1168
1173
|
sns.pairplot(mcp_pivot, hue=color, corner=True)
|
1169
1174
|
|
1170
|
-
if show:
|
1171
|
-
plt.show()
|
1172
1175
|
if return_fig:
|
1173
1176
|
return plt.gcf()
|
1177
|
+
plt.show()
|
1174
1178
|
return None
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import contextlib
|
1
2
|
import math
|
2
3
|
from abc import ABC, abstractmethod
|
3
4
|
from collections.abc import Iterable, Mapping, Sequence
|
@@ -23,8 +24,7 @@ from pertpy.tools._differential_gene_expression._checks import check_is_numeric_
|
|
23
24
|
|
24
25
|
class MethodBase(ABC):
|
25
26
|
def __init__(self, adata, *, mask=None, layer=None, **kwargs):
|
26
|
-
"""
|
27
|
-
Initialize the method.
|
27
|
+
"""Initialize the method.
|
28
28
|
|
29
29
|
Args:
|
30
30
|
adata: AnnData object, usually pseudobulked.
|
@@ -62,8 +62,7 @@ class MethodBase(ABC):
|
|
62
62
|
fit_kwargs=MappingProxyType({}),
|
63
63
|
test_kwargs=MappingProxyType({}),
|
64
64
|
):
|
65
|
-
"""
|
66
|
-
Compare between groups in a specified column.
|
65
|
+
"""Compare between groups in a specified column.
|
67
66
|
|
68
67
|
Args:
|
69
68
|
adata: AnnData object.
|
@@ -100,7 +99,7 @@ class MethodBase(ABC):
|
|
100
99
|
...
|
101
100
|
|
102
101
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
103
|
-
def plot_volcano(
|
102
|
+
def plot_volcano( # pragma: no cover # noqa: D417
|
104
103
|
self,
|
105
104
|
data: pd.DataFrame | ad.AnnData,
|
106
105
|
*,
|
@@ -125,7 +124,6 @@ class MethodBase(ABC):
|
|
125
124
|
shape_order: list[str] | None = None,
|
126
125
|
x_label: str | None = None,
|
127
126
|
y_label: str | None = None,
|
128
|
-
show: bool = True,
|
129
127
|
return_fig: bool = False,
|
130
128
|
**kwargs: int,
|
131
129
|
) -> Figure | None:
|
@@ -189,8 +187,7 @@ class MethodBase(ABC):
|
|
189
187
|
colors = ["gray", "#D62728", "#1F77B4"]
|
190
188
|
|
191
189
|
def _pval_reciprocal(lfc: float) -> float:
|
192
|
-
"""
|
193
|
-
Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
|
190
|
+
"""Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
|
194
191
|
|
195
192
|
Used for plotting the S-curve
|
196
193
|
"""
|
@@ -198,7 +195,7 @@ class MethodBase(ABC):
|
|
198
195
|
|
199
196
|
def _map_shape(symbol: str) -> str:
|
200
197
|
if shape_dict is not None:
|
201
|
-
for k in shape_dict
|
198
|
+
for k in shape_dict:
|
202
199
|
if shape_dict[k] is not None and symbol in shape_dict[k]:
|
203
200
|
return k
|
204
201
|
return "other"
|
@@ -212,8 +209,7 @@ class MethodBase(ABC):
|
|
212
209
|
pval_thresh: float = None,
|
213
210
|
s_curve: bool = False,
|
214
211
|
) -> str:
|
215
|
-
"""
|
216
|
-
Map genes to categorize based on log2fc and pvalue.
|
212
|
+
"""Map genes to categorize based on log2fc and pvalue.
|
217
213
|
|
218
214
|
These categories are used for coloring the dots.
|
219
215
|
Used when no color_dict is passed, sets up/down/nonsignificant.
|
@@ -230,14 +226,13 @@ class MethodBase(ABC):
|
|
230
226
|
return "Down"
|
231
227
|
else:
|
232
228
|
return "not DE"
|
229
|
+
# Standard condition for Up or Down categorization
|
230
|
+
elif log2fc > log2fc_thresh and nlog10 > pval_thresh:
|
231
|
+
return "Up"
|
232
|
+
elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
|
233
|
+
return "Down"
|
233
234
|
else:
|
234
|
-
|
235
|
-
if log2fc > log2fc_thresh and nlog10 > pval_thresh:
|
236
|
-
return "Up"
|
237
|
-
elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
|
238
|
-
return "Down"
|
239
|
-
else:
|
240
|
-
return "not DE"
|
235
|
+
return "not DE"
|
241
236
|
|
242
237
|
def _map_genes_categories_highlight(
|
243
238
|
row: pd.Series,
|
@@ -248,8 +243,7 @@ class MethodBase(ABC):
|
|
248
243
|
s_curve: bool = False,
|
249
244
|
symbol_col: str = None,
|
250
245
|
) -> str:
|
251
|
-
"""
|
252
|
-
Map genes to categorize based on log2fc and pvalue.
|
246
|
+
"""Map genes to categorize based on log2fc and pvalue.
|
253
247
|
|
254
248
|
These categories are used for coloring the dots.
|
255
249
|
Used when color_dict is passed, sets DE / not DE for background and user supplied highlight genes.
|
@@ -259,7 +253,7 @@ class MethodBase(ABC):
|
|
259
253
|
symbol = row[symbol_col]
|
260
254
|
|
261
255
|
if color_dict is not None:
|
262
|
-
for k in color_dict
|
256
|
+
for k in color_dict:
|
263
257
|
if symbol in color_dict[k]:
|
264
258
|
return k
|
265
259
|
|
@@ -484,14 +478,13 @@ class MethodBase(ABC):
|
|
484
478
|
|
485
479
|
plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
|
486
480
|
|
487
|
-
if show:
|
488
|
-
plt.show()
|
489
481
|
if return_fig:
|
490
482
|
return plt.gcf()
|
483
|
+
plt.show()
|
491
484
|
return None
|
492
485
|
|
493
486
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
494
|
-
def plot_paired(
|
487
|
+
def plot_paired( # pragma: no cover # noqa: D417
|
495
488
|
self,
|
496
489
|
adata: ad.AnnData,
|
497
490
|
results_df: pd.DataFrame,
|
@@ -511,7 +504,6 @@ class MethodBase(ABC):
|
|
511
504
|
pvalue_template=lambda x: f"p={x:.2e}",
|
512
505
|
boxplot_properties=None,
|
513
506
|
palette=None,
|
514
|
-
show: bool = True,
|
515
507
|
return_fig: bool = False,
|
516
508
|
) -> Figure | None:
|
517
509
|
"""Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
|
@@ -584,14 +576,9 @@ class MethodBase(ABC):
|
|
584
576
|
adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
|
585
577
|
)
|
586
578
|
|
587
|
-
if layer is not None
|
588
|
-
|
589
|
-
else:
|
590
|
-
X = adata.X
|
591
|
-
try:
|
579
|
+
X = adata.layers[layer] if layer is not None else adata.X
|
580
|
+
with contextlib.suppress(AttributeError):
|
592
581
|
X = X.toarray()
|
593
|
-
except AttributeError:
|
594
|
-
pass
|
595
582
|
|
596
583
|
groupby_cols = [pairedby, groupby]
|
597
584
|
df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
|
@@ -679,14 +666,13 @@ class MethodBase(ABC):
|
|
679
666
|
)
|
680
667
|
|
681
668
|
plt.tight_layout()
|
682
|
-
if show:
|
683
|
-
plt.show()
|
684
669
|
if return_fig:
|
685
670
|
return plt.gcf()
|
671
|
+
plt.show()
|
686
672
|
return None
|
687
673
|
|
688
674
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
689
|
-
def plot_fold_change(
|
675
|
+
def plot_fold_change( # pragma: no cover # noqa: D417
|
690
676
|
self,
|
691
677
|
results_df: pd.DataFrame,
|
692
678
|
*,
|
@@ -696,7 +682,6 @@ class MethodBase(ABC):
|
|
696
682
|
symbol_col: str = "variable",
|
697
683
|
y_label: str = "Log2 fold change",
|
698
684
|
figsize: tuple[int, int] = (10, 5),
|
699
|
-
show: bool = True,
|
700
685
|
return_fig: bool = False,
|
701
686
|
**barplot_kwargs,
|
702
687
|
) -> Figure | None:
|
@@ -762,14 +747,13 @@ class MethodBase(ABC):
|
|
762
747
|
plt.xlabel("")
|
763
748
|
plt.ylabel(y_label)
|
764
749
|
|
765
|
-
if show:
|
766
|
-
plt.show()
|
767
750
|
if return_fig:
|
768
751
|
return plt.gcf()
|
752
|
+
plt.show()
|
769
753
|
return None
|
770
754
|
|
771
755
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
772
|
-
def plot_multicomparison_fc(
|
756
|
+
def plot_multicomparison_fc( # pragma: no cover # noqa: D417
|
773
757
|
self,
|
774
758
|
results_df: pd.DataFrame,
|
775
759
|
*,
|
@@ -782,7 +766,6 @@ class MethodBase(ABC):
|
|
782
766
|
figsize: tuple[int, int] = (10, 2),
|
783
767
|
x_label: str = "Contrast",
|
784
768
|
y_label: str = "Gene",
|
785
|
-
show: bool = True,
|
786
769
|
return_fig: bool = False,
|
787
770
|
**heatmap_kwargs,
|
788
771
|
) -> Figure | None:
|
@@ -880,10 +863,9 @@ class MethodBase(ABC):
|
|
880
863
|
plt.xlabel(x_label)
|
881
864
|
plt.ylabel(y_label)
|
882
865
|
|
883
|
-
if show:
|
884
|
-
plt.show()
|
885
866
|
if return_fig:
|
886
867
|
return plt.gcf()
|
868
|
+
plt.show()
|
887
869
|
return None
|
888
870
|
|
889
871
|
|
@@ -1021,7 +1003,7 @@ class LinearModelBase(MethodBase):
|
|
1021
1003
|
)
|
1022
1004
|
return self.formulaic_contrasts.cond(**kwargs)
|
1023
1005
|
|
1024
|
-
def contrast(self, *args, **kwargs):
|
1006
|
+
def contrast(self, *args, **kwargs): # noqa: D417
|
1025
1007
|
"""Build a simple contrast for pairwise comparisons.
|
1026
1008
|
|
1027
1009
|
Args:
|
@@ -16,9 +16,8 @@ def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None:
|
|
16
16
|
if issparse(array):
|
17
17
|
if np.any(~np.isfinite(array.data)):
|
18
18
|
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
|
19
|
-
|
20
|
-
|
21
|
-
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
|
19
|
+
elif np.any(~np.isfinite(array)):
|
20
|
+
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
|
22
21
|
|
23
22
|
|
24
23
|
def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None:
|
@@ -34,8 +33,7 @@ def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-
|
|
34
33
|
if issparse(array):
|
35
34
|
if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
|
36
35
|
raise ValueError("Non-zero elements of the matrix must be close to integer values.")
|
37
|
-
|
38
|
-
|
39
|
-
raise ValueError("Matrix must be a count matrix.")
|
36
|
+
elif array.dtype.kind != "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
|
37
|
+
raise ValueError("Matrix must be a count matrix.")
|
40
38
|
if (array < 0).sum() > 0:
|
41
39
|
raise ValueError("Non-zero elements of the matrix must be positive.")
|
@@ -36,16 +36,15 @@ class DGEEVAL:
|
|
36
36
|
if not de_key1 or not de_key2:
|
37
37
|
raise ValueError("Both `de_key1` and `de_key2` must be provided together if using `adata`.")
|
38
38
|
|
39
|
-
|
40
|
-
|
41
|
-
raise ValueError("Both `de_df1` and `de_df2` must be provided together if using DataFrames.")
|
39
|
+
elif de_df1 is None or de_df2 is None:
|
40
|
+
raise ValueError("Both `de_df1` and `de_df2` must be provided together if using DataFrames.")
|
42
41
|
|
43
42
|
if de_key1:
|
44
43
|
if not adata:
|
45
44
|
raise ValueError("`adata` should be provided with `de_key1` and `de_key2`. ")
|
46
|
-
assert all(
|
47
|
-
|
48
|
-
)
|
45
|
+
assert all(k in adata.uns for k in [de_key1, de_key2]), (
|
46
|
+
"Provided `de_key1` and `de_key2` must exist in `adata.uns`."
|
47
|
+
)
|
49
48
|
vars = adata.var_names
|
50
49
|
|
51
50
|
if de_df1 is not None:
|