pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +20 -0
- pertpy/data/_dataloader.py +4 -4
- pertpy/data/_datasets.py +3 -3
- pertpy/metadata/_cell_line.py +19 -7
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +19 -6
- pertpy/tools/__init__.py +12 -15
- pertpy/tools/_augur.py +36 -46
- pertpy/tools/_cinemaot.py +24 -18
- pertpy/tools/_coda/_base_coda.py +87 -106
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +495 -113
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +15 -8
- pertpy/tools/_enrichment.py +18 -8
- pertpy/tools/_milo.py +58 -46
- pertpy/tools/_mixscape.py +111 -100
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +50 -0
- pertpy/tools/_scgen/_scgen.py +35 -25
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/METADATA +5 -4
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/RECORD +29 -29
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@ from collections.abc import Sequence
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
5
|
-
from
|
5
|
+
from lamin_utils import logger
|
6
6
|
from scipy.sparse import issparse
|
7
7
|
|
8
8
|
from ._base import LinearModelBase
|
@@ -27,16 +27,11 @@ class EdgeR(LinearModelBase):
|
|
27
27
|
# pandas2ri.activate()
|
28
28
|
# rpy2.robjects.numpy2ri.activate()
|
29
29
|
try:
|
30
|
-
import rpy2.robjects.numpy2ri
|
31
|
-
import rpy2.robjects.pandas2ri
|
32
30
|
from rpy2 import robjects as ro
|
33
31
|
from rpy2.robjects import numpy2ri, pandas2ri
|
34
|
-
from rpy2.robjects.conversion import localconverter
|
32
|
+
from rpy2.robjects.conversion import get_conversion, localconverter
|
35
33
|
from rpy2.robjects.packages import importr
|
36
34
|
|
37
|
-
pandas2ri.activate()
|
38
|
-
rpy2.robjects.numpy2ri.activate()
|
39
|
-
|
40
35
|
except ImportError:
|
41
36
|
raise ImportError("edger requires rpy2 to be installed.") from None
|
42
37
|
|
@@ -49,25 +44,30 @@ class EdgeR(LinearModelBase):
|
|
49
44
|
) from e
|
50
45
|
|
51
46
|
# Convert dataframe
|
52
|
-
with localconverter(
|
47
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
53
48
|
expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
|
54
49
|
if issparse(expr):
|
55
50
|
expr = expr.T.toarray()
|
56
51
|
else:
|
57
52
|
expr = expr.T
|
58
53
|
|
59
|
-
|
54
|
+
with localconverter(get_conversion() + pandas2ri.converter):
|
55
|
+
expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
|
56
|
+
samples_r = ro.conversion.py2rpy(self.adata.obs)
|
60
57
|
|
61
|
-
dge = edger.DGEList(counts=expr_r, samples=
|
58
|
+
dge = edger.DGEList(counts=expr_r, samples=samples_r)
|
62
59
|
|
63
|
-
|
60
|
+
logger.info("Calculating NormFactors")
|
64
61
|
dge = edger.calcNormFactors(dge)
|
65
62
|
|
66
|
-
|
67
|
-
|
63
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
64
|
+
design_r = ro.conversion.py2rpy(self.design.values)
|
65
|
+
|
66
|
+
logger.info("Estimating Dispersions")
|
67
|
+
dge = edger.estimateDisp(dge, design=design_r)
|
68
68
|
|
69
|
-
|
70
|
-
fit = edger.glmQLFit(dge, design=
|
69
|
+
logger.info("Fitting linear model")
|
70
|
+
fit = edger.glmQLFit(dge, design=design_r, **kwargs)
|
71
71
|
|
72
72
|
ro.globalenv["fit"] = fit
|
73
73
|
self.fit = fit
|
@@ -88,11 +88,9 @@ class EdgeR(LinearModelBase):
|
|
88
88
|
# Fix mask for .fit()
|
89
89
|
|
90
90
|
try:
|
91
|
-
import rpy2.robjects.numpy2ri
|
92
|
-
import rpy2.robjects.pandas2ri
|
93
91
|
from rpy2 import robjects as ro
|
94
92
|
from rpy2.robjects import numpy2ri, pandas2ri
|
95
|
-
from rpy2.robjects.conversion import localconverter
|
93
|
+
from rpy2.robjects.conversion import get_conversion, localconverter
|
96
94
|
from rpy2.robjects.packages import importr
|
97
95
|
|
98
96
|
except ImportError:
|
@@ -106,7 +104,8 @@ class EdgeR(LinearModelBase):
|
|
106
104
|
) from None
|
107
105
|
|
108
106
|
# Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
|
109
|
-
|
107
|
+
with localconverter(get_conversion() + numpy2ri.converter):
|
108
|
+
contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
|
110
109
|
ro.globalenv["contrast_vec"] = contrast_vec_r
|
111
110
|
|
112
111
|
# Test contrast with R
|
@@ -117,8 +116,18 @@ class EdgeR(LinearModelBase):
|
|
117
116
|
"""
|
118
117
|
)
|
119
118
|
|
120
|
-
#
|
121
|
-
de_res = ro.
|
119
|
+
# Retrieve the `de_res` object
|
120
|
+
de_res = ro.globalenv["de_res"]
|
121
|
+
|
122
|
+
# If already a Pandas DataFrame, return it directly
|
123
|
+
if isinstance(de_res, pd.DataFrame):
|
124
|
+
de_res.index.name = "variable"
|
125
|
+
return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
|
126
|
+
|
127
|
+
# Convert to Pandas DataFrame if still an R object
|
128
|
+
with localconverter(get_conversion() + pandas2ri.converter):
|
129
|
+
de_res = ro.conversion.rpy2py(de_res)
|
130
|
+
|
122
131
|
de_res.index.name = "variable"
|
123
132
|
de_res = de_res.reset_index()
|
124
133
|
|
@@ -2,6 +2,7 @@ import os
|
|
2
2
|
import re
|
3
3
|
import warnings
|
4
4
|
|
5
|
+
import numpy as np
|
5
6
|
import pandas as pd
|
6
7
|
from anndata import AnnData
|
7
8
|
from numpy import ndarray
|
@@ -40,33 +41,25 @@ class PyDESeq2(LinearModelBase):
|
|
40
41
|
Args:
|
41
42
|
**kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
|
42
43
|
"""
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
44
|
+
try:
|
45
|
+
usable_cpus = len(os.sched_getaffinity(0))
|
46
|
+
except AttributeError:
|
47
|
+
usable_cpus = os.cpu_count()
|
48
|
+
|
49
|
+
inference = DefaultInference(n_cpus=kwargs.pop("n_cpus", usable_cpus))
|
50
|
+
|
50
51
|
dds = DeseqDataSet(
|
51
|
-
adata=self.adata,
|
52
|
+
adata=self.adata,
|
53
|
+
design=self.design, # initialize using design matrix, not formula
|
54
|
+
refit_cooks=True,
|
55
|
+
inference=inference,
|
56
|
+
**kwargs,
|
52
57
|
)
|
53
|
-
# workaround code to insert design array
|
54
|
-
des_mtx_cols = dds.obsm["design_matrix"].columns
|
55
|
-
dds.obsm["design_matrix"] = self.design
|
56
|
-
if dds.obsm["design_matrix"].shape[1] == len(des_mtx_cols):
|
57
|
-
dds.obsm["design_matrix"].columns = des_mtx_cols.copy()
|
58
58
|
|
59
59
|
dds.deseq2()
|
60
60
|
self.dds = dds
|
61
61
|
|
62
|
-
|
63
|
-
# see https://github.com/owkin/PyDESeq2/issues/213
|
64
|
-
|
65
|
-
# Therefore these functions are overridden in a way to make it work with PyDESeq2,
|
66
|
-
# ingoring the inconsistency of function signatures. Once arbitrary design
|
67
|
-
# matrices and contrasts are supported by PyDEseq2, we can fully support the
|
68
|
-
# Linear model interface.
|
69
|
-
def _test_single_contrast(self, contrast: list[str], alpha=0.05, **kwargs) -> pd.DataFrame: # type: ignore
|
62
|
+
def _test_single_contrast(self, contrast, alpha=0.05, **kwargs) -> pd.DataFrame:
|
70
63
|
"""Conduct a specific test and returns a Pandas DataFrame.
|
71
64
|
|
72
65
|
Args:
|
@@ -74,6 +67,7 @@ class PyDESeq2(LinearModelBase):
|
|
74
67
|
alpha: p value threshold used for controlling fdr with independent hypothesis weighting
|
75
68
|
**kwargs: extra arguments to pass to DeseqStats()
|
76
69
|
"""
|
70
|
+
contrast = np.array(contrast)
|
77
71
|
stat_res = DeseqStats(self.dds, contrast=contrast, alpha=alpha, **kwargs)
|
78
72
|
# Calling `.summary()` is required to fill the `results_df` data frame
|
79
73
|
stat_res.summary()
|
@@ -85,11 +79,3 @@ class PyDESeq2(LinearModelBase):
|
|
85
79
|
res_df.index.name = "variable"
|
86
80
|
res_df = res_df.reset_index()
|
87
81
|
return res_df
|
88
|
-
|
89
|
-
def cond(self, **kwargs) -> ndarray:
|
90
|
-
raise NotImplementedError(
|
91
|
-
"PyDESeq2 currently doesn't support arbitrary contrasts, see https://github.com/owkin/PyDESeq2/issues/213"
|
92
|
-
)
|
93
|
-
|
94
|
-
def contrast(self, column: str, baseline: str, group_to_compare: str) -> tuple[str, str, str]: # type: ignore
|
95
|
-
return (column, group_to_compare, baseline)
|
@@ -59,14 +59,3 @@ class Statsmodels(LinearModelBase):
|
|
59
59
|
}
|
60
60
|
)
|
61
61
|
return pd.DataFrame(res).sort_values("p_value")
|
62
|
-
|
63
|
-
def contrast(self, column: str, baseline: str, group_to_compare: str) -> np.ndarray:
|
64
|
-
"""Build a simple contrast for pairwise comparisons.
|
65
|
-
|
66
|
-
This is equivalent to
|
67
|
-
|
68
|
-
```
|
69
|
-
model.cond(<column> = baseline) - model.cond(<column> = group_to_compare)
|
70
|
-
```
|
71
|
-
"""
|
72
|
-
return self.cond(**{column: baseline}) - self.cond(**{column: group_to_compare})
|
@@ -344,9 +344,9 @@ class Distance:
|
|
344
344
|
else:
|
345
345
|
embedding = adata.obsm[self.obsm_key].copy()
|
346
346
|
for index_x, group_x in enumerate(fct(groups)):
|
347
|
-
cells_x = embedding[grouping == group_x].copy()
|
347
|
+
cells_x = embedding[np.asarray(grouping == group_x)].copy()
|
348
348
|
for group_y in groups[index_x:]: # type: ignore
|
349
|
-
cells_y = embedding[grouping == group_y].copy()
|
349
|
+
cells_y = embedding[np.asarray(grouping == group_y)].copy()
|
350
350
|
if not bootstrap:
|
351
351
|
# By distance axiom, the distance between a group and itself is 0
|
352
352
|
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
@@ -478,9 +478,9 @@ class Distance:
|
|
478
478
|
else:
|
479
479
|
embedding = adata.obsm[self.obsm_key].copy()
|
480
480
|
for group_x in fct(groups):
|
481
|
-
cells_x = embedding[grouping == group_x].copy()
|
481
|
+
cells_x = embedding[np.asarray(grouping == group_x)].copy()
|
482
482
|
group_y = selected_group
|
483
|
-
cells_y = embedding[grouping == group_y].copy()
|
483
|
+
cells_y = embedding[np.asarray(grouping == group_y)].copy()
|
484
484
|
if not bootstrap:
|
485
485
|
# By distance axiom, the distance between a group and itself is 0
|
486
486
|
dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
|
@@ -691,17 +691,18 @@ class MMD(AbstractDistance):
|
|
691
691
|
|
692
692
|
|
693
693
|
class WassersteinDistance(AbstractDistance):
|
694
|
-
"""Wasserstein distance metric (solved with entropy regularized Sinkhorn)."""
|
695
|
-
|
696
694
|
def __init__(self) -> None:
|
697
695
|
super().__init__()
|
698
696
|
self.accepts_precomputed = False
|
699
697
|
|
700
698
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
699
|
+
X = np.asarray(X, dtype=np.float64)
|
700
|
+
Y = np.asarray(Y, dtype=np.float64)
|
701
701
|
geom = PointCloud(X, Y)
|
702
702
|
return self.solve_ot_problem(geom, **kwargs)
|
703
703
|
|
704
704
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
705
|
+
P = np.asarray(P, dtype=np.float64)
|
705
706
|
geom = Geometry(cost_matrix=P[idx, :][:, ~idx])
|
706
707
|
return self.solve_ot_problem(geom, **kwargs)
|
707
708
|
|
@@ -709,7 +710,13 @@ class WassersteinDistance(AbstractDistance):
|
|
709
710
|
ot_prob = LinearProblem(geom)
|
710
711
|
solver = Sinkhorn()
|
711
712
|
ot = solver(ot_prob, **kwargs)
|
712
|
-
|
713
|
+
cost = float(ot.reg_ot_cost)
|
714
|
+
|
715
|
+
# Check for NaN or invalid cost
|
716
|
+
if not np.isfinite(cost):
|
717
|
+
return 1.0
|
718
|
+
else:
|
719
|
+
return cost
|
713
720
|
|
714
721
|
|
715
722
|
class EuclideanDistance(AbstractDistance):
|
@@ -981,7 +988,7 @@ class NBLL(AbstractDistance):
|
|
981
988
|
try:
|
982
989
|
nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params
|
983
990
|
return _compute_nll(y, nb_params, epsilon)
|
984
|
-
except np.linalg.
|
991
|
+
except np.linalg.LinAlgError:
|
985
992
|
if x.mean() < 10 and y.mean() < 10:
|
986
993
|
return 0.0
|
987
994
|
else:
|
pertpy/tools/_enrichment.py
CHANGED
@@ -3,6 +3,7 @@ from collections.abc import Sequence
|
|
3
3
|
from typing import Any, Literal
|
4
4
|
|
5
5
|
import blitzgsea
|
6
|
+
import matplotlib.pyplot as plt
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
import scanpy as sc
|
@@ -14,6 +15,7 @@ from scipy.sparse import issparse
|
|
14
15
|
from scipy.stats import hypergeom
|
15
16
|
from statsmodels.stats.multitest import multipletests
|
16
17
|
|
18
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
19
|
from pertpy.metadata import Drug
|
18
20
|
|
19
21
|
|
@@ -290,9 +292,11 @@ class Enrichment:
|
|
290
292
|
|
291
293
|
return enrichment
|
292
294
|
|
295
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
293
296
|
def plot_dotplot(
|
294
297
|
self,
|
295
298
|
adata: AnnData,
|
299
|
+
*,
|
296
300
|
targets: dict[str, dict[str, list[str]]] = None,
|
297
301
|
source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
|
298
302
|
category_name: str = "interaction_type",
|
@@ -300,10 +304,10 @@ class Enrichment:
|
|
300
304
|
groupby: str = None,
|
301
305
|
key: str = "pertpy_enrichment",
|
302
306
|
ax: Axes | None = None,
|
303
|
-
|
304
|
-
|
307
|
+
show: bool = True,
|
308
|
+
return_fig: bool = False,
|
305
309
|
**kwargs,
|
306
|
-
) -> DotPlot |
|
310
|
+
) -> DotPlot | None:
|
307
311
|
"""Plots a dotplot by groupby and categories.
|
308
312
|
|
309
313
|
Wraps scanpy's dotplot but formats it nicely by categories.
|
@@ -319,11 +323,11 @@ class Enrichment:
|
|
319
323
|
category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`.
|
320
324
|
groupby: dotplot groupby such as clusters or cell types.
|
321
325
|
key: Prefix key of enrichment results in `uns`.
|
326
|
+
{common_plot_args}
|
322
327
|
kwargs: Passed to scanpy dotplot.
|
323
328
|
|
324
329
|
Returns:
|
325
|
-
If `return_fig` is `True`, returns
|
326
|
-
else if `show` is false, return axes dict.
|
330
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
327
331
|
|
328
332
|
Examples:
|
329
333
|
>>> import pertpy as pt
|
@@ -403,21 +407,27 @@ class Enrichment:
|
|
403
407
|
"var_group_labels": var_group_labels,
|
404
408
|
}
|
405
409
|
|
406
|
-
|
410
|
+
fig = sc.pl.dotplot(
|
407
411
|
enrichment_score_adata,
|
408
412
|
groupby=groupby,
|
409
413
|
swap_axes=True,
|
410
414
|
ax=ax,
|
411
|
-
|
412
|
-
show=show,
|
415
|
+
show=False,
|
413
416
|
**plot_args,
|
414
417
|
**kwargs,
|
415
418
|
)
|
416
419
|
|
420
|
+
if show:
|
421
|
+
plt.show()
|
422
|
+
if return_fig:
|
423
|
+
return fig
|
424
|
+
return None
|
425
|
+
|
417
426
|
def plot_gsea(
|
418
427
|
self,
|
419
428
|
adata: AnnData,
|
420
429
|
enrichment: dict[str, pd.DataFrame],
|
430
|
+
*,
|
421
431
|
n: int = 10,
|
422
432
|
key: str = "pertpy_enrichment_gsea",
|
423
433
|
interactive_plot: bool = False,
|
pertpy/tools/_milo.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import logging
|
4
3
|
import random
|
5
4
|
import re
|
6
5
|
from typing import TYPE_CHECKING, Literal
|
@@ -14,6 +13,8 @@ from anndata import AnnData
|
|
14
13
|
from lamin_utils import logger
|
15
14
|
from mudata import MuData
|
16
15
|
|
16
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
|
+
|
17
18
|
if TYPE_CHECKING:
|
18
19
|
from collections.abc import Sequence
|
19
20
|
|
@@ -125,7 +126,7 @@ class Milo:
|
|
125
126
|
try:
|
126
127
|
use_rep = adata.uns["neighbors"]["params"]["use_rep"]
|
127
128
|
except KeyError:
|
128
|
-
|
129
|
+
logger.warning("Using X_pca as default embedding")
|
129
130
|
use_rep = "X_pca"
|
130
131
|
try:
|
131
132
|
knn_graph = adata.obsp["connectivities"].copy()
|
@@ -136,7 +137,7 @@ class Milo:
|
|
136
137
|
try:
|
137
138
|
use_rep = adata.uns[neighbors_key]["params"]["use_rep"]
|
138
139
|
except KeyError:
|
139
|
-
|
140
|
+
logger.warning("Using X_pca as default embedding")
|
140
141
|
use_rep = "X_pca"
|
141
142
|
knn_graph = adata.obsp[neighbors_key + "_connectivities"].copy()
|
142
143
|
|
@@ -182,7 +183,7 @@ class Milo:
|
|
182
183
|
knn_dists = adata.obsp[neighbors_key + "_distances"]
|
183
184
|
|
184
185
|
nhood_ixs = adata.obs["nhood_ixs_refined"] == 1
|
185
|
-
dist_mat = knn_dists[nhood_ixs, :]
|
186
|
+
dist_mat = knn_dists[np.asarray(nhood_ixs), :]
|
186
187
|
k_distances = dist_mat.max(1).toarray().ravel()
|
187
188
|
adata.obs["nhood_kth_distance"] = 0
|
188
189
|
adata.obs["nhood_kth_distance"] = adata.obs["nhood_kth_distance"].astype(float)
|
@@ -703,8 +704,8 @@ class Milo:
|
|
703
704
|
pvalues = sample_adata.var["PValue"]
|
704
705
|
keep_nhoods = ~pvalues.isna() # Filtering in case of test on subset of nhoods
|
705
706
|
o = pvalues[keep_nhoods].argsort()
|
706
|
-
pvalues = pvalues[keep_nhoods][o]
|
707
|
-
w = w[keep_nhoods][o]
|
707
|
+
pvalues = pvalues.loc[keep_nhoods].iloc[o]
|
708
|
+
w = w.loc[keep_nhoods].iloc[o]
|
708
709
|
|
709
710
|
adjp = np.zeros(shape=len(o))
|
710
711
|
adjp[o] = (sum(w) * pvalues / np.cumsum(w))[::-1].cummin()[::-1]
|
@@ -713,9 +714,11 @@ class Milo:
|
|
713
714
|
sample_adata.var["SpatialFDR"] = np.nan
|
714
715
|
sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp
|
715
716
|
|
717
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
716
718
|
def plot_nhood_graph(
|
717
719
|
self,
|
718
720
|
mdata: MuData,
|
721
|
+
*,
|
719
722
|
alpha: float = 0.1,
|
720
723
|
min_logFC: float = 0,
|
721
724
|
min_size: int = 10,
|
@@ -724,10 +727,10 @@ class Milo:
|
|
724
727
|
color_map: Colormap | str | None = None,
|
725
728
|
palette: str | Sequence[str] | None = None,
|
726
729
|
ax: Axes | None = None,
|
727
|
-
show: bool
|
728
|
-
|
730
|
+
show: bool = True,
|
731
|
+
return_fig: bool = False,
|
729
732
|
**kwargs,
|
730
|
-
) -> None:
|
733
|
+
) -> Figure | None:
|
731
734
|
"""Visualize DA results on abstracted graph (wrapper around sc.pl.embedding)
|
732
735
|
|
733
736
|
Args:
|
@@ -737,9 +740,7 @@ class Milo:
|
|
737
740
|
min_size: Minimum size of nodes in visualization. (default: 10)
|
738
741
|
plot_edges: If edges for neighbourhood overlaps whould be plotted.
|
739
742
|
title: Plot title.
|
740
|
-
|
741
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
742
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
743
|
+
{common_plot_args}
|
743
744
|
**kwargs: Additional arguments to `scanpy.pl.embedding`.
|
744
745
|
|
745
746
|
Examples:
|
@@ -782,7 +783,7 @@ class Milo:
|
|
782
783
|
vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
|
783
784
|
vmin = -vmax
|
784
785
|
|
785
|
-
sc.pl.embedding(
|
786
|
+
fig = sc.pl.embedding(
|
786
787
|
nhood_adata,
|
787
788
|
"X_milo_graph",
|
788
789
|
color="graph_color",
|
@@ -798,33 +799,42 @@ class Milo:
|
|
798
799
|
color_map=color_map,
|
799
800
|
palette=palette,
|
800
801
|
ax=ax,
|
801
|
-
show=
|
802
|
-
save=save,
|
802
|
+
show=False,
|
803
803
|
**kwargs,
|
804
804
|
)
|
805
805
|
|
806
|
+
if show:
|
807
|
+
plt.show()
|
808
|
+
if return_fig:
|
809
|
+
return fig
|
810
|
+
return None
|
811
|
+
|
812
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
806
813
|
def plot_nhood(
|
807
814
|
self,
|
808
815
|
mdata: MuData,
|
809
816
|
ix: int,
|
817
|
+
*,
|
810
818
|
feature_key: str | None = "rna",
|
811
819
|
basis: str = "X_umap",
|
812
820
|
color_map: Colormap | str | None = None,
|
813
821
|
palette: str | Sequence[str] | None = None,
|
814
|
-
return_fig: bool | None = None,
|
815
822
|
ax: Axes | None = None,
|
816
|
-
show: bool
|
817
|
-
|
823
|
+
show: bool = True,
|
824
|
+
return_fig: bool = False,
|
818
825
|
**kwargs,
|
819
|
-
) -> None:
|
826
|
+
) -> Figure | None:
|
820
827
|
"""Visualize cells in a neighbourhood.
|
821
828
|
|
822
829
|
Args:
|
823
830
|
mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
|
824
831
|
ix: index of neighbourhood to visualize
|
832
|
+
feature_key: Key in mdata to the cell-level AnnData object.
|
825
833
|
basis: Embedding to use for visualization.
|
826
|
-
|
827
|
-
|
834
|
+
color_map: Colormap to use for coloring.
|
835
|
+
palette: Color palette to use for coloring.
|
836
|
+
ax: Axes to plot on.
|
837
|
+
{common_plot_args}
|
828
838
|
**kwargs: Additional arguments to `scanpy.pl.embedding`.
|
829
839
|
|
830
840
|
Examples:
|
@@ -842,7 +852,7 @@ class Milo:
|
|
842
852
|
.. image:: /_static/docstring_previews/milo_nhood.png
|
843
853
|
"""
|
844
854
|
mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
|
845
|
-
sc.pl.embedding(
|
855
|
+
fig = sc.pl.embedding(
|
846
856
|
mdata[feature_key],
|
847
857
|
basis,
|
848
858
|
color="Nhood",
|
@@ -852,32 +862,43 @@ class Milo:
|
|
852
862
|
palette=palette,
|
853
863
|
return_fig=return_fig,
|
854
864
|
ax=ax,
|
855
|
-
show=
|
856
|
-
save=save,
|
865
|
+
show=False,
|
857
866
|
**kwargs,
|
858
867
|
)
|
859
868
|
|
869
|
+
if show:
|
870
|
+
plt.show()
|
871
|
+
if return_fig:
|
872
|
+
return fig
|
873
|
+
return None
|
874
|
+
|
875
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
860
876
|
def plot_da_beeswarm(
|
861
877
|
self,
|
862
878
|
mdata: MuData,
|
879
|
+
*,
|
863
880
|
feature_key: str | None = "rna",
|
864
881
|
anno_col: str = "nhood_annotation",
|
865
882
|
alpha: float = 0.1,
|
866
883
|
subset_nhoods: list[str] = None,
|
867
884
|
palette: str | Sequence[str] | dict[str, str] | None = None,
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
) -> Figure | Axes | None:
|
885
|
+
show: bool = True,
|
886
|
+
return_fig: bool = False,
|
887
|
+
) -> Figure | None:
|
872
888
|
"""Plot beeswarm plot of logFC against nhood labels
|
873
889
|
|
874
890
|
Args:
|
875
891
|
mdata: MuData object
|
892
|
+
feature_key: Key in mdata to the cell-level AnnData object.
|
876
893
|
anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
|
877
894
|
alpha: Significance threshold. (default: 0.1)
|
878
895
|
subset_nhoods: List of nhoods to plot. If None, plot all nhoods.
|
879
896
|
palette: Name of Seaborn color palette for violinplots.
|
880
897
|
Defaults to pre-defined category colors for violinplots.
|
898
|
+
{common_plot_args}
|
899
|
+
|
900
|
+
Returns:
|
901
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
881
902
|
|
882
903
|
Examples:
|
883
904
|
>>> import pertpy as pt
|
@@ -973,29 +994,23 @@ class Milo:
|
|
973
994
|
plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
|
974
995
|
plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
|
975
996
|
|
976
|
-
if save:
|
977
|
-
plt.savefig(save, bbox_inches="tight")
|
978
|
-
return None
|
979
997
|
if show:
|
980
998
|
plt.show()
|
981
|
-
return None
|
982
999
|
if return_fig:
|
983
1000
|
return plt.gcf()
|
984
|
-
if (not show and not save) or (show is None and save is None):
|
985
|
-
return plt.gca()
|
986
|
-
|
987
1001
|
return None
|
988
1002
|
|
1003
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
989
1004
|
def plot_nhood_counts_by_cond(
|
990
1005
|
self,
|
991
1006
|
mdata: MuData,
|
992
1007
|
test_var: str,
|
1008
|
+
*,
|
993
1009
|
subset_nhoods: list[str] = None,
|
994
1010
|
log_counts: bool = False,
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
) -> Figure | Axes | None:
|
1011
|
+
show: bool = True,
|
1012
|
+
return_fig: bool = False,
|
1013
|
+
) -> Figure | None:
|
999
1014
|
"""Plot boxplot of cell numbers vs condition of interest.
|
1000
1015
|
|
1001
1016
|
Args:
|
@@ -1003,6 +1018,10 @@ class Milo:
|
|
1003
1018
|
test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
|
1004
1019
|
subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods.
|
1005
1020
|
log_counts: Whether to plot log1p of cell counts.
|
1021
|
+
{common_plot_args}
|
1022
|
+
|
1023
|
+
Returns:
|
1024
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1006
1025
|
"""
|
1007
1026
|
try:
|
1008
1027
|
nhood_adata = mdata["milo"].T.copy()
|
@@ -1031,15 +1050,8 @@ class Milo:
|
|
1031
1050
|
plt.xticks(rotation=90)
|
1032
1051
|
plt.xlabel(test_var)
|
1033
1052
|
|
1034
|
-
if save:
|
1035
|
-
plt.savefig(save, bbox_inches="tight")
|
1036
|
-
return None
|
1037
1053
|
if show:
|
1038
1054
|
plt.show()
|
1039
|
-
return None
|
1040
1055
|
if return_fig:
|
1041
1056
|
return plt.gcf()
|
1042
|
-
if not (show or save):
|
1043
|
-
return plt.gca()
|
1044
|
-
|
1045
1057
|
return None
|