pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_dialogue.py
CHANGED
@@ -25,6 +25,8 @@ from sklearn.linear_model import LinearRegression
|
|
25
25
|
from sparsecca import lp_pmd, multicca_permute, multicca_pmd
|
26
26
|
from statsmodels.sandbox.stats.multicomp import multipletests
|
27
27
|
|
28
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
29
|
+
|
28
30
|
if TYPE_CHECKING:
|
29
31
|
from matplotlib.axes import Axes
|
30
32
|
from matplotlib.figure import Figure
|
@@ -80,27 +82,27 @@ class Dialogue:
|
|
80
82
|
|
81
83
|
return pseudobulk
|
82
84
|
|
83
|
-
def
|
84
|
-
|
85
|
+
def _pseudobulk_feature_space(
|
86
|
+
self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
|
87
|
+
) -> pd.DataFrame:
|
88
|
+
"""Return Cell-averaged components from a passed feature space.
|
85
89
|
|
86
90
|
TODO: consider merging with `get_pseudobulks`
|
87
91
|
TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
|
88
92
|
|
89
93
|
Args:
|
90
|
-
groupby: The key to groupby for pseudobulks
|
91
|
-
n_components: The number of
|
94
|
+
groupby: The key to groupby for pseudobulks.
|
95
|
+
n_components: The number of components to use.
|
96
|
+
feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
|
92
97
|
|
93
98
|
Returns:
|
94
|
-
A pseudobulk of
|
99
|
+
A pseudobulk DataFrame of the averaged components.
|
95
100
|
"""
|
96
101
|
aggr = {}
|
97
|
-
|
98
102
|
for category in adata.obs.loc[:, groupby].cat.categories:
|
99
103
|
temp = adata.obs.loc[:, groupby] == category
|
100
|
-
aggr[category] = adata[temp].obsm[
|
101
|
-
|
104
|
+
aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
|
102
105
|
aggr = pd.DataFrame(aggr)
|
103
|
-
|
104
106
|
return aggr
|
105
107
|
|
106
108
|
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
|
@@ -556,7 +558,7 @@ class Dialogue:
|
|
556
558
|
self,
|
557
559
|
adata: AnnData,
|
558
560
|
ct_order: list[str],
|
559
|
-
|
561
|
+
agg_feature: bool = True,
|
560
562
|
normalize: bool = True,
|
561
563
|
) -> tuple[list, dict]:
|
562
564
|
"""Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
|
@@ -566,14 +568,14 @@ class Dialogue:
|
|
566
568
|
Args:
|
567
569
|
adata: AnnData object generate celltype objects for
|
568
570
|
ct_order: The order of cell types
|
569
|
-
|
571
|
+
agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
|
570
572
|
normalize: Whether to mimic DIALOGUE behavior or not.
|
571
573
|
|
572
574
|
Returns:
|
573
575
|
A celltype_label:array dictionary.
|
574
576
|
"""
|
575
577
|
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
|
576
|
-
fn = self.
|
578
|
+
fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
|
577
579
|
ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
|
578
580
|
|
579
581
|
# TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
|
@@ -591,7 +593,7 @@ class Dialogue:
|
|
591
593
|
adata: AnnData,
|
592
594
|
penalties: list[int] = None,
|
593
595
|
ct_order: list[str] = None,
|
594
|
-
|
596
|
+
agg_feature: bool = True,
|
595
597
|
solver: Literal["lp", "bs"] = "bs",
|
596
598
|
normalize: bool = True,
|
597
599
|
) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
|
@@ -604,7 +606,7 @@ class Dialogue:
|
|
604
606
|
sample_id: Key to use for pseudobulk determination.
|
605
607
|
penalties: PMD penalties.
|
606
608
|
ct_order: The order of cell types.
|
607
|
-
|
609
|
+
agg_features: Whether to calculate cell-averaged principal components.
|
608
610
|
solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
|
609
611
|
For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
|
610
612
|
normalize: Whether to mimic DIALOGUE as close as possible
|
@@ -629,7 +631,7 @@ class Dialogue:
|
|
629
631
|
else:
|
630
632
|
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
|
631
633
|
|
632
|
-
mcca_in, ct_subs = self._load(adata, ct_order=cell_types,
|
634
|
+
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
|
633
635
|
|
634
636
|
n_samples = mcca_in[0].shape[1]
|
635
637
|
if penalties is None:
|
@@ -1059,18 +1061,17 @@ class Dialogue:
|
|
1059
1061
|
|
1060
1062
|
return rank_dfs
|
1061
1063
|
|
1064
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1062
1065
|
def plot_split_violins(
|
1063
1066
|
self,
|
1064
1067
|
adata: AnnData,
|
1065
1068
|
split_key: str,
|
1066
1069
|
celltype_key: str,
|
1070
|
+
*,
|
1067
1071
|
split_which: tuple[str, str] = None,
|
1068
1072
|
mcp: str = "mcp_0",
|
1069
|
-
return_fig: bool
|
1070
|
-
|
1071
|
-
save: bool | str | None = None,
|
1072
|
-
show: bool | None = None,
|
1073
|
-
) -> Axes | Figure | None:
|
1073
|
+
return_fig: bool = False,
|
1074
|
+
) -> Figure | None:
|
1074
1075
|
"""Plots split violin plots for a given MCP and split variable.
|
1075
1076
|
|
1076
1077
|
Any cells with a value for split_key not in split_which are removed from the plot.
|
@@ -1081,9 +1082,10 @@ class Dialogue:
|
|
1081
1082
|
celltype_key: Key for cell type annotations.
|
1082
1083
|
split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
|
1083
1084
|
mcp: Key for MCP data.
|
1085
|
+
{common_plot_args}
|
1084
1086
|
|
1085
1087
|
Returns:
|
1086
|
-
|
1088
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1087
1089
|
|
1088
1090
|
Examples:
|
1089
1091
|
>>> import pertpy as pt
|
@@ -1105,30 +1107,24 @@ class Dialogue:
|
|
1105
1107
|
df[split_key] = df[split_key].cat.remove_unused_categories()
|
1106
1108
|
|
1107
1109
|
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
1108
|
-
|
1109
1110
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
1110
1111
|
|
1111
|
-
if save:
|
1112
|
-
plt.savefig(save, bbox_inches="tight")
|
1113
|
-
if show:
|
1114
|
-
plt.show()
|
1115
1112
|
if return_fig:
|
1116
1113
|
return plt.gcf()
|
1117
|
-
|
1118
|
-
return ax
|
1114
|
+
plt.show()
|
1119
1115
|
return None
|
1120
1116
|
|
1117
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1121
1118
|
def plot_pairplot(
|
1122
1119
|
self,
|
1123
1120
|
adata: AnnData,
|
1124
1121
|
celltype_key: str,
|
1125
1122
|
color: str,
|
1126
1123
|
sample_id: str,
|
1124
|
+
*,
|
1127
1125
|
mcp: str = "mcp_0",
|
1128
|
-
return_fig: bool
|
1129
|
-
|
1130
|
-
save: bool | str | None = None,
|
1131
|
-
) -> PairGrid | Figure | None:
|
1126
|
+
return_fig: bool = False,
|
1127
|
+
) -> Figure | None:
|
1132
1128
|
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
1133
1129
|
|
1134
1130
|
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
|
@@ -1140,9 +1136,10 @@ class Dialogue:
|
|
1140
1136
|
color: Key in `adata.obs` for color annotations. This parameter is used as the hue
|
1141
1137
|
sample_id: Key in `adata.obs` for the sample annotations.
|
1142
1138
|
mcp: Key in `adata.obs` for MCP feature values.
|
1139
|
+
{common_plot_args}
|
1143
1140
|
|
1144
1141
|
Returns:
|
1145
|
-
|
1142
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1146
1143
|
|
1147
1144
|
Examples:
|
1148
1145
|
>>> import pertpy as pt
|
@@ -1165,14 +1162,9 @@ class Dialogue:
|
|
1165
1162
|
aggstats = aggstats.loc[list(mcp_pivot.index), :]
|
1166
1163
|
aggstats[color] = aggstats["top"]
|
1167
1164
|
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
1168
|
-
|
1165
|
+
sns.pairplot(mcp_pivot, hue=color, corner=True)
|
1169
1166
|
|
1170
|
-
if save:
|
1171
|
-
plt.savefig(save, bbox_inches="tight")
|
1172
|
-
if show:
|
1173
|
-
plt.show()
|
1174
1167
|
if return_fig:
|
1175
1168
|
return plt.gcf()
|
1176
|
-
|
1177
|
-
return ax
|
1169
|
+
plt.show()
|
1178
1170
|
return None
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from ._base import
|
1
|
+
from ._base import LinearModelBase, MethodBase
|
2
2
|
from ._dge_comparison import DGEEVAL
|
3
3
|
from ._edger import EdgeR
|
4
4
|
from ._pydeseq2 import PyDESeq2
|
@@ -14,7 +14,6 @@ __all__ = [
|
|
14
14
|
"SimpleComparisonBase",
|
15
15
|
"WilcoxonTest",
|
16
16
|
"TTest",
|
17
|
-
"ContrastType",
|
18
17
|
]
|
19
18
|
|
20
19
|
AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
|