pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,11 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.3
|
2
2
|
Name: pertpy
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.7.0
|
4
4
|
Summary: Perturbation Analysis in the scverse ecosystem.
|
5
5
|
Project-URL: Documentation, https://pertpy.readthedocs.io
|
6
6
|
Project-URL: Source, https://github.com/theislab/pertpy
|
7
7
|
Project-URL: Home-page, https://github.com/theislab/pertpy
|
8
|
-
Author: Lukas Heumos, Yuge Ji, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu, Amir Moinfar, Sergei Rybakov, Tessa Green, Stefan Peidli, Antonia Schumacher
|
8
|
+
Author: Lukas Heumos, Yuge Ji, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu, Amir Moinfar, Sergei Rybakov, Tessa Green, Stefan Peidli, Antonia Schumacher, Lilly May
|
9
9
|
Maintainer-email: Lukas Heumos <lukas.heumos@posteo.net>
|
10
10
|
License: MIT License
|
11
11
|
|
@@ -29,18 +29,31 @@ License: MIT License
|
|
29
29
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
30
30
|
SOFTWARE.
|
31
31
|
License-File: LICENSE
|
32
|
-
|
32
|
+
Classifier: Development Status :: 4 - Beta
|
33
|
+
Classifier: Environment :: Console
|
34
|
+
Classifier: Framework :: Jupyter
|
35
|
+
Classifier: Intended Audience :: Developers
|
36
|
+
Classifier: Intended Audience :: Science/Research
|
37
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
38
|
+
Classifier: Natural Language :: English
|
39
|
+
Classifier: Operating System :: MacOS :: MacOS X
|
40
|
+
Classifier: Operating System :: POSIX :: Linux
|
41
|
+
Classifier: Programming Language :: Python :: 3
|
42
|
+
Classifier: Programming Language :: Python :: 3.10
|
43
|
+
Classifier: Programming Language :: Python :: 3.11
|
44
|
+
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
45
|
+
Classifier: Topic :: Scientific/Engineering :: Visualization
|
46
|
+
Requires-Python: >=3.10
|
33
47
|
Requires-Dist: adjusttext
|
34
|
-
Requires-Dist: anndata
|
35
48
|
Requires-Dist: arviz
|
49
|
+
Requires-Dist: blitzgsea
|
36
50
|
Requires-Dist: decoupler
|
37
|
-
Requires-Dist: ipywidgets
|
38
51
|
Requires-Dist: muon
|
39
|
-
Requires-Dist: numba
|
40
52
|
Requires-Dist: numpyro
|
41
53
|
Requires-Dist: openpyxl
|
42
54
|
Requires-Dist: ott-jax
|
43
|
-
Requires-Dist:
|
55
|
+
Requires-Dist: pubchempy
|
56
|
+
Requires-Dist: pyarrow
|
44
57
|
Requires-Dist: requests
|
45
58
|
Requires-Dist: rich
|
46
59
|
Requires-Dist: scanpy[leiden]
|
@@ -76,13 +89,13 @@ Requires-Dist: sphinx>=4; extra == 'doc'
|
|
76
89
|
Requires-Dist: sphinxcontrib-bibtex>=1.0.0; extra == 'doc'
|
77
90
|
Requires-Dist: sphinxext-opengraph; extra == 'doc'
|
78
91
|
Provides-Extra: test
|
92
|
+
Requires-Dist: coverage; extra == 'test'
|
79
93
|
Requires-Dist: pytest; extra == 'test'
|
80
|
-
Requires-Dist: pytest-cov; extra == 'test'
|
81
94
|
Description-Content-Type: text/markdown
|
82
95
|
|
83
96
|
[](https://github.com/psf/black)
|
84
97
|
[](https://github.com/theislab/pertpy/actions/workflows/build.yml)
|
85
|
-
[](https://codecov.io/gh/theislab/pertpy)
|
86
99
|
[](https://opensource.org/licenses/Apache2.0)
|
87
100
|
[](https://pypi.org/project/pertpy/)
|
88
101
|
[](https://pypi.org/project/pertpy)
|
@@ -92,7 +105,7 @@ Description-Content-Type: text/markdown
|
|
92
105
|
|
93
106
|
# pertpy
|
94
107
|
|
95
|
-

|
96
109
|
|
97
110
|
## Documentation
|
98
111
|
|
@@ -103,7 +116,13 @@ Please read the [documentation](https://pertpy.readthedocs.io/en/latest).
|
|
103
116
|
You can install _pertpy_ via [pip] from [PyPI]:
|
104
117
|
|
105
118
|
```console
|
106
|
-
|
119
|
+
pip install pertpy
|
120
|
+
```
|
121
|
+
|
122
|
+
if you want to use scCODA please install it as:
|
123
|
+
|
124
|
+
```console
|
125
|
+
pip install pertpy[coda]
|
107
126
|
```
|
108
127
|
|
109
128
|
[pip]: https://pip.pypa.io/
|
@@ -0,0 +1,53 @@
|
|
1
|
+
pertpy/__init__.py,sha256=HI_6a6S0V2YQmRZrlopJ9fxrHVEsBc0GhcX-fZZQCWk,600
|
2
|
+
pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
pertpy/data/__init__.py,sha256=zj4jWyw_Dr6HJvdznQi9QUkZCoFN6vdFRhFF_tw4jts,1275
|
4
|
+
pertpy/data/_dataloader.py,sha256=UTIFK2pofjO6HJjj1dJRbZiQ04_Sslg0d6ggBRkaTnk,2325
|
5
|
+
pertpy/data/_datasets.py,sha256=78RQb4lGFhCDQdAwb8p_MYy4ZxFjRhhfAsKkuWMl5EQ,63179
|
6
|
+
pertpy/metadata/__init__.py,sha256=9FC_FfcpGTsc3j5orjr3USkfsT63REnQBy4r7_VRjco,171
|
7
|
+
pertpy/metadata/_cell_line.py,sha256=tA17gWp9uY0jn36S5bt5lfc9zU2HC2DxZ6x5jVRUosU,40644
|
8
|
+
pertpy/metadata/_compound.py,sha256=Y1oaV_Nq4-wDVFPQhenJTjvOlY4HYi8bCxUUN3VNbMc,4861
|
9
|
+
pertpy/metadata/_drug.py,sha256=q3snazwU_wTLC3Js-9pr_CvfBCRGuFBLlHSFVV-h3ko,9357
|
10
|
+
pertpy/metadata/_look_up.py,sha256=Tb9BHZVc05aH4Kg2KLcfrgvra8gEV5aBqwAZ433krMc,29476
|
11
|
+
pertpy/metadata/_metadata.py,sha256=h_m5bjn_DvCOuiie_q_DB6eM9ermJKXSnLJ0i84AGqI,3419
|
12
|
+
pertpy/metadata/_moa.py,sha256=gLjM59ANqK5Ov-Y-v6COr5Bz4nI2O_LEBfJ2jRpR5U8,4941
|
13
|
+
pertpy/plot/__init__.py,sha256=ZbOAZMpnnAy8GR8oSXJhuwA3gr9GDqgzx0fEkz0TFaE,257
|
14
|
+
pertpy/plot/_augur.py,sha256=uHBoKVeNMcQUV23eHX_k3KQ4sa7wJUxn8VYgoGM0XSw,6713
|
15
|
+
pertpy/plot/_coda.py,sha256=qZfkfFT61pwf-rlOYPwJiYDkx_1DTkT4TfG9fObQWoc,26891
|
16
|
+
pertpy/plot/_guide_rna.py,sha256=iIjkOus7tftIv01zaCTamSX2q8bHBI8O4dxv6QDRVM0,2702
|
17
|
+
pertpy/plot/_milopy.py,sha256=LGMJ4UQFAoRXtBC5LrUtskCU3YEQLzGvMOPuxBg1Tak,7414
|
18
|
+
pertpy/plot/_mixscape.py,sha256=wKsY1efBP9VtQptAq79w5-ODcgzomD3QgW8iOF-5LZY,14862
|
19
|
+
pertpy/preprocessing/__init__.py,sha256=uja9T469LLYQAGgrTyFa4MudXci6NXnAgOn97FHXcxA,40
|
20
|
+
pertpy/preprocessing/_guide_rna.py,sha256=9FSu0QCxZujGPxhl-PcU4yRR7ZSGhgg99eJaLmf8hRA,7713
|
21
|
+
pertpy/tools/__init__.py,sha256=zs_xmmmefEFbOQFMFhh_ds9StyuZ1iiSDdk4qhNlQq4,931
|
22
|
+
pertpy/tools/_augur.py,sha256=Lws7-oSjSFGRupoloj7yIA3YNvVt_dcjSnQ8_Ctt2u0,55797
|
23
|
+
pertpy/tools/_cinemaot.py,sha256=EbnucYl-Q3sfPDL9RiA-cXjpGFJYvfF4CaDEkyr7Snc,39501
|
24
|
+
pertpy/tools/_dialogue.py,sha256=u98h147h2PwhWxFQo8limNmt20e2r5La9-wZvS6tXjo,52143
|
25
|
+
pertpy/tools/_differential_gene_expression.py,sha256=eqit0SJ_fIpSYRS5grYZxEln_JkpXyng-q3LUw9ArY8,14360
|
26
|
+
pertpy/tools/_enrichment.py,sha256=U2WY7t--FCQFZsXXBLZYBwF4WO9t5WToWjCGn_-tP6I,21922
|
27
|
+
pertpy/tools/_kernel_pca.py,sha256=3S1D_wrp4vlHUPiRbCAoRbUyY-rVs112Qh-BZHSmTxE,1578
|
28
|
+
pertpy/tools/_milo.py,sha256=SxpEjhvn2FOYQcQ56tG9ts8fS2vpp4DsB3V7lr8QHbc,43388
|
29
|
+
pertpy/tools/_mixscape.py,sha256=5s0BrLsv1uv1tbZkZnNXN2qUzWjxvIuGe2iu9czBIxA,52500
|
30
|
+
pertpy/tools/decoupler_LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
31
|
+
pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
|
32
|
+
pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
|
+
pertpy/tools/_coda/_base_coda.py,sha256=L6OEz7PHDve_geS6JdyjMC5fSMNitZ4pYQOyMjSW04E,115262
|
34
|
+
pertpy/tools/_coda/_sccoda.py,sha256=K_zsU8z_YP1hwJV1Urug4fd7kS-q23eLkcSsX_hxqYg,22849
|
35
|
+
pertpy/tools/_coda/_tasccoda.py,sha256=xTe89TMtbUkw7WAEAvWzOQsMotdekP3Vs5zhMtB2t6c,31033
|
36
|
+
pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
37
|
+
pertpy/tools/_distances/_distance_tests.py,sha256=aC4zspn4nJ2JRk56c76xTgSxl9znY7QI7zLktwsQ8SY,13604
|
38
|
+
pertpy/tools/_distances/_distances.py,sha256=_dB9dyqWvg_yJBMJTuruFw5w-CLtQWqol-28KVifPBM,35619
|
39
|
+
pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
|
+
pertpy/tools/_perturbation_space/_clustering.py,sha256=a1p0vzjN9nsmtWXWaKIeR3To3mL-7mzqD3HTTiHCpRk,3470
|
41
|
+
pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=JYDPJlyqxeFbBqWqXk0el6ClXIwrUIYTye5ehmQvmao,22080
|
42
|
+
pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
|
43
|
+
pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=VwwYjQ8-VfRHp8_97Id320xd3tnvv-wea6iLFkRE0RY,18868
|
44
|
+
pertpy/tools/_perturbation_space/_simple.py,sha256=37zZweW-50Oa8pTQyC53KyNZQH157HurSc8qXZZt7JA,12681
|
45
|
+
pertpy/tools/_scgen/__init__.py,sha256=neOj3JEcXxnIvgSMrSuPaUSGl1upo7yQ7K_NXHb8Rb8,45
|
46
|
+
pertpy/tools/_scgen/_base_components.py,sha256=dIw-_7Z8iCietPF4tnpM7bFHtDksjnaHXwUjp9GoCIQ,2936
|
47
|
+
pertpy/tools/_scgen/_scgen.py,sha256=qPu_zWCdLvcuOqSDNb2FLzxE5Y1uH0uSbpPGu7kOKks,30705
|
48
|
+
pertpy/tools/_scgen/_scgenvae.py,sha256=v_6tZ4wY-JjdMH1QVd_wG4_N0PoaqB-FM8zC2JsDu1o,3935
|
49
|
+
pertpy/tools/_scgen/_utils.py,sha256=y0LGS1OLmIVUBq2ZYySM2Up51o3c08-yjTvUkFb3E0U,2841
|
50
|
+
pertpy-0.7.0.dist-info/METADATA,sha256=tSfWThb8uxC4jwbO5VfOIikt8DRBpO4UdHdFde0al3w,5782
|
51
|
+
pertpy-0.7.0.dist-info/WHEEL,sha256=as-1oFTWSeWBgyzh0O_qF439xqBe6AbBgt4MfYe5zwY,87
|
52
|
+
pertpy-0.7.0.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
|
53
|
+
pertpy-0.7.0.dist-info/RECORD,,
|
pertpy/plot/_cinemaot.py
DELETED
@@ -1,81 +0,0 @@
|
|
1
|
-
from typing import Optional
|
2
|
-
|
3
|
-
import matplotlib.pyplot as plt
|
4
|
-
import pandas as pd
|
5
|
-
import scanpy as sc
|
6
|
-
import seaborn as sns
|
7
|
-
from anndata import AnnData
|
8
|
-
from matplotlib.axes import Axes
|
9
|
-
from scanpy.plotting import _utils
|
10
|
-
|
11
|
-
|
12
|
-
class CinemaotPlot:
|
13
|
-
"""Plotting functions for CINEMA-OT. Only includes new functions beyond the scanpy.pl.embedding family."""
|
14
|
-
|
15
|
-
@staticmethod
|
16
|
-
def vis_matching(
|
17
|
-
adata: AnnData,
|
18
|
-
de: AnnData,
|
19
|
-
pert_key: str,
|
20
|
-
control: str,
|
21
|
-
de_label: str,
|
22
|
-
source_label: str,
|
23
|
-
matching_rep: str = "ot",
|
24
|
-
resolution: float = 0.5,
|
25
|
-
normalize: str = "col",
|
26
|
-
title: str = "CINEMA-OT matching matrix",
|
27
|
-
min_val: float = 0.01,
|
28
|
-
show: bool = True,
|
29
|
-
save: Optional[str] = None,
|
30
|
-
ax: Optional[Axes] = None,
|
31
|
-
**kwargs,
|
32
|
-
) -> None:
|
33
|
-
"""Visualize the CINEMA-OT matching matrix.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
adata: the original anndata after running cinemaot.causaleffect or cinemaot.causaleffect_weighted.
|
37
|
-
de: The anndata output from Cinemaot.causaleffect() or Cinemaot.causaleffect_weighted().
|
38
|
-
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
|
39
|
-
control: Control category from the `pert_key` column.
|
40
|
-
de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0.
|
41
|
-
source_label: the confounder / cell type label.
|
42
|
-
matching_rep: the place that stores the matching matrix. default de.obsm['ot'].
|
43
|
-
normalize: normalize the coarse-grained matching matrix by row / column.
|
44
|
-
title: the title for the figure.
|
45
|
-
min_val: The min value to truncate the matching matrix.
|
46
|
-
show: Show the plot, do not return axis.
|
47
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
48
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
49
|
-
**kwargs: Other parameters to input for seaborn.heatmap.
|
50
|
-
"""
|
51
|
-
adata_ = adata[adata.obs[pert_key] == control]
|
52
|
-
|
53
|
-
df = pd.DataFrame(de.obsm[matching_rep])
|
54
|
-
if de_label is None:
|
55
|
-
de_label = "leiden"
|
56
|
-
sc.pp.neighbors(de, use_rep="X_embedding")
|
57
|
-
sc.tl.leiden(de, resolution=resolution)
|
58
|
-
df["de_label"] = de.obs[de_label].astype(str).values
|
59
|
-
df["de_label"] = "Response " + df["de_label"]
|
60
|
-
df = df.groupby("de_label").sum().T
|
61
|
-
df["source_label"] = adata_.obs[source_label].astype(str).values
|
62
|
-
df = df.groupby("source_label").sum()
|
63
|
-
|
64
|
-
if normalize == "col":
|
65
|
-
df = df / df.sum(axis=0)
|
66
|
-
else:
|
67
|
-
df = (df.T / df.sum(axis=1)).T
|
68
|
-
df = df.clip(lower=min_val) - min_val
|
69
|
-
if normalize == "col":
|
70
|
-
df = df / df.sum(axis=0)
|
71
|
-
else:
|
72
|
-
df = (df.T / df.sum(axis=1)).T
|
73
|
-
|
74
|
-
g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
|
75
|
-
plt.title(title)
|
76
|
-
_utils.savefig_or_show("matching_heatmap", show=show, save=save)
|
77
|
-
if not show:
|
78
|
-
if ax is not None:
|
79
|
-
return ax
|
80
|
-
else:
|
81
|
-
return g
|
pertpy/plot/_dialogue.py
DELETED
@@ -1,91 +0,0 @@
|
|
1
|
-
import matplotlib.pyplot as plt
|
2
|
-
import pandas as pd
|
3
|
-
import scanpy as sc
|
4
|
-
import seaborn as sns
|
5
|
-
from anndata import AnnData
|
6
|
-
from seaborn import PairGrid
|
7
|
-
|
8
|
-
|
9
|
-
class DialoguePlot:
|
10
|
-
@staticmethod
|
11
|
-
def split_violins(
|
12
|
-
adata: AnnData,
|
13
|
-
split_key: str,
|
14
|
-
celltype_key=str,
|
15
|
-
split_which: tuple[str, str] = None,
|
16
|
-
mcp: str = "mcp_0",
|
17
|
-
) -> plt.Axes:
|
18
|
-
"""Plots split violin plots for a given MCP and split variable.
|
19
|
-
|
20
|
-
Any cells with a value for split_key not in split_which are removed from the plot.
|
21
|
-
|
22
|
-
Args:
|
23
|
-
adata: Annotated data object.
|
24
|
-
split_key: Variable in adata.obs used to split the data.
|
25
|
-
celltype_key: Key for cell type annotations.
|
26
|
-
split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
|
27
|
-
mcp: Key for MCP data. Defaults to "mcp_0".
|
28
|
-
|
29
|
-
Returns:
|
30
|
-
A :class:`~matplotlib.axes.Axes` object
|
31
|
-
|
32
|
-
Examples:
|
33
|
-
>>> import pertpy as pt
|
34
|
-
>>> import scanpy as sc
|
35
|
-
>>> adata = pt.dt.dialogue_example()
|
36
|
-
>>> sc.pp.pca(adata)
|
37
|
-
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
38
|
-
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
39
|
-
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
40
|
-
>>> pt.pl.dl.split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
|
41
|
-
"""
|
42
|
-
df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
|
43
|
-
if split_which is None:
|
44
|
-
split_which = df[split_key].unique()
|
45
|
-
df = df[df[split_key].isin(split_which)]
|
46
|
-
df[split_key] = df[split_key].cat.remove_unused_categories()
|
47
|
-
|
48
|
-
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
49
|
-
|
50
|
-
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
51
|
-
|
52
|
-
return ax
|
53
|
-
|
54
|
-
@staticmethod
|
55
|
-
def pairplot(adata: AnnData, celltype_key: str, color: str, sample_id: str, mcp: str = "mcp_0") -> PairGrid:
|
56
|
-
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
57
|
-
|
58
|
-
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
|
59
|
-
then creates a pairplot to visualize the relationships between these mean MCP values.
|
60
|
-
|
61
|
-
Args:
|
62
|
-
adata: Annotated data object.
|
63
|
-
celltype_key: Key in adata.obs containing cell type annotations.
|
64
|
-
color: Key in adata.obs for color annotations. This parameter is used as the hue
|
65
|
-
sample_id: Key in adata.obs for the sample annotations.
|
66
|
-
mcp: Key in adata.obs for MCP feature values. Defaults to "mcp_0".
|
67
|
-
|
68
|
-
Returns:
|
69
|
-
Seaborn Pairgrid object.
|
70
|
-
|
71
|
-
Examples:
|
72
|
-
>>> import pertpy as pt
|
73
|
-
>>> import scanpy as sc
|
74
|
-
>>> adata = pt.dt.dialogue_example()
|
75
|
-
>>> sc.pp.pca(adata)
|
76
|
-
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
|
77
|
-
n_counts_key = "nCount_RNA", n_mpcs = 3)
|
78
|
-
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
|
79
|
-
>>> pt.pl.dl.pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
|
80
|
-
"""
|
81
|
-
mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
|
82
|
-
mean_mcps = mean_mcps.reset_index()
|
83
|
-
mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
|
84
|
-
|
85
|
-
aggstats = adata.obs.groupby([sample_id])[color].describe()
|
86
|
-
aggstats = aggstats.loc[list(mcp_pivot.index), :]
|
87
|
-
aggstats[color] = aggstats["top"]
|
88
|
-
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
89
|
-
ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
|
90
|
-
|
91
|
-
return ax
|
pertpy/plot/_scgen.py
DELETED
@@ -1,337 +0,0 @@
|
|
1
|
-
import numpy as np
|
2
|
-
import pandas as pd
|
3
|
-
import scanpy as sc
|
4
|
-
from adjustText import adjust_text
|
5
|
-
from matplotlib import pyplot
|
6
|
-
from scipy import stats
|
7
|
-
from scvi import REGISTRY_KEYS
|
8
|
-
|
9
|
-
|
10
|
-
class JaxscgenPlot:
|
11
|
-
"""Plotting functions for Jaxscgen."""
|
12
|
-
|
13
|
-
@staticmethod
|
14
|
-
def reg_mean_plot(
|
15
|
-
adata,
|
16
|
-
condition_key,
|
17
|
-
axis_keys,
|
18
|
-
labels,
|
19
|
-
path_to_save="./reg_mean.pdf",
|
20
|
-
save=True,
|
21
|
-
gene_list=None,
|
22
|
-
show=False,
|
23
|
-
top_100_genes=None,
|
24
|
-
verbose=False,
|
25
|
-
legend=True,
|
26
|
-
title=None,
|
27
|
-
x_coeff=0.30,
|
28
|
-
y_coeff=0.8,
|
29
|
-
fontsize=14,
|
30
|
-
**kwargs,
|
31
|
-
):
|
32
|
-
"""Plots mean matching figure for a set of specific genes.
|
33
|
-
|
34
|
-
Args:
|
35
|
-
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
36
|
-
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
37
|
-
corresponding to batch and cell type metadata, respectively.
|
38
|
-
condition_key: The key for the condition
|
39
|
-
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
40
|
-
`{"x": "Key for x-axis", "y": "Key for y-axis"}`.
|
41
|
-
labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
|
42
|
-
path_to_save: path to save the plot.
|
43
|
-
save: Specify if the plot should be saved or not.
|
44
|
-
gene_list: list of gene names to be plotted.
|
45
|
-
show: if `True`: will show to the plot after saving it.
|
46
|
-
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
47
|
-
verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
|
48
|
-
legend: if `True`: plots a legend, defaults to `True`.
|
49
|
-
title: Set if you want the plot to display a title.
|
50
|
-
x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
|
51
|
-
y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
|
52
|
-
fontsize: Fontsize used for text in the plot, defaults to 14.
|
53
|
-
**kwargs:
|
54
|
-
|
55
|
-
Examples:
|
56
|
-
>>> import pertpy at pt
|
57
|
-
>>> data = pt.dt.kang_2018()
|
58
|
-
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
59
|
-
>>> model = pt.tl.SCGEN(data)
|
60
|
-
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
61
|
-
>>> pred, delta = model.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
|
62
|
-
>>> pred.obs['label'] = 'pred'
|
63
|
-
>>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred)
|
64
|
-
>>> r2_value = pt.pl.scg.reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \
|
65
|
-
labels={"x": "predicted", "y": "ground truth"}, save=False, show=True)
|
66
|
-
"""
|
67
|
-
import seaborn as sns
|
68
|
-
|
69
|
-
sns.set()
|
70
|
-
sns.set(color_codes=True)
|
71
|
-
|
72
|
-
diff_genes = top_100_genes
|
73
|
-
stim = adata[adata.obs[condition_key] == axis_keys["y"]]
|
74
|
-
ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
|
75
|
-
if diff_genes is not None:
|
76
|
-
if hasattr(diff_genes, "tolist"):
|
77
|
-
diff_genes = diff_genes.tolist()
|
78
|
-
adata_diff = adata[:, diff_genes]
|
79
|
-
stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
|
80
|
-
ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
|
81
|
-
x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel()
|
82
|
-
y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
|
83
|
-
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
84
|
-
if verbose:
|
85
|
-
print("top_100 DEGs mean: ", r_value_diff**2)
|
86
|
-
x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
|
87
|
-
y = np.asarray(np.mean(stim.X, axis=0)).ravel()
|
88
|
-
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
89
|
-
if verbose:
|
90
|
-
print("All genes mean: ", r_value**2)
|
91
|
-
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
92
|
-
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
93
|
-
ax.tick_params(labelsize=fontsize)
|
94
|
-
if "range" in kwargs:
|
95
|
-
start, stop, step = kwargs.get("range")
|
96
|
-
ax.set_xticks(np.arange(start, stop, step))
|
97
|
-
ax.set_yticks(np.arange(start, stop, step))
|
98
|
-
ax.set_xlabel(labels["x"], fontsize=fontsize)
|
99
|
-
ax.set_ylabel(labels["y"], fontsize=fontsize)
|
100
|
-
if gene_list is not None:
|
101
|
-
texts = []
|
102
|
-
for i in gene_list:
|
103
|
-
j = adata.var_names.tolist().index(i)
|
104
|
-
x_bar = x[j]
|
105
|
-
y_bar = y[j]
|
106
|
-
texts.append(pyplot.text(x_bar, y_bar, i, fontsize=11, color="black"))
|
107
|
-
pyplot.plot(x_bar, y_bar, "o", color="red", markersize=5)
|
108
|
-
# if "y1" in axis_keys.keys():
|
109
|
-
# y1_bar = y1[j]
|
110
|
-
# pyplot.text(x_bar, y1_bar, i, fontsize=11, color="black")
|
111
|
-
if gene_list is not None:
|
112
|
-
adjust_text(
|
113
|
-
texts,
|
114
|
-
x=x,
|
115
|
-
y=y,
|
116
|
-
arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5},
|
117
|
-
force_points=(0.0, 0.0),
|
118
|
-
)
|
119
|
-
if legend:
|
120
|
-
pyplot.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
121
|
-
if title is None:
|
122
|
-
pyplot.title("", fontsize=fontsize)
|
123
|
-
else:
|
124
|
-
pyplot.title(title, fontsize=fontsize)
|
125
|
-
ax.text(
|
126
|
-
max(x) - max(x) * x_coeff,
|
127
|
-
max(y) - y_coeff * max(y),
|
128
|
-
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
|
129
|
-
fontsize=kwargs.get("textsize", fontsize),
|
130
|
-
)
|
131
|
-
if diff_genes is not None:
|
132
|
-
ax.text(
|
133
|
-
max(x) - max(x) * x_coeff,
|
134
|
-
max(y) - (y_coeff + 0.15) * max(y),
|
135
|
-
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
|
136
|
-
fontsize=kwargs.get("textsize", fontsize),
|
137
|
-
)
|
138
|
-
if save:
|
139
|
-
pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
|
140
|
-
if show:
|
141
|
-
pyplot.show()
|
142
|
-
pyplot.close()
|
143
|
-
if diff_genes is not None:
|
144
|
-
return r_value**2, r_value_diff**2
|
145
|
-
else:
|
146
|
-
return r_value**2
|
147
|
-
|
148
|
-
@staticmethod
|
149
|
-
def reg_var_plot(
|
150
|
-
adata,
|
151
|
-
condition_key,
|
152
|
-
axis_keys,
|
153
|
-
labels,
|
154
|
-
path_to_save="./reg_var.pdf",
|
155
|
-
save=True,
|
156
|
-
gene_list=None,
|
157
|
-
top_100_genes=None,
|
158
|
-
show=False,
|
159
|
-
legend=True,
|
160
|
-
title=None,
|
161
|
-
verbose=False,
|
162
|
-
x_coeff=0.30,
|
163
|
-
y_coeff=0.8,
|
164
|
-
fontsize=14,
|
165
|
-
**kwargs,
|
166
|
-
):
|
167
|
-
"""Plots variance matching figure for a set of specific genes.
|
168
|
-
|
169
|
-
Args:
|
170
|
-
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
171
|
-
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
172
|
-
corresponding to batch and cell type metadata, respectively.
|
173
|
-
condition_key: Key of the condition.
|
174
|
-
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
175
|
-
`{"x": "Key for x-axis", "y": "Key for y-axis"}`.
|
176
|
-
labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
|
177
|
-
path_to_save: path to save the plot.
|
178
|
-
save: Specify if the plot should be saved or not.
|
179
|
-
gene_list: list of gene names to be plotted.
|
180
|
-
show: if `True`: will show to the plot after saving it.
|
181
|
-
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
182
|
-
legend: if `True`: plots a legend, defaults to `True`.
|
183
|
-
title: Set if you want the plot to display a title.
|
184
|
-
verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
|
185
|
-
x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
|
186
|
-
y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
|
187
|
-
fontsize: Fontsize used for text in the plot, defaults to 14.
|
188
|
-
"""
|
189
|
-
import seaborn as sns
|
190
|
-
|
191
|
-
sns.set()
|
192
|
-
sns.set(color_codes=True)
|
193
|
-
|
194
|
-
sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon")
|
195
|
-
diff_genes = top_100_genes
|
196
|
-
stim = adata[adata.obs[condition_key] == axis_keys["y"]]
|
197
|
-
ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
|
198
|
-
if diff_genes is not None:
|
199
|
-
if hasattr(diff_genes, "tolist"):
|
200
|
-
diff_genes = diff_genes.tolist()
|
201
|
-
adata_diff = adata[:, diff_genes]
|
202
|
-
stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
|
203
|
-
ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
|
204
|
-
x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel()
|
205
|
-
y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
|
206
|
-
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
207
|
-
if verbose:
|
208
|
-
print("Top 100 DEGs var: ", r_value_diff**2)
|
209
|
-
if "y1" in axis_keys.keys():
|
210
|
-
real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
|
211
|
-
x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
|
212
|
-
y = np.asarray(np.var(stim.X, axis=0)).ravel()
|
213
|
-
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
214
|
-
if verbose:
|
215
|
-
print("All genes var: ", r_value**2)
|
216
|
-
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
217
|
-
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
218
|
-
ax.tick_params(labelsize=fontsize)
|
219
|
-
if "range" in kwargs:
|
220
|
-
start, stop, step = kwargs.get("range")
|
221
|
-
ax.set_xticks(np.arange(start, stop, step))
|
222
|
-
ax.set_yticks(np.arange(start, stop, step))
|
223
|
-
# _p1 = pyplot.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}")
|
224
|
-
# pyplot.plot(x, m * x + b, "-", color="green")
|
225
|
-
ax.set_xlabel(labels["x"], fontsize=fontsize)
|
226
|
-
ax.set_ylabel(labels["y"], fontsize=fontsize)
|
227
|
-
if "y1" in axis_keys.keys():
|
228
|
-
y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
|
229
|
-
_ = pyplot.scatter(
|
230
|
-
x,
|
231
|
-
y1,
|
232
|
-
marker="*",
|
233
|
-
c="grey",
|
234
|
-
alpha=0.5,
|
235
|
-
label=f"{axis_keys['x']}-{axis_keys['y1']}",
|
236
|
-
)
|
237
|
-
if gene_list is not None:
|
238
|
-
for i in gene_list:
|
239
|
-
j = adata.var_names.tolist().index(i)
|
240
|
-
x_bar = x[j]
|
241
|
-
y_bar = y[j]
|
242
|
-
pyplot.text(x_bar, y_bar, i, fontsize=11, color="black")
|
243
|
-
pyplot.plot(x_bar, y_bar, "o", color="red", markersize=5)
|
244
|
-
if "y1" in axis_keys.keys():
|
245
|
-
y1_bar = y1[j]
|
246
|
-
pyplot.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
|
247
|
-
if legend:
|
248
|
-
pyplot.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
249
|
-
if title is None:
|
250
|
-
pyplot.title("", fontsize=12)
|
251
|
-
else:
|
252
|
-
pyplot.title(title, fontsize=12)
|
253
|
-
ax.text(
|
254
|
-
max(x) - max(x) * x_coeff,
|
255
|
-
max(y) - y_coeff * max(y),
|
256
|
-
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
|
257
|
-
fontsize=kwargs.get("textsize", fontsize),
|
258
|
-
)
|
259
|
-
if diff_genes is not None:
|
260
|
-
ax.text(
|
261
|
-
max(x) - max(x) * x_coeff,
|
262
|
-
max(y) - (y_coeff + 0.15) * max(y),
|
263
|
-
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
|
264
|
-
fontsize=kwargs.get("textsize", fontsize),
|
265
|
-
)
|
266
|
-
|
267
|
-
if save:
|
268
|
-
pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
|
269
|
-
if show:
|
270
|
-
pyplot.show()
|
271
|
-
pyplot.close()
|
272
|
-
if diff_genes is not None:
|
273
|
-
return r_value**2, r_value_diff**2
|
274
|
-
else:
|
275
|
-
return r_value**2
|
276
|
-
|
277
|
-
@staticmethod
|
278
|
-
def binary_classifier(
|
279
|
-
scgen,
|
280
|
-
adata,
|
281
|
-
delta,
|
282
|
-
ctrl_key,
|
283
|
-
stim_key,
|
284
|
-
path_to_save,
|
285
|
-
save=True,
|
286
|
-
fontsize=14,
|
287
|
-
):
|
288
|
-
"""Latent space classifier.
|
289
|
-
|
290
|
-
Builds a linear classifier based on the dot product between
|
291
|
-
the difference vector and the latent representation of each
|
292
|
-
cell and plots the dot product results between delta and latent representation.
|
293
|
-
|
294
|
-
Args:
|
295
|
-
scgen: ScGen object that was trained.
|
296
|
-
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
297
|
-
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
298
|
-
corresponding to batch and cell type metadata, respectively.
|
299
|
-
delta: Difference between stimulated and control cells in latent space
|
300
|
-
ctrl_key: Key for `control` part of the `data` found in `condition_key`.
|
301
|
-
stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
|
302
|
-
path_to_save: Path to save the plot.
|
303
|
-
save: Specify if the plot should be saved or not.
|
304
|
-
fontsize: Set the font size of the plot.
|
305
|
-
"""
|
306
|
-
# matplotlib.rcParams.update(matplotlib.rcParamsDefault)
|
307
|
-
pyplot.close("all")
|
308
|
-
adata = scgen._validate_anndata(adata)
|
309
|
-
condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
|
310
|
-
cd = adata[adata.obs[condition_key] == ctrl_key, :]
|
311
|
-
stim = adata[adata.obs[condition_key] == stim_key, :]
|
312
|
-
all_latent_cd = scgen.get_latent_representation(cd.X)
|
313
|
-
all_latent_stim = scgen.get_latent_representation(stim.X)
|
314
|
-
dot_cd = np.zeros(len(all_latent_cd))
|
315
|
-
dot_sal = np.zeros(len(all_latent_stim))
|
316
|
-
for ind, vec in enumerate(all_latent_cd):
|
317
|
-
dot_cd[ind] = np.dot(delta, vec)
|
318
|
-
for ind, vec in enumerate(all_latent_stim):
|
319
|
-
dot_sal[ind] = np.dot(delta, vec)
|
320
|
-
pyplot.hist(
|
321
|
-
dot_cd,
|
322
|
-
label=ctrl_key,
|
323
|
-
bins=50,
|
324
|
-
)
|
325
|
-
pyplot.hist(dot_sal, label=stim_key, bins=50)
|
326
|
-
pyplot.axvline(0, color="k", linestyle="dashed", linewidth=1)
|
327
|
-
pyplot.title(" ", fontsize=fontsize)
|
328
|
-
pyplot.xlabel(" ", fontsize=fontsize)
|
329
|
-
pyplot.ylabel(" ", fontsize=fontsize)
|
330
|
-
pyplot.xticks(fontsize=fontsize)
|
331
|
-
pyplot.yticks(fontsize=fontsize)
|
332
|
-
ax = pyplot.gca()
|
333
|
-
ax.grid(False)
|
334
|
-
|
335
|
-
if save:
|
336
|
-
pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
|
337
|
-
pyplot.show()
|
File without changes
|