pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_dialogue.py
CHANGED
@@ -25,6 +25,8 @@ from sklearn.linear_model import LinearRegression
|
|
25
25
|
from sparsecca import lp_pmd, multicca_permute, multicca_pmd
|
26
26
|
from statsmodels.sandbox.stats.multicomp import multipletests
|
27
27
|
|
28
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
29
|
+
|
28
30
|
if TYPE_CHECKING:
|
29
31
|
from matplotlib.axes import Axes
|
30
32
|
from matplotlib.figure import Figure
|
@@ -80,27 +82,27 @@ class Dialogue:
|
|
80
82
|
|
81
83
|
return pseudobulk
|
82
84
|
|
83
|
-
def
|
84
|
-
|
85
|
+
def _pseudobulk_feature_space(
|
86
|
+
self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
|
87
|
+
) -> pd.DataFrame:
|
88
|
+
"""Return Cell-averaged components from a passed feature space.
|
85
89
|
|
86
90
|
TODO: consider merging with `get_pseudobulks`
|
87
91
|
TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
|
88
92
|
|
89
93
|
Args:
|
90
|
-
groupby: The key to groupby for pseudobulks
|
91
|
-
n_components: The number of
|
94
|
+
groupby: The key to groupby for pseudobulks.
|
95
|
+
n_components: The number of components to use.
|
96
|
+
feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
|
92
97
|
|
93
98
|
Returns:
|
94
|
-
A pseudobulk of
|
99
|
+
A pseudobulk DataFrame of the averaged components.
|
95
100
|
"""
|
96
101
|
aggr = {}
|
97
|
-
|
98
102
|
for category in adata.obs.loc[:, groupby].cat.categories:
|
99
103
|
temp = adata.obs.loc[:, groupby] == category
|
100
|
-
aggr[category] = adata[temp].obsm[
|
101
|
-
|
104
|
+
aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
|
102
105
|
aggr = pd.DataFrame(aggr)
|
103
|
-
|
104
106
|
return aggr
|
105
107
|
|
106
108
|
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
|
@@ -556,7 +558,7 @@ class Dialogue:
|
|
556
558
|
self,
|
557
559
|
adata: AnnData,
|
558
560
|
ct_order: list[str],
|
559
|
-
|
561
|
+
agg_feature: bool = True,
|
560
562
|
normalize: bool = True,
|
561
563
|
) -> tuple[list, dict]:
|
562
564
|
"""Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
|
@@ -566,14 +568,14 @@ class Dialogue:
|
|
566
568
|
Args:
|
567
569
|
adata: AnnData object generate celltype objects for
|
568
570
|
ct_order: The order of cell types
|
569
|
-
|
571
|
+
agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
|
570
572
|
normalize: Whether to mimic DIALOGUE behavior or not.
|
571
573
|
|
572
574
|
Returns:
|
573
575
|
A celltype_label:array dictionary.
|
574
576
|
"""
|
575
577
|
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
|
576
|
-
fn = self.
|
578
|
+
fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
|
577
579
|
ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
|
578
580
|
|
579
581
|
# TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
|
@@ -591,7 +593,7 @@ class Dialogue:
|
|
591
593
|
adata: AnnData,
|
592
594
|
penalties: list[int] = None,
|
593
595
|
ct_order: list[str] = None,
|
594
|
-
|
596
|
+
agg_feature: bool = True,
|
595
597
|
solver: Literal["lp", "bs"] = "bs",
|
596
598
|
normalize: bool = True,
|
597
599
|
) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
|
@@ -604,7 +606,7 @@ class Dialogue:
|
|
604
606
|
sample_id: Key to use for pseudobulk determination.
|
605
607
|
penalties: PMD penalties.
|
606
608
|
ct_order: The order of cell types.
|
607
|
-
|
609
|
+
agg_features: Whether to calculate cell-averaged principal components.
|
608
610
|
solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
|
609
611
|
For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
|
610
612
|
normalize: Whether to mimic DIALOGUE as close as possible
|
@@ -629,7 +631,7 @@ class Dialogue:
|
|
629
631
|
else:
|
630
632
|
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
|
631
633
|
|
632
|
-
mcca_in, ct_subs = self._load(adata, ct_order=cell_types,
|
634
|
+
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
|
633
635
|
|
634
636
|
n_samples = mcca_in[0].shape[1]
|
635
637
|
if penalties is None:
|
@@ -1059,18 +1061,17 @@ class Dialogue:
|
|
1059
1061
|
|
1060
1062
|
return rank_dfs
|
1061
1063
|
|
1064
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1062
1065
|
def plot_split_violins(
|
1063
1066
|
self,
|
1064
1067
|
adata: AnnData,
|
1065
1068
|
split_key: str,
|
1066
1069
|
celltype_key: str,
|
1070
|
+
*,
|
1067
1071
|
split_which: tuple[str, str] = None,
|
1068
1072
|
mcp: str = "mcp_0",
|
1069
|
-
return_fig: bool
|
1070
|
-
|
1071
|
-
save: bool | str | None = None,
|
1072
|
-
show: bool | None = None,
|
1073
|
-
) -> Axes | Figure | None:
|
1073
|
+
return_fig: bool = False,
|
1074
|
+
) -> Figure | None:
|
1074
1075
|
"""Plots split violin plots for a given MCP and split variable.
|
1075
1076
|
|
1076
1077
|
Any cells with a value for split_key not in split_which are removed from the plot.
|
@@ -1081,9 +1082,10 @@ class Dialogue:
|
|
1081
1082
|
celltype_key: Key for cell type annotations.
|
1082
1083
|
split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
|
1083
1084
|
mcp: Key for MCP data.
|
1085
|
+
{common_plot_args}
|
1084
1086
|
|
1085
1087
|
Returns:
|
1086
|
-
|
1088
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1087
1089
|
|
1088
1090
|
Examples:
|
1089
1091
|
>>> import pertpy as pt
|
@@ -1105,30 +1107,24 @@ class Dialogue:
|
|
1105
1107
|
df[split_key] = df[split_key].cat.remove_unused_categories()
|
1106
1108
|
|
1107
1109
|
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
1108
|
-
|
1109
1110
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
1110
1111
|
|
1111
|
-
if save:
|
1112
|
-
plt.savefig(save, bbox_inches="tight")
|
1113
|
-
if show:
|
1114
|
-
plt.show()
|
1115
1112
|
if return_fig:
|
1116
1113
|
return plt.gcf()
|
1117
|
-
|
1118
|
-
return ax
|
1114
|
+
plt.show()
|
1119
1115
|
return None
|
1120
1116
|
|
1117
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1121
1118
|
def plot_pairplot(
|
1122
1119
|
self,
|
1123
1120
|
adata: AnnData,
|
1124
1121
|
celltype_key: str,
|
1125
1122
|
color: str,
|
1126
1123
|
sample_id: str,
|
1124
|
+
*,
|
1127
1125
|
mcp: str = "mcp_0",
|
1128
|
-
return_fig: bool
|
1129
|
-
|
1130
|
-
save: bool | str | None = None,
|
1131
|
-
) -> PairGrid | Figure | None:
|
1126
|
+
return_fig: bool = False,
|
1127
|
+
) -> Figure | None:
|
1132
1128
|
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
1133
1129
|
|
1134
1130
|
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
|
@@ -1140,9 +1136,10 @@ class Dialogue:
|
|
1140
1136
|
color: Key in `adata.obs` for color annotations. This parameter is used as the hue
|
1141
1137
|
sample_id: Key in `adata.obs` for the sample annotations.
|
1142
1138
|
mcp: Key in `adata.obs` for MCP feature values.
|
1139
|
+
{common_plot_args}
|
1143
1140
|
|
1144
1141
|
Returns:
|
1145
|
-
|
1142
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1146
1143
|
|
1147
1144
|
Examples:
|
1148
1145
|
>>> import pertpy as pt
|
@@ -1165,14 +1162,9 @@ class Dialogue:
|
|
1165
1162
|
aggstats = aggstats.loc[list(mcp_pivot.index), :]
|
1166
1163
|
aggstats[color] = aggstats["top"]
|
1167
1164
|
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
1168
|
-
|
1165
|
+
sns.pairplot(mcp_pivot, hue=color, corner=True)
|
1169
1166
|
|
1170
|
-
if save:
|
1171
|
-
plt.savefig(save, bbox_inches="tight")
|
1172
|
-
if show:
|
1173
|
-
plt.show()
|
1174
1167
|
if return_fig:
|
1175
1168
|
return plt.gcf()
|
1176
|
-
|
1177
|
-
return ax
|
1169
|
+
plt.show()
|
1178
1170
|
return None
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from ._base import
|
1
|
+
from ._base import LinearModelBase, MethodBase
|
2
2
|
from ._dge_comparison import DGEEVAL
|
3
3
|
from ._edger import EdgeR
|
4
4
|
from ._pydeseq2 import PyDESeq2
|
@@ -14,7 +14,6 @@ __all__ = [
|
|
14
14
|
"SimpleComparisonBase",
|
15
15
|
"WilcoxonTest",
|
16
16
|
"TTest",
|
17
|
-
"ContrastType",
|
18
17
|
]
|
19
18
|
|
20
19
|
AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
|