pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@ from collections.abc import Sequence
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
5
|
-
from
|
5
|
+
from lamin_utils import logger
|
6
6
|
from scipy.sparse import issparse
|
7
7
|
|
8
8
|
from ._base import LinearModelBase
|
@@ -27,16 +27,11 @@ class EdgeR(LinearModelBase):
|
|
27
27
|
# pandas2ri.activate()
|
28
28
|
# rpy2.robjects.numpy2ri.activate()
|
29
29
|
try:
|
30
|
-
import rpy2.robjects.numpy2ri
|
31
|
-
import rpy2.robjects.pandas2ri
|
32
30
|
from rpy2 import robjects as ro
|
33
31
|
from rpy2.robjects import numpy2ri, pandas2ri
|
34
|
-
from rpy2.robjects.conversion import localconverter
|
32
|
+
from rpy2.robjects.conversion import get_conversion, localconverter
|
35
33
|
from rpy2.robjects.packages import importr
|
36
34
|
|
37
|
-
pandas2ri.activate()
|
38
|
-
rpy2.robjects.numpy2ri.activate()
|
39
|
-
|
40
35
|
except ImportError:
|
41
36
|
raise ImportError("edger requires rpy2 to be installed.") from None
|
42
37
|
|
@@ -49,25 +44,30 @@ class EdgeR(LinearModelBase):
|
|
49
44
|
) from e
|
50
45
|
|
51
46
|
# Convert dataframe
|
52
|
-
with localconverter(
|
47
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
53
48
|
expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
|
54
49
|
if issparse(expr):
|
55
50
|
expr = expr.T.toarray()
|
56
51
|
else:
|
57
52
|
expr = expr.T
|
58
53
|
|
59
|
-
|
54
|
+
with localconverter(get_conversion() + pandas2ri.converter):
|
55
|
+
expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
|
56
|
+
samples_r = ro.conversion.py2rpy(self.adata.obs)
|
60
57
|
|
61
|
-
dge = edger.DGEList(counts=expr_r, samples=
|
58
|
+
dge = edger.DGEList(counts=expr_r, samples=samples_r)
|
62
59
|
|
63
|
-
|
60
|
+
logger.info("Calculating NormFactors")
|
64
61
|
dge = edger.calcNormFactors(dge)
|
65
62
|
|
66
|
-
|
67
|
-
|
63
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
64
|
+
design_r = ro.conversion.py2rpy(self.design.values)
|
65
|
+
|
66
|
+
logger.info("Estimating Dispersions")
|
67
|
+
dge = edger.estimateDisp(dge, design=design_r)
|
68
68
|
|
69
|
-
|
70
|
-
fit = edger.glmQLFit(dge, design=
|
69
|
+
logger.info("Fitting linear model")
|
70
|
+
fit = edger.glmQLFit(dge, design=design_r, **kwargs)
|
71
71
|
|
72
72
|
ro.globalenv["fit"] = fit
|
73
73
|
self.fit = fit
|
@@ -88,11 +88,9 @@ class EdgeR(LinearModelBase):
|
|
88
88
|
# Fix mask for .fit()
|
89
89
|
|
90
90
|
try:
|
91
|
-
import rpy2.robjects.numpy2ri
|
92
|
-
import rpy2.robjects.pandas2ri
|
93
91
|
from rpy2 import robjects as ro
|
94
92
|
from rpy2.robjects import numpy2ri, pandas2ri
|
95
|
-
from rpy2.robjects.conversion import localconverter
|
93
|
+
from rpy2.robjects.conversion import get_conversion, localconverter
|
96
94
|
from rpy2.robjects.packages import importr
|
97
95
|
|
98
96
|
except ImportError:
|
@@ -106,7 +104,8 @@ class EdgeR(LinearModelBase):
|
|
106
104
|
) from None
|
107
105
|
|
108
106
|
# Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
|
109
|
-
|
107
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
108
|
+
contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
|
110
109
|
ro.globalenv["contrast_vec"] = contrast_vec_r
|
111
110
|
|
112
111
|
# Test contrast with R
|
@@ -117,8 +116,18 @@ class EdgeR(LinearModelBase):
|
|
117
116
|
"""
|
118
117
|
)
|
119
118
|
|
120
|
-
#
|
121
|
-
de_res = ro.
|
119
|
+
# Retrieve the `de_res` object
|
120
|
+
de_res = ro.globalenv["de_res"]
|
121
|
+
|
122
|
+
# If already a Pandas DataFrame, return it directly
|
123
|
+
if isinstance(de_res, pd.DataFrame):
|
124
|
+
de_res.index.name = "variable"
|
125
|
+
return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
|
126
|
+
|
127
|
+
# Convert to Pandas DataFrame if still an R object
|
128
|
+
with localconverter(get_conversion() + pandas2ri.converter):
|
129
|
+
de_res = ro.conversion.rpy2py(de_res)
|
130
|
+
|
122
131
|
de_res.index.name = "variable"
|
123
132
|
de_res = de_res.reset_index()
|
124
133
|
|
@@ -2,6 +2,7 @@ import os
|
|
2
2
|
import re
|
3
3
|
import warnings
|
4
4
|
|
5
|
+
import numpy as np
|
5
6
|
import pandas as pd
|
6
7
|
from anndata import AnnData
|
7
8
|
from numpy import ndarray
|
@@ -40,33 +41,25 @@ class PyDESeq2(LinearModelBase):
|
|
40
41
|
Args:
|
41
42
|
**kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
|
42
43
|
"""
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
44
|
+
try:
|
45
|
+
usable_cpus = len(os.sched_getaffinity(0))
|
46
|
+
except AttributeError:
|
47
|
+
usable_cpus = os.cpu_count()
|
48
|
+
|
49
|
+
inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", usable_cpus))
|
50
|
+
|
50
51
|
dds = DeseqDataSet(
|
51
|
-
adata=self.adata,
|
52
|
+
adata=self.adata,
|
53
|
+
design=self.design, # initialize using design matrix, not formula
|
54
|
+
refit_cooks=True,
|
55
|
+
inference=inference,
|
56
|
+
**kwargs,
|
52
57
|
)
|
53
|
-
# workaround code to insert design array
|
54
|
-
des_mtx_cols = dds.obsm["design_matrix"].columns
|
55
|
-
dds.obsm["design_matrix"] = self.design
|
56
|
-
if dds.obsm["design_matrix"].shape[1] == len(des_mtx_cols):
|
57
|
-
dds.obsm["design_matrix"].columns = des_mtx_cols.copy()
|
58
58
|
|
59
59
|
dds.deseq2()
|
60
60
|
self.dds = dds
|
61
61
|
|
62
|
-
|
63
|
-
# see https://github.com/owkin/PyDESeq2/issues/213
|
64
|
-
|
65
|
-
# Therefore these functions are overridden in a way to make it work with PyDESeq2,
|
66
|
-
# ingoring the inconsistency of function signatures. Once arbitrary design
|
67
|
-
# matrices and contrasts are supported by PyDEseq2, we can fully support the
|
68
|
-
# Linear model interface.
|
69
|
-
def _test_single_contrast(self, contrast: list[str], alpha=0.05, **kwargs) -> pd.DataFrame: # type: ignore
|
62
|
+
def _test_single_contrast(self, contrast, alpha=0.05, **kwargs) -> pd.DataFrame:
|
70
63
|
"""Conduct a specific test and returns a Pandas DataFrame.
|
71
64
|
|
72
65
|
Args:
|
@@ -74,6 +67,7 @@ class PyDESeq2(LinearModelBase):
|
|
74
67
|
alpha: p value threshold used for controlling fdr with independent hypothesis weighting
|
75
68
|
**kwargs: extra arguments to pass to DeseqStats()
|
76
69
|
"""
|
70
|
+
contrast = np.array(contrast)
|
77
71
|
stat_res = DeseqStats(self.dds, contrast=contrast, alpha=alpha, **kwargs)
|
78
72
|
# Calling `.summary()` is required to fill the `results_df` data frame
|
79
73
|
stat_res.summary()
|
@@ -85,11 +79,3 @@ class PyDESeq2(LinearModelBase):
|
|
85
79
|
res_df.index.name = "variable"
|
86
80
|
res_df = res_df.reset_index()
|
87
81
|
return res_df
|
88
|
-
|
89
|
-
def cond(self, **kwargs) -> ndarray:
|
90
|
-
raise NotImplementedError(
|
91
|
-
"PyDESeq2 currently doesn't support arbitrary contrasts, see https://github.com/owkin/PyDESeq2/issues/213"
|
92
|
-
)
|
93
|
-
|
94
|
-
def contrast(self, column: str, baseline: str, group_to_compare: str) -> tuple[str, str, str]: # type: ignore
|
95
|
-
return (column, group_to_compare, baseline)
|
@@ -59,14 +59,3 @@ class Statsmodels(LinearModelBase):
|
|
59
59
|
}
|
60
60
|
)
|
61
61
|
return pd.DataFrame(res).sort_values("p_value")
|
62
|
-
|
63
|
-
def contrast(self, column: str, baseline: str, group_to_compare: str) -> np.ndarray:
|
64
|
-
"""Build a simple contrast for pairwise comparisons.
|
65
|
-
|
66
|
-
This is equivalent to
|
67
|
-
|
68
|
-
```
|
69
|
-
model.cond(<column> = baseline) - model.cond(<column> = group_to_compare)
|
70
|
-
```
|
71
|
-
"""
|
72
|
-
return self.cond(**{column: baseline}) - self.cond(**{column: group_to_compare})
|
@@ -344,9 +344,9 @@ class Distance:
|
|
344
344
|
else:
|
345
345
|
embedding = adata.obsm[self.obsm_key].copy()
|
346
346
|
for index_x, group_x in enumerate(fct(groups)):
|
347
|
-
cells_x = embedding[grouping == group_x].copy()
|
347
|
+
cells_x = embedding[np.asarray(grouping == group_x)].copy()
|
348
348
|
for group_y in groups[index_x:]: # type: ignore
|
349
|
-
cells_y = embedding[grouping == group_y].copy()
|
349
|
+
cells_y = embedding[np.asarray(grouping == group_y)].copy()
|
350
350
|
if not bootstrap:
|
351
351
|
# By distance axiom, the distance between a group and itself is 0
|
352
352
|
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
@@ -478,9 +478,9 @@ class Distance:
|
|
478
478
|
else:
|
479
479
|
embedding = adata.obsm[self.obsm_key].copy()
|
480
480
|
for group_x in fct(groups):
|
481
|
-
cells_x = embedding[grouping == group_x].copy()
|
481
|
+
cells_x = embedding[np.asarray(grouping == group_x)].copy()
|
482
482
|
group_y = selected_group
|
483
|
-
cells_y = embedding[grouping == group_y].copy()
|
483
|
+
cells_y = embedding[np.asarray(grouping == group_y)].copy()
|
484
484
|
if not bootstrap:
|
485
485
|
# By distance axiom, the distance between a group and itself is 0
|
486
486
|
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
@@ -691,17 +691,18 @@ class MMD(AbstractDistance):
|
|
691
691
|
|
692
692
|
|
693
693
|
class WassersteinDistance(AbstractDistance):
|
694
|
-
"""Wasserstein distance metric (solved with entropy regularized Sinkhorn)."""
|
695
|
-
|
696
694
|
def __init__(self) -> None:
|
697
695
|
super().__init__()
|
698
696
|
self.accepts_precomputed = False
|
699
697
|
|
700
698
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
699
|
+
X = np.asarray(X, dtype=np.float64)
|
700
|
+
Y = np.asarray(Y, dtype=np.float64)
|
701
701
|
geom = PointCloud(X, Y)
|
702
702
|
return self.solve_ot_problem(geom, **kwargs)
|
703
703
|
|
704
704
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
705
|
+
P = np.asarray(P, dtype=np.float64)
|
705
706
|
geom = Geometry(cost_matrix=P[idx, :][:, ~idx])
|
706
707
|
return self.solve_ot_problem(geom, **kwargs)
|
707
708
|
|
@@ -709,7 +710,13 @@ class WassersteinDistance(AbstractDistance):
|
|
709
710
|
ot_prob = LinearProblem(geom)
|
710
711
|
solver = Sinkhorn()
|
711
712
|
ot = solver(ot_prob, **kwargs)
|
712
|
-
|
713
|
+
cost = float(ot.reg_ot_cost)
|
714
|
+
|
715
|
+
# Check for NaN or invalid cost
|
716
|
+
if not np.isfinite(cost):
|
717
|
+
return 1.0
|
718
|
+
else:
|
719
|
+
return cost
|
713
720
|
|
714
721
|
|
715
722
|
class EuclideanDistance(AbstractDistance):
|
@@ -981,7 +988,7 @@ class NBLL(AbstractDistance):
|
|
981
988
|
try:
|
982
989
|
nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
|
983
990
|
return _compute_nll(y, nb_params, epsilon)
|
984
|
-
except np.linalg.
|
991
|
+
except np.linalg.LinAlgError:
|
985
992
|
if x.mean() < 10 and y.mean() < 10:
|
986
993
|
return 0.0
|
987
994
|
else:
|
@@ -1110,67 +1117,75 @@ class MeanVarDistributionDistance(AbstractDistance):
|
|
1110
1117
|
super().__init__()
|
1111
1118
|
self.accepts_precomputed = False
|
1112
1119
|
|
1120
|
+
@staticmethod
|
1121
|
+
def _mean_var(x, log: bool = False):
|
1122
|
+
mean = np.mean(x, axis=0)
|
1123
|
+
var = np.var(x, axis=0)
|
1124
|
+
positive = mean > 0
|
1125
|
+
mean = mean[positive]
|
1126
|
+
var = var[positive]
|
1127
|
+
if log:
|
1128
|
+
mean = np.log(mean)
|
1129
|
+
var = np.log(var)
|
1130
|
+
return mean, var
|
1131
|
+
|
1132
|
+
@staticmethod
|
1133
|
+
def _prep_kde_data(x, y):
|
1134
|
+
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1135
|
+
|
1136
|
+
@staticmethod
|
1137
|
+
def _grid_points(d, n_points=100):
|
1138
|
+
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1139
|
+
d_min = d.min()
|
1140
|
+
d_max = d.max()
|
1141
|
+
# Compute bin size
|
1142
|
+
d_bin = (d_max - d_min) / (n_points - 2)
|
1143
|
+
d_min = d_min - d_bin
|
1144
|
+
d_max = d_max + d_bin
|
1145
|
+
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1146
|
+
|
1147
|
+
@staticmethod
|
1148
|
+
def _kde_eval_both(x_kde, y_kde, grid):
|
1149
|
+
n_points = len(grid)
|
1150
|
+
chunk_size = 10000
|
1151
|
+
|
1152
|
+
result_x = np.zeros(n_points)
|
1153
|
+
result_y = np.zeros(n_points)
|
1154
|
+
|
1155
|
+
# Process same chunks for both KDEs
|
1156
|
+
for start in range(0, n_points, chunk_size):
|
1157
|
+
end = min(start + chunk_size, n_points)
|
1158
|
+
chunk = grid[start:end]
|
1159
|
+
result_x[start:end] = x_kde.score_samples(chunk)
|
1160
|
+
result_y[start:end] = y_kde.score_samples(chunk)
|
1161
|
+
|
1162
|
+
return result_x, result_y
|
1163
|
+
|
1113
1164
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1114
1165
|
"""Difference of mean-var distributions in 2 matrices.
|
1115
|
-
|
1116
1166
|
Args:
|
1117
1167
|
X: Normalized and log transformed cells x genes count matrix.
|
1118
1168
|
Y: Normalized and log transformed cells x genes count matrix.
|
1119
1169
|
"""
|
1170
|
+
mean_x, var_x = self._mean_var(X, log=True)
|
1171
|
+
mean_y, var_y = self._mean_var(Y, log=True)
|
1120
1172
|
|
1121
|
-
|
1122
|
-
|
1123
|
-
var = np.var(x, axis=0)
|
1124
|
-
positive = mean > 0
|
1125
|
-
mean = mean[positive]
|
1126
|
-
var = var[positive]
|
1127
|
-
if log:
|
1128
|
-
mean = np.log(mean)
|
1129
|
-
var = np.log(var)
|
1130
|
-
return mean, var
|
1131
|
-
|
1132
|
-
def _prep_kde_data(x, y):
|
1133
|
-
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1134
|
-
|
1135
|
-
def _grid_points(d, n_points=100):
|
1136
|
-
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1137
|
-
d_min = d.min()
|
1138
|
-
d_max = d.max()
|
1139
|
-
# Compute bin size
|
1140
|
-
d_bin = (d_max - d_min) / (n_points - 2)
|
1141
|
-
d_min = d_min - d_bin
|
1142
|
-
d_max = d_max + d_bin
|
1143
|
-
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1144
|
-
|
1145
|
-
def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
|
1146
|
-
# the thread_count is determined using the factor 0.875 as recommended here:
|
1147
|
-
# https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
|
1148
|
-
with multiprocessing.Pool(thread_count) as p:
|
1149
|
-
return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
|
1150
|
-
|
1151
|
-
def _kde_eval(d, grid):
|
1152
|
-
# Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
|
1153
|
-
# can not be compared well on regions further away from the data as they are -inf
|
1154
|
-
kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
|
1155
|
-
return _parallel_score_samples(kde, grid)
|
1156
|
-
|
1157
|
-
mean_x, var_x = _mean_var(X, log=True)
|
1158
|
-
mean_y, var_y = _mean_var(Y, log=True)
|
1159
|
-
|
1160
|
-
x = _prep_kde_data(mean_x, var_x)
|
1161
|
-
y = _prep_kde_data(mean_y, var_y)
|
1173
|
+
x = self._prep_kde_data(mean_x, var_x)
|
1174
|
+
y = self._prep_kde_data(mean_y, var_y)
|
1162
1175
|
|
1163
1176
|
# Gridpoints to eval KDE on
|
1164
|
-
mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
|
1165
|
-
var_grid = _grid_points(np.concatenate([var_x, var_y]))
|
1177
|
+
mean_grid = self._grid_points(np.concatenate([mean_x, mean_y]))
|
1178
|
+
var_grid = self._grid_points(np.concatenate([var_x, var_y]))
|
1166
1179
|
grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
|
1167
1180
|
|
1168
|
-
|
1169
|
-
|
1181
|
+
# Fit both KDEs first
|
1182
|
+
x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x)
|
1183
|
+
y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y)
|
1170
1184
|
|
1171
|
-
|
1185
|
+
# Evaluate both KDEs on same grid chunks
|
1186
|
+
kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid)
|
1172
1187
|
|
1173
|
-
return
|
1188
|
+
return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean()
|
1174
1189
|
|
1175
1190
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1176
1191
|
raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
|
pertpy/tools/_enrichment.py
CHANGED
@@ -3,6 +3,7 @@ from collections.abc import Sequence
|
|
3
3
|
from typing import Any, Literal
|
4
4
|
|
5
5
|
import blitzgsea
|
6
|
+
import matplotlib.pyplot as plt
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
import scanpy as sc
|
@@ -14,6 +15,7 @@ from scipy.sparse import issparse
|
|
14
15
|
from scipy.stats import hypergeom
|
15
16
|
from statsmodels.stats.multitest import multipletests
|
16
17
|
|
18
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
19
|
from pertpy.metadata import Drug
|
18
20
|
|
19
21
|
|
@@ -290,9 +292,11 @@ class Enrichment:
|
|
290
292
|
|
291
293
|
return enrichment
|
292
294
|
|
295
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
293
296
|
def plot_dotplot(
|
294
297
|
self,
|
295
298
|
adata: AnnData,
|
299
|
+
*,
|
296
300
|
targets: dict[str, dict[str, list[str]]] = None,
|
297
301
|
source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
|
298
302
|
category_name: str = "interaction_type",
|
@@ -300,10 +304,9 @@ class Enrichment:
|
|
300
304
|
groupby: str = None,
|
301
305
|
key: str = "pertpy_enrichment",
|
302
306
|
ax: Axes | None = None,
|
303
|
-
|
304
|
-
show: bool | None = None,
|
307
|
+
return_fig: bool = False,
|
305
308
|
**kwargs,
|
306
|
-
) -> DotPlot |
|
309
|
+
) -> DotPlot | None:
|
307
310
|
"""Plots a dotplot by groupby and categories.
|
308
311
|
|
309
312
|
Wraps scanpy's dotplot but formats it nicely by categories.
|
@@ -319,11 +322,11 @@ class Enrichment:
|
|
319
322
|
category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`.
|
320
323
|
groupby: dotplot groupby such as clusters or cell types.
|
321
324
|
key: Prefix key of enrichment results in `uns`.
|
325
|
+
{common_plot_args}
|
322
326
|
kwargs: Passed to scanpy dotplot.
|
323
327
|
|
324
328
|
Returns:
|
325
|
-
If `return_fig` is `True`, returns
|
326
|
-
else if `show` is false, return axes dict.
|
329
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
327
330
|
|
328
331
|
Examples:
|
329
332
|
>>> import pertpy as pt
|
@@ -403,21 +406,26 @@ class Enrichment:
|
|
403
406
|
"var_group_labels": var_group_labels,
|
404
407
|
}
|
405
408
|
|
406
|
-
|
409
|
+
fig = sc.pl.dotplot(
|
407
410
|
enrichment_score_adata,
|
408
411
|
groupby=groupby,
|
409
412
|
swap_axes=True,
|
410
413
|
ax=ax,
|
411
|
-
|
412
|
-
show=show,
|
414
|
+
show=False,
|
413
415
|
**plot_args,
|
414
416
|
**kwargs,
|
415
417
|
)
|
416
418
|
|
419
|
+
if return_fig:
|
420
|
+
return fig
|
421
|
+
plt.show()
|
422
|
+
return None
|
423
|
+
|
417
424
|
def plot_gsea(
|
418
425
|
self,
|
419
426
|
adata: AnnData,
|
420
427
|
enrichment: dict[str, pd.DataFrame],
|
428
|
+
*,
|
421
429
|
n: int = 10,
|
422
430
|
key: str = "pertpy_enrichment_gsea",
|
423
431
|
interactive_plot: bool = False,
|