pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@ from collections.abc import Sequence
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
5
|
-
from
|
5
|
+
from lamin_utils import logger
|
6
6
|
from scipy.sparse import issparse
|
7
7
|
|
8
8
|
from ._base import LinearModelBase
|
@@ -27,16 +27,11 @@ class EdgeR(LinearModelBase):
|
|
27
27
|
# pandas2ri.activate()
|
28
28
|
# rpy2.robjects.numpy2ri.activate()
|
29
29
|
try:
|
30
|
-
import rpy2.robjects.numpy2ri
|
31
|
-
import rpy2.robjects.pandas2ri
|
32
30
|
from rpy2 import robjects as ro
|
33
31
|
from rpy2.robjects import numpy2ri, pandas2ri
|
34
|
-
from rpy2.robjects.conversion import localconverter
|
32
|
+
from rpy2.robjects.conversion import get_conversion, localconverter
|
35
33
|
from rpy2.robjects.packages import importr
|
36
34
|
|
37
|
-
pandas2ri.activate()
|
38
|
-
rpy2.robjects.numpy2ri.activate()
|
39
|
-
|
40
35
|
except ImportError:
|
41
36
|
raise ImportError("edger requires rpy2 to be installed.") from None
|
42
37
|
|
@@ -49,25 +44,30 @@ class EdgeR(LinearModelBase):
|
|
49
44
|
) from e
|
50
45
|
|
51
46
|
# Convert dataframe
|
52
|
-
with localconverter(
|
47
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
53
48
|
expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
|
54
49
|
if issparse(expr):
|
55
50
|
expr = expr.T.toarray()
|
56
51
|
else:
|
57
52
|
expr = expr.T
|
58
53
|
|
59
|
-
|
54
|
+
with localconverter(get_conversion() + pandas2ri.converter):
|
55
|
+
expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
|
56
|
+
samples_r = ro.conversion.py2rpy(self.adata.obs)
|
60
57
|
|
61
|
-
dge = edger.DGEList(counts=expr_r, samples=
|
58
|
+
dge = edger.DGEList(counts=expr_r, samples=samples_r)
|
62
59
|
|
63
|
-
|
60
|
+
logger.info("Calculating NormFactors")
|
64
61
|
dge = edger.calcNormFactors(dge)
|
65
62
|
|
66
|
-
|
67
|
-
|
63
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
64
|
+
design_r = ro.conversion.py2rpy(self.design.values)
|
65
|
+
|
66
|
+
logger.info("Estimating Dispersions")
|
67
|
+
dge = edger.estimateDisp(dge, design=design_r)
|
68
68
|
|
69
|
-
|
70
|
-
fit = edger.glmQLFit(dge, design=
|
69
|
+
logger.info("Fitting linear model")
|
70
|
+
fit = edger.glmQLFit(dge, design=design_r, **kwargs)
|
71
71
|
|
72
72
|
ro.globalenv["fit"] = fit
|
73
73
|
self.fit = fit
|
@@ -88,11 +88,9 @@ class EdgeR(LinearModelBase):
|
|
88
88
|
# Fix mask for .fit()
|
89
89
|
|
90
90
|
try:
|
91
|
-
import rpy2.robjects.numpy2ri
|
92
|
-
import rpy2.robjects.pandas2ri
|
93
91
|
from rpy2 import robjects as ro
|
94
92
|
from rpy2.robjects import numpy2ri, pandas2ri
|
95
|
-
from rpy2.robjects.conversion import localconverter
|
93
|
+
from rpy2.robjects.conversion import get_conversion, localconverter
|
96
94
|
from rpy2.robjects.packages import importr
|
97
95
|
|
98
96
|
except ImportError:
|
@@ -106,7 +104,8 @@ class EdgeR(LinearModelBase):
|
|
106
104
|
) from None
|
107
105
|
|
108
106
|
# Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
|
109
|
-
|
107
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
108
|
+
contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
|
110
109
|
ro.globalenv["contrast_vec"] = contrast_vec_r
|
111
110
|
|
112
111
|
# Test contrast with R
|
@@ -117,8 +116,18 @@ class EdgeR(LinearModelBase):
|
|
117
116
|
"""
|
118
117
|
)
|
119
118
|
|
120
|
-
#
|
121
|
-
de_res = ro.
|
119
|
+
# Retrieve the `de_res` object
|
120
|
+
de_res = ro.globalenv["de_res"]
|
121
|
+
|
122
|
+
# If already a Pandas DataFrame, return it directly
|
123
|
+
if isinstance(de_res, pd.DataFrame):
|
124
|
+
de_res.index.name = "variable"
|
125
|
+
return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
|
126
|
+
|
127
|
+
# Convert to Pandas DataFrame if still an R object
|
128
|
+
with localconverter(get_conversion() + pandas2ri.converter):
|
129
|
+
de_res = ro.conversion.rpy2py(de_res)
|
130
|
+
|
122
131
|
de_res.index.name = "variable"
|
123
132
|
de_res = de_res.reset_index()
|
124
133
|
|
@@ -2,6 +2,7 @@ import os
|
|
2
2
|
import re
|
3
3
|
import warnings
|
4
4
|
|
5
|
+
import numpy as np
|
5
6
|
import pandas as pd
|
6
7
|
from anndata import AnnData
|
7
8
|
from numpy import ndarray
|
@@ -40,33 +41,25 @@ class PyDESeq2(LinearModelBase):
|
|
40
41
|
Args:
|
41
42
|
**kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
|
42
43
|
"""
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
44
|
+
try:
|
45
|
+
usable_cpus = len(os.sched_getaffinity(0))
|
46
|
+
except AttributeError:
|
47
|
+
usable_cpus = os.cpu_count()
|
48
|
+
|
49
|
+
inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", usable_cpus))
|
50
|
+
|
50
51
|
dds = DeseqDataSet(
|
51
|
-
adata=self.adata,
|
52
|
+
adata=self.adata,
|
53
|
+
design=self.design, # initialize using design matrix, not formula
|
54
|
+
refit_cooks=True,
|
55
|
+
inference=inference,
|
56
|
+
**kwargs,
|
52
57
|
)
|
53
|
-
# workaround code to insert design array
|
54
|
-
des_mtx_cols = dds.obsm["design_matrix"].columns
|
55
|
-
dds.obsm["design_matrix"] = self.design
|
56
|
-
if dds.obsm["design_matrix"].shape[1] == len(des_mtx_cols):
|
57
|
-
dds.obsm["design_matrix"].columns = des_mtx_cols.copy()
|
58
58
|
|
59
59
|
dds.deseq2()
|
60
60
|
self.dds = dds
|
61
61
|
|
62
|
-
|
63
|
-
# see https://github.com/owkin/PyDESeq2/issues/213
|
64
|
-
|
65
|
-
# Therefore these functions are overridden in a way to make it work with PyDESeq2,
|
66
|
-
# ingoring the inconsistency of function signatures. Once arbitrary design
|
67
|
-
# matrices and contrasts are supported by PyDEseq2, we can fully support the
|
68
|
-
# Linear model interface.
|
69
|
-
def _test_single_contrast(self, contrast: list[str], alpha=0.05, **kwargs) -> pd.DataFrame: # type: ignore
|
62
|
+
def _test_single_contrast(self, contrast, alpha=0.05, **kwargs) -> pd.DataFrame:
|
70
63
|
"""Conduct a specific test and returns a Pandas DataFrame.
|
71
64
|
|
72
65
|
Args:
|
@@ -74,6 +67,7 @@ class PyDESeq2(LinearModelBase):
|
|
74
67
|
alpha: p value threshold used for controlling fdr with independent hypothesis weighting
|
75
68
|
**kwargs: extra arguments to pass to DeseqStats()
|
76
69
|
"""
|
70
|
+
contrast = np.array(contrast)
|
77
71
|
stat_res = DeseqStats(self.dds, contrast=contrast, alpha=alpha, **kwargs)
|
78
72
|
# Calling `.summary()` is required to fill the `results_df` data frame
|
79
73
|
stat_res.summary()
|
@@ -85,11 +79,3 @@ class PyDESeq2(LinearModelBase):
|
|
85
79
|
res_df.index.name = "variable"
|
86
80
|
res_df = res_df.reset_index()
|
87
81
|
return res_df
|
88
|
-
|
89
|
-
def cond(self, **kwargs) -> ndarray:
|
90
|
-
raise NotImplementedError(
|
91
|
-
"PyDESeq2 currently doesn't support arbitrary contrasts, see https://github.com/owkin/PyDESeq2/issues/213"
|
92
|
-
)
|
93
|
-
|
94
|
-
def contrast(self, column: str, baseline: str, group_to_compare: str) -> tuple[str, str, str]: # type: ignore
|
95
|
-
return (column, group_to_compare, baseline)
|
@@ -59,14 +59,3 @@ class Statsmodels(LinearModelBase):
|
|
59
59
|
}
|
60
60
|
)
|
61
61
|
return pd.DataFrame(res).sort_values("p_value")
|
62
|
-
|
63
|
-
def contrast(self, column: str, baseline: str, group_to_compare: str) -> np.ndarray:
|
64
|
-
"""Build a simple contrast for pairwise comparisons.
|
65
|
-
|
66
|
-
This is equivalent to
|
67
|
-
|
68
|
-
```
|
69
|
-
model.cond(<column> = baseline) - model.cond(<column> = group_to_compare)
|
70
|
-
```
|
71
|
-
"""
|
72
|
-
return self.cond(**{column: baseline}) - self.cond(**{column: group_to_compare})
|
@@ -344,9 +344,9 @@ class Distance:
|
|
344
344
|
else:
|
345
345
|
embedding = adata.obsm[self.obsm_key].copy()
|
346
346
|
for index_x, group_x in enumerate(fct(groups)):
|
347
|
-
cells_x = embedding[grouping == group_x].copy()
|
347
|
+
cells_x = embedding[np.asarray(grouping == group_x)].copy()
|
348
348
|
for group_y in groups[index_x:]: # type: ignore
|
349
|
-
cells_y = embedding[grouping == group_y].copy()
|
349
|
+
cells_y = embedding[np.asarray(grouping == group_y)].copy()
|
350
350
|
if not bootstrap:
|
351
351
|
# By distance axiom, the distance between a group and itself is 0
|
352
352
|
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
@@ -478,9 +478,9 @@ class Distance:
|
|
478
478
|
else:
|
479
479
|
embedding = adata.obsm[self.obsm_key].copy()
|
480
480
|
for group_x in fct(groups):
|
481
|
-
cells_x = embedding[grouping == group_x].copy()
|
481
|
+
cells_x = embedding[np.asarray(grouping == group_x)].copy()
|
482
482
|
group_y = selected_group
|
483
|
-
cells_y = embedding[grouping == group_y].copy()
|
483
|
+
cells_y = embedding[np.asarray(grouping == group_y)].copy()
|
484
484
|
if not bootstrap:
|
485
485
|
# By distance axiom, the distance between a group and itself is 0
|
486
486
|
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
@@ -691,17 +691,18 @@ class MMD(AbstractDistance):
|
|
691
691
|
|
692
692
|
|
693
693
|
class WassersteinDistance(AbstractDistance):
|
694
|
-
"""Wasserstein distance metric (solved with entropy regularized Sinkhorn)."""
|
695
|
-
|
696
694
|
def __init__(self) -> None:
|
697
695
|
super().__init__()
|
698
696
|
self.accepts_precomputed = False
|
699
697
|
|
700
698
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
699
|
+
X = np.asarray(X, dtype=np.float64)
|
700
|
+
Y = np.asarray(Y, dtype=np.float64)
|
701
701
|
geom = PointCloud(X, Y)
|
702
702
|
return self.solve_ot_problem(geom, **kwargs)
|
703
703
|
|
704
704
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
705
|
+
P = np.asarray(P, dtype=np.float64)
|
705
706
|
geom = Geometry(cost_matrix=P[idx, :][:, ~idx])
|
706
707
|
return self.solve_ot_problem(geom, **kwargs)
|
707
708
|
|
@@ -709,7 +710,13 @@ class WassersteinDistance(AbstractDistance):
|
|
709
710
|
ot_prob = LinearProblem(geom)
|
710
711
|
solver = Sinkhorn()
|
711
712
|
ot = solver(ot_prob, **kwargs)
|
712
|
-
|
713
|
+
cost = float(ot.reg_ot_cost)
|
714
|
+
|
715
|
+
# Check for NaN or invalid cost
|
716
|
+
if not np.isfinite(cost):
|
717
|
+
return 1.0
|
718
|
+
else:
|
719
|
+
return cost
|
713
720
|
|
714
721
|
|
715
722
|
class EuclideanDistance(AbstractDistance):
|
@@ -981,7 +988,7 @@ class NBLL(AbstractDistance):
|
|
981
988
|
try:
|
982
989
|
nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
|
983
990
|
return _compute_nll(y, nb_params, epsilon)
|
984
|
-
except np.linalg.
|
991
|
+
except np.linalg.LinAlgError:
|
985
992
|
if x.mean() < 10 and y.mean() < 10:
|
986
993
|
return 0.0
|
987
994
|
else:
|
@@ -1110,67 +1117,75 @@ class MeanVarDistributionDistance(AbstractDistance):
|
|
1110
1117
|
super().__init__()
|
1111
1118
|
self.accepts_precomputed = False
|
1112
1119
|
|
1120
|
+
@staticmethod
|
1121
|
+
def _mean_var(x, log: bool = False):
|
1122
|
+
mean = np.mean(x, axis=0)
|
1123
|
+
var = np.var(x, axis=0)
|
1124
|
+
positive = mean > 0
|
1125
|
+
mean = mean[positive]
|
1126
|
+
var = var[positive]
|
1127
|
+
if log:
|
1128
|
+
mean = np.log(mean)
|
1129
|
+
var = np.log(var)
|
1130
|
+
return mean, var
|
1131
|
+
|
1132
|
+
@staticmethod
|
1133
|
+
def _prep_kde_data(x, y):
|
1134
|
+
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1135
|
+
|
1136
|
+
@staticmethod
|
1137
|
+
def _grid_points(d, n_points=100):
|
1138
|
+
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1139
|
+
d_min = d.min()
|
1140
|
+
d_max = d.max()
|
1141
|
+
# Compute bin size
|
1142
|
+
d_bin = (d_max - d_min) / (n_points - 2)
|
1143
|
+
d_min = d_min - d_bin
|
1144
|
+
d_max = d_max + d_bin
|
1145
|
+
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1146
|
+
|
1147
|
+
@staticmethod
|
1148
|
+
def _kde_eval_both(x_kde, y_kde, grid):
|
1149
|
+
n_points = len(grid)
|
1150
|
+
chunk_size = 10000
|
1151
|
+
|
1152
|
+
result_x = np.zeros(n_points)
|
1153
|
+
result_y = np.zeros(n_points)
|
1154
|
+
|
1155
|
+
# Process same chunks for both KDEs
|
1156
|
+
for start in range(0, n_points, chunk_size):
|
1157
|
+
end = min(start + chunk_size, n_points)
|
1158
|
+
chunk = grid[start:end]
|
1159
|
+
result_x[start:end] = x_kde.score_samples(chunk)
|
1160
|
+
result_y[start:end] = y_kde.score_samples(chunk)
|
1161
|
+
|
1162
|
+
return result_x, result_y
|
1163
|
+
|
1113
1164
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1114
1165
|
"""Difference of mean-var distributions in 2 matrices.
|
1115
|
-
|
1116
1166
|
Args:
|
1117
1167
|
X: Normalized and log transformed cells x genes count matrix.
|
1118
1168
|
Y: Normalized and log transformed cells x genes count matrix.
|
1119
1169
|
"""
|
1170
|
+
mean_x, var_x = self._mean_var(X, log=True)
|
1171
|
+
mean_y, var_y = self._mean_var(Y, log=True)
|
1120
1172
|
|
1121
|
-
|
1122
|
-
|
1123
|
-
var = np.var(x, axis=0)
|
1124
|
-
positive = mean > 0
|
1125
|
-
mean = mean[positive]
|
1126
|
-
var = var[positive]
|
1127
|
-
if log:
|
1128
|
-
mean = np.log(mean)
|
1129
|
-
var = np.log(var)
|
1130
|
-
return mean, var
|
1131
|
-
|
1132
|
-
def _prep_kde_data(x, y):
|
1133
|
-
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1134
|
-
|
1135
|
-
def _grid_points(d, n_points=100):
|
1136
|
-
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1137
|
-
d_min = d.min()
|
1138
|
-
d_max = d.max()
|
1139
|
-
# Compute bin size
|
1140
|
-
d_bin = (d_max - d_min) / (n_points - 2)
|
1141
|
-
d_min = d_min - d_bin
|
1142
|
-
d_max = d_max + d_bin
|
1143
|
-
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1144
|
-
|
1145
|
-
def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
|
1146
|
-
# the thread_count is determined using the factor 0.875 as recommended here:
|
1147
|
-
# https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
|
1148
|
-
with multiprocessing.Pool(thread_count) as p:
|
1149
|
-
return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
|
1150
|
-
|
1151
|
-
def _kde_eval(d, grid):
|
1152
|
-
# Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
|
1153
|
-
# can not be compared well on regions further away from the data as they are -inf
|
1154
|
-
kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
|
1155
|
-
return _parallel_score_samples(kde, grid)
|
1156
|
-
|
1157
|
-
mean_x, var_x = _mean_var(X, log=True)
|
1158
|
-
mean_y, var_y = _mean_var(Y, log=True)
|
1159
|
-
|
1160
|
-
x = _prep_kde_data(mean_x, var_x)
|
1161
|
-
y = _prep_kde_data(mean_y, var_y)
|
1173
|
+
x = self._prep_kde_data(mean_x, var_x)
|
1174
|
+
y = self._prep_kde_data(mean_y, var_y)
|
1162
1175
|
|
1163
1176
|
# Gridpoints to eval KDE on
|
1164
|
-
mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
|
1165
|
-
var_grid = _grid_points(np.concatenate([var_x, var_y]))
|
1177
|
+
mean_grid = self._grid_points(np.concatenate([mean_x, mean_y]))
|
1178
|
+
var_grid = self._grid_points(np.concatenate([var_x, var_y]))
|
1166
1179
|
grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
|
1167
1180
|
|
1168
|
-
|
1169
|
-
|
1181
|
+
# Fit both KDEs first
|
1182
|
+
x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x)
|
1183
|
+
y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y)
|
1170
1184
|
|
1171
|
-
|
1185
|
+
# Evaluate both KDEs on same grid chunks
|
1186
|
+
kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid)
|
1172
1187
|
|
1173
|
-
return
|
1188
|
+
return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean()
|
1174
1189
|
|
1175
1190
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1176
1191
|
raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
|
pertpy/tools/_enrichment.py
CHANGED
@@ -3,6 +3,7 @@ from collections.abc import Sequence
|
|
3
3
|
from typing import Any, Literal
|
4
4
|
|
5
5
|
import blitzgsea
|
6
|
+
import matplotlib.pyplot as plt
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
import scanpy as sc
|
@@ -14,6 +15,7 @@ from scipy.sparse import issparse
|
|
14
15
|
from scipy.stats import hypergeom
|
15
16
|
from statsmodels.stats.multitest import multipletests
|
16
17
|
|
18
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
19
|
from pertpy.metadata import Drug
|
18
20
|
|
19
21
|
|
@@ -290,9 +292,11 @@ class Enrichment:
|
|
290
292
|
|
291
293
|
return enrichment
|
292
294
|
|
295
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
293
296
|
def plot_dotplot(
|
294
297
|
self,
|
295
298
|
adata: AnnData,
|
299
|
+
*,
|
296
300
|
targets: dict[str, dict[str, list[str]]] = None,
|
297
301
|
source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
|
298
302
|
category_name: str = "interaction_type",
|
@@ -300,10 +304,9 @@ class Enrichment:
|
|
300
304
|
groupby: str = None,
|
301
305
|
key: str = "pertpy_enrichment",
|
302
306
|
ax: Axes | None = None,
|
303
|
-
|
304
|
-
show: bool | None = None,
|
307
|
+
return_fig: bool = False,
|
305
308
|
**kwargs,
|
306
|
-
) -> DotPlot |
|
309
|
+
) -> DotPlot | None:
|
307
310
|
"""Plots a dotplot by groupby and categories.
|
308
311
|
|
309
312
|
Wraps scanpy's dotplot but formats it nicely by categories.
|
@@ -319,11 +322,11 @@ class Enrichment:
|
|
319
322
|
category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`.
|
320
323
|
groupby: dotplot groupby such as clusters or cell types.
|
321
324
|
key: Prefix key of enrichment results in `uns`.
|
325
|
+
{common_plot_args}
|
322
326
|
kwargs: Passed to scanpy dotplot.
|
323
327
|
|
324
328
|
Returns:
|
325
|
-
If `return_fig` is `True`, returns
|
326
|
-
else if `show` is false, return axes dict.
|
329
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
327
330
|
|
328
331
|
Examples:
|
329
332
|
>>> import pertpy as pt
|
@@ -403,21 +406,26 @@ class Enrichment:
|
|
403
406
|
"var_group_labels": var_group_labels,
|
404
407
|
}
|
405
408
|
|
406
|
-
|
409
|
+
fig = sc.pl.dotplot(
|
407
410
|
enrichment_score_adata,
|
408
411
|
groupby=groupby,
|
409
412
|
swap_axes=True,
|
410
413
|
ax=ax,
|
411
|
-
|
412
|
-
show=show,
|
414
|
+
show=False,
|
413
415
|
**plot_args,
|
414
416
|
**kwargs,
|
415
417
|
)
|
416
418
|
|
419
|
+
if return_fig:
|
420
|
+
return fig
|
421
|
+
plt.show()
|
422
|
+
return None
|
423
|
+
|
417
424
|
def plot_gsea(
|
418
425
|
self,
|
419
426
|
adata: AnnData,
|
420
427
|
enrichment: dict[str, pd.DataFrame],
|
428
|
+
*,
|
421
429
|
n: int = 10,
|
422
430
|
key: str = "pertpy_enrichment_gsea",
|
423
431
|
interactive_plot: bool = False,
|