pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +20 -0
- pertpy/data/_dataloader.py +4 -4
- pertpy/data/_datasets.py +3 -3
- pertpy/metadata/_cell_line.py +19 -7
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +19 -6
- pertpy/tools/__init__.py +12 -15
- pertpy/tools/_augur.py +36 -46
- pertpy/tools/_cinemaot.py +24 -18
- pertpy/tools/_coda/_base_coda.py +87 -106
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +495 -113
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +15 -8
- pertpy/tools/_enrichment.py +18 -8
- pertpy/tools/_milo.py +58 -46
- pertpy/tools/_mixscape.py +111 -100
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +50 -0
- pertpy/tools/_scgen/_scgen.py +35 -25
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/METADATA +5 -4
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/RECORD +29 -29
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@ from collections.abc import Sequence
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
5
|
-
from
|
5
|
+
from lamin_utils import logger
|
6
6
|
from scipy.sparse import issparse
|
7
7
|
|
8
8
|
from ._base import LinearModelBase
|
@@ -27,16 +27,11 @@ class EdgeR(LinearModelBase):
|
|
27
27
|
# pandas2ri.activate()
|
28
28
|
# rpy2.robjects.numpy2ri.activate()
|
29
29
|
try:
|
30
|
-
import rpy2.robjects.numpy2ri
|
31
|
-
import rpy2.robjects.pandas2ri
|
32
30
|
from rpy2 import robjects as ro
|
33
31
|
from rpy2.robjects import numpy2ri, pandas2ri
|
34
|
-
from rpy2.robjects.conversion import localconverter
|
32
|
+
from rpy2.robjects.conversion import get_conversion, localconverter
|
35
33
|
from rpy2.robjects.packages import importr
|
36
34
|
|
37
|
-
pandas2ri.activate()
|
38
|
-
rpy2.robjects.numpy2ri.activate()
|
39
|
-
|
40
35
|
except ImportError:
|
41
36
|
raise ImportError("edger requires rpy2 to be installed.") from None
|
42
37
|
|
@@ -49,25 +44,30 @@ class EdgeR(LinearModelBase):
|
|
49
44
|
) from e
|
50
45
|
|
51
46
|
# Convert dataframe
|
52
|
-
with localconverter(
|
47
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
53
48
|
expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
|
54
49
|
if issparse(expr):
|
55
50
|
expr = expr.T.toarray()
|
56
51
|
else:
|
57
52
|
expr = expr.T
|
58
53
|
|
59
|
-
|
54
|
+
with localconverter(get_conversion() + pandas2ri.converter):
|
55
|
+
expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
|
56
|
+
samples_r = ro.conversion.py2rpy(self.adata.obs)
|
60
57
|
|
61
|
-
dge = edger.DGEList(counts=expr_r, samples=
|
58
|
+
dge = edger.DGEList(counts=expr_r, samples=samples_r)
|
62
59
|
|
63
|
-
|
60
|
+
logger.info("Calculating NormFactors")
|
64
61
|
dge = edger.calcNormFactors(dge)
|
65
62
|
|
66
|
-
|
67
|
-
|
63
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
64
|
+
design_r = ro.conversion.py2rpy(self.design.values)
|
65
|
+
|
66
|
+
logger.info("Estimating Dispersions")
|
67
|
+
dge = edger.estimateDisp(dge, design=design_r)
|
68
68
|
|
69
|
-
|
70
|
-
fit = edger.glmQLFit(dge, design=
|
69
|
+
logger.info("Fitting linear model")
|
70
|
+
fit = edger.glmQLFit(dge, design=design_r, **kwargs)
|
71
71
|
|
72
72
|
ro.globalenv["fit"] = fit
|
73
73
|
self.fit = fit
|
@@ -88,11 +88,9 @@ class EdgeR(LinearModelBase):
|
|
88
88
|
# Fix mask for .fit()
|
89
89
|
|
90
90
|
try:
|
91
|
-
import rpy2.robjects.numpy2ri
|
92
|
-
import rpy2.robjects.pandas2ri
|
93
91
|
from rpy2 import robjects as ro
|
94
92
|
from rpy2.robjects import numpy2ri, pandas2ri
|
95
|
-
from rpy2.robjects.conversion import localconverter
|
93
|
+
from rpy2.robjects.conversion import get_conversion, localconverter
|
96
94
|
from rpy2.robjects.packages import importr
|
97
95
|
|
98
96
|
except ImportError:
|
@@ -106,7 +104,8 @@ class EdgeR(LinearModelBase):
|
|
106
104
|
) from None
|
107
105
|
|
108
106
|
# Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
|
109
|
-
|
107
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
108
|
+
contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
|
110
109
|
ro.globalenv["contrast_vec"] = contrast_vec_r
|
111
110
|
|
112
111
|
# Test contrast with R
|
@@ -117,8 +116,18 @@ class EdgeR(LinearModelBase):
|
|
117
116
|
"""
|
118
117
|
)
|
119
118
|
|
120
|
-
#
|
121
|
-
de_res = ro.
|
119
|
+
# Retrieve the `de_res` object
|
120
|
+
de_res = ro.globalenv["de_res"]
|
121
|
+
|
122
|
+
# If already a Pandas DataFrame, return it directly
|
123
|
+
if isinstance(de_res, pd.DataFrame):
|
124
|
+
de_res.index.name = "variable"
|
125
|
+
return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
|
126
|
+
|
127
|
+
# Convert to Pandas DataFrame if still an R object
|
128
|
+
with localconverter(get_conversion() + pandas2ri.converter):
|
129
|
+
de_res = ro.conversion.rpy2py(de_res)
|
130
|
+
|
122
131
|
de_res.index.name = "variable"
|
123
132
|
de_res = de_res.reset_index()
|
124
133
|
|
@@ -2,6 +2,7 @@ import os
|
|
2
2
|
import re
|
3
3
|
import warnings
|
4
4
|
|
5
|
+
import numpy as np
|
5
6
|
import pandas as pd
|
6
7
|
from anndata import AnnData
|
7
8
|
from numpy import ndarray
|
@@ -40,33 +41,25 @@ class PyDESeq2(LinearModelBase):
|
|
40
41
|
Args:
|
41
42
|
**kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
|
42
43
|
"""
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
44
|
+
try:
|
45
|
+
usable_cpus = len(os.sched_getaffinity(0))
|
46
|
+
except AttributeError:
|
47
|
+
usable_cpus = os.cpu_count()
|
48
|
+
|
49
|
+
inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", usable_cpus))
|
50
|
+
|
50
51
|
dds = DeseqDataSet(
|
51
|
-
adata=self.adata,
|
52
|
+
adata=self.adata,
|
53
|
+
design=self.design, # initialize using design matrix, not formula
|
54
|
+
refit_cooks=True,
|
55
|
+
inference=inference,
|
56
|
+
**kwargs,
|
52
57
|
)
|
53
|
-
# workaround code to insert design array
|
54
|
-
des_mtx_cols = dds.obsm["design_matrix"].columns
|
55
|
-
dds.obsm["design_matrix"] = self.design
|
56
|
-
if dds.obsm["design_matrix"].shape[1] == len(des_mtx_cols):
|
57
|
-
dds.obsm["design_matrix"].columns = des_mtx_cols.copy()
|
58
58
|
|
59
59
|
dds.deseq2()
|
60
60
|
self.dds = dds
|
61
61
|
|
62
|
-
|
63
|
-
# see https://github.com/owkin/PyDESeq2/issues/213
|
64
|
-
|
65
|
-
# Therefore these functions are overridden in a way to make it work with PyDESeq2,
|
66
|
-
# ingoring the inconsistency of function signatures. Once arbitrary design
|
67
|
-
# matrices and contrasts are supported by PyDEseq2, we can fully support the
|
68
|
-
# Linear model interface.
|
69
|
-
def _test_single_contrast(self, contrast: list[str], alpha=0.05, **kwargs) -> pd.DataFrame: # type: ignore
|
62
|
+
def _test_single_contrast(self, contrast, alpha=0.05, **kwargs) -> pd.DataFrame:
|
70
63
|
"""Conduct a specific test and returns a Pandas DataFrame.
|
71
64
|
|
72
65
|
Args:
|
@@ -74,6 +67,7 @@ class PyDESeq2(LinearModelBase):
|
|
74
67
|
alpha: p value threshold used for controlling fdr with independent hypothesis weighting
|
75
68
|
**kwargs: extra arguments to pass to DeseqStats()
|
76
69
|
"""
|
70
|
+
contrast = np.array(contrast)
|
77
71
|
stat_res = DeseqStats(self.dds, contrast=contrast, alpha=alpha, **kwargs)
|
78
72
|
# Calling `.summary()` is required to fill the `results_df` data frame
|
79
73
|
stat_res.summary()
|
@@ -85,11 +79,3 @@ class PyDESeq2(LinearModelBase):
|
|
85
79
|
res_df.index.name = "variable"
|
86
80
|
res_df = res_df.reset_index()
|
87
81
|
return res_df
|
88
|
-
|
89
|
-
def cond(self, **kwargs) -> ndarray:
|
90
|
-
raise NotImplementedError(
|
91
|
-
"PyDESeq2 currently doesn't support arbitrary contrasts, see https://github.com/owkin/PyDESeq2/issues/213"
|
92
|
-
)
|
93
|
-
|
94
|
-
def contrast(self, column: str, baseline: str, group_to_compare: str) -> tuple[str, str, str]: # type: ignore
|
95
|
-
return (column, group_to_compare, baseline)
|
@@ -59,14 +59,3 @@ class Statsmodels(LinearModelBase):
|
|
59
59
|
}
|
60
60
|
)
|
61
61
|
return pd.DataFrame(res).sort_values("p_value")
|
62
|
-
|
63
|
-
def contrast(self, column: str, baseline: str, group_to_compare: str) -> np.ndarray:
|
64
|
-
"""Build a simple contrast for pairwise comparisons.
|
65
|
-
|
66
|
-
This is equivalent to
|
67
|
-
|
68
|
-
```
|
69
|
-
model.cond(<column> = baseline) - model.cond(<column> = group_to_compare)
|
70
|
-
```
|
71
|
-
"""
|
72
|
-
return self.cond(**{column: baseline}) - self.cond(**{column: group_to_compare})
|
@@ -344,9 +344,9 @@ class Distance:
|
|
344
344
|
else:
|
345
345
|
embedding = adata.obsm[self.obsm_key].copy()
|
346
346
|
for index_x, group_x in enumerate(fct(groups)):
|
347
|
-
cells_x = embedding[grouping == group_x].copy()
|
347
|
+
cells_x = embedding[np.asarray(grouping == group_x)].copy()
|
348
348
|
for group_y in groups[index_x:]: # type: ignore
|
349
|
-
cells_y = embedding[grouping == group_y].copy()
|
349
|
+
cells_y = embedding[np.asarray(grouping == group_y)].copy()
|
350
350
|
if not bootstrap:
|
351
351
|
# By distance axiom, the distance between a group and itself is 0
|
352
352
|
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
@@ -478,9 +478,9 @@ class Distance:
|
|
478
478
|
else:
|
479
479
|
embedding = adata.obsm[self.obsm_key].copy()
|
480
480
|
for group_x in fct(groups):
|
481
|
-
cells_x = embedding[grouping == group_x].copy()
|
481
|
+
cells_x = embedding[np.asarray(grouping == group_x)].copy()
|
482
482
|
group_y = selected_group
|
483
|
-
cells_y = embedding[grouping == group_y].copy()
|
483
|
+
cells_y = embedding[np.asarray(grouping == group_y)].copy()
|
484
484
|
if not bootstrap:
|
485
485
|
# By distance axiom, the distance between a group and itself is 0
|
486
486
|
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
@@ -691,17 +691,18 @@ class MMD(AbstractDistance):
|
|
691
691
|
|
692
692
|
|
693
693
|
class WassersteinDistance(AbstractDistance):
|
694
|
-
"""Wasserstein distance metric (solved with entropy regularized Sinkhorn)."""
|
695
|
-
|
696
694
|
def __init__(self) -> None:
|
697
695
|
super().__init__()
|
698
696
|
self.accepts_precomputed = False
|
699
697
|
|
700
698
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
699
|
+
X = np.asarray(X, dtype=np.float64)
|
700
|
+
Y = np.asarray(Y, dtype=np.float64)
|
701
701
|
geom = PointCloud(X, Y)
|
702
702
|
return self.solve_ot_problem(geom, **kwargs)
|
703
703
|
|
704
704
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
705
|
+
P = np.asarray(P, dtype=np.float64)
|
705
706
|
geom = Geometry(cost_matrix=P[idx, :][:, ~idx])
|
706
707
|
return self.solve_ot_problem(geom, **kwargs)
|
707
708
|
|
@@ -709,7 +710,13 @@ class WassersteinDistance(AbstractDistance):
|
|
709
710
|
ot_prob = LinearProblem(geom)
|
710
711
|
solver = Sinkhorn()
|
711
712
|
ot = solver(ot_prob, **kwargs)
|
712
|
-
|
713
|
+
cost = float(ot.reg_ot_cost)
|
714
|
+
|
715
|
+
# Check for NaN or invalid cost
|
716
|
+
if not np.isfinite(cost):
|
717
|
+
return 1.0
|
718
|
+
else:
|
719
|
+
return cost
|
713
720
|
|
714
721
|
|
715
722
|
class EuclideanDistance(AbstractDistance):
|
@@ -981,7 +988,7 @@ class NBLL(AbstractDistance):
|
|
981
988
|
try:
|
982
989
|
nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
|
983
990
|
return _compute_nll(y, nb_params, epsilon)
|
984
|
-
except np.linalg.
|
991
|
+
except np.linalg.LinAlgError:
|
985
992
|
if x.mean() < 10 and y.mean() < 10:
|
986
993
|
return 0.0
|
987
994
|
else:
|
pertpy/tools/_enrichment.py
CHANGED
@@ -3,6 +3,7 @@ from collections.abc import Sequence
|
|
3
3
|
from typing import Any, Literal
|
4
4
|
|
5
5
|
import blitzgsea
|
6
|
+
import matplotlib.pyplot as plt
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
import scanpy as sc
|
@@ -14,6 +15,7 @@ from scipy.sparse import issparse
|
|
14
15
|
from scipy.stats import hypergeom
|
15
16
|
from statsmodels.stats.multitest import multipletests
|
16
17
|
|
18
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
19
|
from pertpy.metadata import Drug
|
18
20
|
|
19
21
|
|
@@ -290,9 +292,11 @@ class Enrichment:
|
|
290
292
|
|
291
293
|
return enrichment
|
292
294
|
|
295
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
293
296
|
def plot_dotplot(
|
294
297
|
self,
|
295
298
|
adata: AnnData,
|
299
|
+
*,
|
296
300
|
targets: dict[str, dict[str, list[str]]] = None,
|
297
301
|
source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
|
298
302
|
category_name: str = "interaction_type",
|
@@ -300,10 +304,10 @@ class Enrichment:
|
|
300
304
|
groupby: str = None,
|
301
305
|
key: str = "pertpy_enrichment",
|
302
306
|
ax: Axes | None = None,
|
303
|
-
|
304
|
-
|
307
|
+
show: bool = True,
|
308
|
+
return_fig: bool = False,
|
305
309
|
**kwargs,
|
306
|
-
) -> DotPlot |
|
310
|
+
) -> DotPlot | None:
|
307
311
|
"""Plots a dotplot by groupby and categories.
|
308
312
|
|
309
313
|
Wraps scanpy's dotplot but formats it nicely by categories.
|
@@ -319,11 +323,11 @@ class Enrichment:
|
|
319
323
|
category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`.
|
320
324
|
groupby: dotplot groupby such as clusters or cell types.
|
321
325
|
key: Prefix key of enrichment results in `uns`.
|
326
|
+
{common_plot_args}
|
322
327
|
kwargs: Passed to scanpy dotplot.
|
323
328
|
|
324
329
|
Returns:
|
325
|
-
If `return_fig` is `True`, returns
|
326
|
-
else if `show` is false, return axes dict.
|
330
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
327
331
|
|
328
332
|
Examples:
|
329
333
|
>>> import pertpy as pt
|
@@ -403,21 +407,27 @@ class Enrichment:
|
|
403
407
|
"var_group_labels": var_group_labels,
|
404
408
|
}
|
405
409
|
|
406
|
-
|
410
|
+
fig = sc.pl.dotplot(
|
407
411
|
enrichment_score_adata,
|
408
412
|
groupby=groupby,
|
409
413
|
swap_axes=True,
|
410
414
|
ax=ax,
|
411
|
-
|
412
|
-
show=show,
|
415
|
+
show=False,
|
413
416
|
**plot_args,
|
414
417
|
**kwargs,
|
415
418
|
)
|
416
419
|
|
420
|
+
if show:
|
421
|
+
plt.show()
|
422
|
+
if return_fig:
|
423
|
+
return fig
|
424
|
+
return None
|
425
|
+
|
417
426
|
def plot_gsea(
|
418
427
|
self,
|
419
428
|
adata: AnnData,
|
420
429
|
enrichment: dict[str, pd.DataFrame],
|
430
|
+
*,
|
421
431
|
n: int = 10,
|
422
432
|
key: str = "pertpy_enrichment_gsea",
|
423
433
|
interactive_plot: bool = False,
|
pertpy/tools/_milo.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import logging
|
4
3
|
import random
|
5
4
|
import re
|
6
5
|
from typing import TYPE_CHECKING, Literal
|
@@ -14,6 +13,8 @@ from anndata import AnnData
|
|
14
13
|
from lamin_utils import logger
|
15
14
|
from mudata import MuData
|
16
15
|
|
16
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
|
+
|
17
18
|
if TYPE_CHECKING:
|
18
19
|
from collections.abc import Sequence
|
19
20
|
|
@@ -125,7 +126,7 @@ class Milo:
|
|
125
126
|
try:
|
126
127
|
use_rep = adata.uns["neighbors"]["params"]["use_rep"]
|
127
128
|
except KeyError:
|
128
|
-
|
129
|
+
logger.warning("Using X_pca as default embedding")
|
129
130
|
use_rep = "X_pca"
|
130
131
|
try:
|
131
132
|
knn_graph = adata.obsp["connectivities"].copy()
|
@@ -136,7 +137,7 @@ class Milo:
|
|
136
137
|
try:
|
137
138
|
use_rep = adata.uns[neighbors_key]["params"]["use_rep"]
|
138
139
|
except KeyError:
|
139
|
-
|
140
|
+
logger.warning("Using X_pca as default embedding")
|
140
141
|
use_rep = "X_pca"
|
141
142
|
knn_graph = adata.obsp[neighbors_key + "_connectivities"].copy()
|
142
143
|
|
@@ -182,7 +183,7 @@ class Milo:
|
|
182
183
|
knn_dists = adata.obsp[neighbors_key + "_distances"]
|
183
184
|
|
184
185
|
nhood_ixs = adata.obs["nhood_ixs_refined"] == 1
|
185
|
-
dist_mat = knn_dists[nhood_ixs, :]
|
186
|
+
dist_mat = knn_dists[np.asarray(nhood_ixs), :]
|
186
187
|
k_distances = dist_mat.max(1).toarray().ravel()
|
187
188
|
adata.obs["nhood_kth_distance"] = 0
|
188
189
|
adata.obs["nhood_kth_distance"] = adata.obs["nhood_kth_distance"].astype(float)
|
@@ -703,8 +704,8 @@ class Milo:
|
|
703
704
|
pvalues = sample_adata.var["PValue"]
|
704
705
|
keep_nhoods = ~pvalues.isna() # Filtering in case of test on subset of nhoods
|
705
706
|
o = pvalues[keep_nhoods].argsort()
|
706
|
-
pvalues = pvalues[keep_nhoods][o]
|
707
|
-
w = w[keep_nhoods][o]
|
707
|
+
pvalues = pvalues.loc[keep_nhoods].iloc[o]
|
708
|
+
w = w.loc[keep_nhoods].iloc[o]
|
708
709
|
|
709
710
|
adjp = np.zeros(shape=len(o))
|
710
711
|
adjp[o] = (sum(w) * pvalues / np.cumsum(w))[::-1].cummin()[::-1]
|
@@ -713,9 +714,11 @@ class Milo:
|
|
713
714
|
sample_adata.var["SpatialFDR"] = np.nan
|
714
715
|
sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp
|
715
716
|
|
717
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
716
718
|
def plot_nhood_graph(
|
717
719
|
self,
|
718
720
|
mdata: MuData,
|
721
|
+
*,
|
719
722
|
alpha: float = 0.1,
|
720
723
|
min_logFC: float = 0,
|
721
724
|
min_size: int = 10,
|
@@ -724,10 +727,10 @@ class Milo:
|
|
724
727
|
color_map: Colormap | str | None = None,
|
725
728
|
palette: str | Sequence[str] | None = None,
|
726
729
|
ax: Axes | None = None,
|
727
|
-
show: bool
|
728
|
-
|
730
|
+
show: bool = True,
|
731
|
+
return_fig: bool = False,
|
729
732
|
**kwargs,
|
730
|
-
) -> None:
|
733
|
+
) -> Figure | None:
|
731
734
|
"""Visualize DA results on abstracted graph (wrapper around sc.pl.embedding)
|
732
735
|
|
733
736
|
Args:
|
@@ -737,9 +740,7 @@ class Milo:
|
|
737
740
|
min_size: Minimum size of nodes in visualization. (default: 10)
|
738
741
|
plot_edges: If edges for neighbourhood overlaps whould be plotted.
|
739
742
|
title: Plot title.
|
740
|
-
|
741
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
742
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
743
|
+
{common_plot_args}
|
743
744
|
**kwargs: Additional arguments to `scanpy.pl.embedding`.
|
744
745
|
|
745
746
|
Examples:
|
@@ -782,7 +783,7 @@ class Milo:
|
|
782
783
|
vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
|
783
784
|
vmin = -vmax
|
784
785
|
|
785
|
-
sc.pl.embedding(
|
786
|
+
fig = sc.pl.embedding(
|
786
787
|
nhood_adata,
|
787
788
|
"X_milo_graph",
|
788
789
|
color="graph_color",
|
@@ -798,33 +799,42 @@ class Milo:
|
|
798
799
|
color_map=color_map,
|
799
800
|
palette=palette,
|
800
801
|
ax=ax,
|
801
|
-
show=
|
802
|
-
save=save,
|
802
|
+
show=False,
|
803
803
|
**kwargs,
|
804
804
|
)
|
805
805
|
|
806
|
+
if show:
|
807
|
+
plt.show()
|
808
|
+
if return_fig:
|
809
|
+
return fig
|
810
|
+
return None
|
811
|
+
|
812
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
806
813
|
def plot_nhood(
|
807
814
|
self,
|
808
815
|
mdata: MuData,
|
809
816
|
ix: int,
|
817
|
+
*,
|
810
818
|
feature_key: str | None = "rna",
|
811
819
|
basis: str = "X_umap",
|
812
820
|
color_map: Colormap | str | None = None,
|
813
821
|
palette: str | Sequence[str] | None = None,
|
814
|
-
return_fig: bool | None = None,
|
815
822
|
ax: Axes | None = None,
|
816
|
-
show: bool
|
817
|
-
|
823
|
+
show: bool = True,
|
824
|
+
return_fig: bool = False,
|
818
825
|
**kwargs,
|
819
|
-
) -> None:
|
826
|
+
) -> Figure | None:
|
820
827
|
"""Visualize cells in a neighbourhood.
|
821
828
|
|
822
829
|
Args:
|
823
830
|
mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
|
824
831
|
ix: index of neighbourhood to visualize
|
832
|
+
feature_key: Key in mdata to the cell-level AnnData object.
|
825
833
|
basis: Embedding to use for visualization.
|
826
|
-
|
827
|
-
|
834
|
+
color_map: Colormap to use for coloring.
|
835
|
+
palette: Color palette to use for coloring.
|
836
|
+
ax: Axes to plot on.
|
837
|
+
{common_plot_args}
|
828
838
|
**kwargs: Additional arguments to `scanpy.pl.embedding`.
|
829
839
|
|
830
840
|
Examples:
|
@@ -842,7 +852,7 @@ class Milo:
|
|
842
852
|
.. image:: /_static/docstring_previews/milo_nhood.png
|
843
853
|
"""
|
844
854
|
mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
|
845
|
-
sc.pl.embedding(
|
855
|
+
fig = sc.pl.embedding(
|
846
856
|
mdata[feature_key],
|
847
857
|
basis,
|
848
858
|
color="Nhood",
|
@@ -852,32 +862,43 @@ class Milo:
|
|
852
862
|
palette=palette,
|
853
863
|
return_fig=return_fig,
|
854
864
|
ax=ax,
|
855
|
-
show=
|
856
|
-
save=save,
|
865
|
+
show=False,
|
857
866
|
**kwargs,
|
858
867
|
)
|
859
868
|
|
869
|
+
if show:
|
870
|
+
plt.show()
|
871
|
+
if return_fig:
|
872
|
+
return fig
|
873
|
+
return None
|
874
|
+
|
875
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
860
876
|
def plot_da_beeswarm(
|
861
877
|
self,
|
862
878
|
mdata: MuData,
|
879
|
+
*,
|
863
880
|
feature_key: str | None = "rna",
|
864
881
|
anno_col: str = "nhood_annotation",
|
865
882
|
alpha: float = 0.1,
|
866
883
|
subset_nhoods: list[str] = None,
|
867
884
|
palette: str | Sequence[str] | dict[str, str] | None = None,
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
) -> Figure | Axes | None:
|
885
|
+
show: bool = True,
|
886
|
+
return_fig: bool = False,
|
887
|
+
) -> Figure | None:
|
872
888
|
"""Plot beeswarm plot of logFC against nhood labels
|
873
889
|
|
874
890
|
Args:
|
875
891
|
mdata: MuData object
|
892
|
+
feature_key: Key in mdata to the cell-level AnnData object.
|
876
893
|
anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
|
877
894
|
alpha: Significance threshold. (default: 0.1)
|
878
895
|
subset_nhoods: List of nhoods to plot. If None, plot all nhoods.
|
879
896
|
palette: Name of Seaborn color palette for violinplots.
|
880
897
|
Defaults to pre-defined category colors for violinplots.
|
898
|
+
{common_plot_args}
|
899
|
+
|
900
|
+
Returns:
|
901
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
881
902
|
|
882
903
|
Examples:
|
883
904
|
>>> import pertpy as pt
|
@@ -973,29 +994,23 @@ class Milo:
|
|
973
994
|
plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
|
974
995
|
plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
|
975
996
|
|
976
|
-
if save:
|
977
|
-
plt.savefig(save, bbox_inches="tight")
|
978
|
-
return None
|
979
997
|
if show:
|
980
998
|
plt.show()
|
981
|
-
return None
|
982
999
|
if return_fig:
|
983
1000
|
return plt.gcf()
|
984
|
-
if (not show and not save) or (show is None and save is None):
|
985
|
-
return plt.gca()
|
986
|
-
|
987
1001
|
return None
|
988
1002
|
|
1003
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
989
1004
|
def plot_nhood_counts_by_cond(
|
990
1005
|
self,
|
991
1006
|
mdata: MuData,
|
992
1007
|
test_var: str,
|
1008
|
+
*,
|
993
1009
|
subset_nhoods: list[str] = None,
|
994
1010
|
log_counts: bool = False,
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
) -> Figure | Axes | None:
|
1011
|
+
show: bool = True,
|
1012
|
+
return_fig: bool = False,
|
1013
|
+
) -> Figure | None:
|
999
1014
|
"""Plot boxplot of cell numbers vs condition of interest.
|
1000
1015
|
|
1001
1016
|
Args:
|
@@ -1003,6 +1018,10 @@ class Milo:
|
|
1003
1018
|
test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
|
1004
1019
|
subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods.
|
1005
1020
|
log_counts: Whether to plot log1p of cell counts.
|
1021
|
+
{common_plot_args}
|
1022
|
+
|
1023
|
+
Returns:
|
1024
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1006
1025
|
"""
|
1007
1026
|
try:
|
1008
1027
|
nhood_adata = mdata["milo"].T.copy()
|
@@ -1031,15 +1050,8 @@ class Milo:
|
|
1031
1050
|
plt.xticks(rotation=90)
|
1032
1051
|
plt.xlabel(test_var)
|
1033
1052
|
|
1034
|
-
if save:
|
1035
|
-
plt.savefig(save, bbox_inches="tight")
|
1036
|
-
return None
|
1037
1053
|
if show:
|
1038
1054
|
plt.show()
|
1039
|
-
return None
|
1040
1055
|
if return_fig:
|
1041
1056
|
return plt.gcf()
|
1042
|
-
if not (show or save):
|
1043
|
-
return plt.gca()
|
1044
|
-
|
1045
1057
|
return None
|