pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,11 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: pertpy
3
- Version: 0.6.0
3
+ Version: 0.7.0
4
4
  Summary: Perturbation Analysis in the scverse ecosystem.
5
5
  Project-URL: Documentation, https://pertpy.readthedocs.io
6
6
  Project-URL: Source, https://github.com/theislab/pertpy
7
7
  Project-URL: Home-page, https://github.com/theislab/pertpy
8
- Author: Lukas Heumos, Yuge Ji, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu, Amir Moinfar, Sergei Rybakov, Tessa Green, Stefan Peidli, Antonia Schumacher
8
+ Author: Lukas Heumos, Yuge Ji, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu, Amir Moinfar, Sergei Rybakov, Tessa Green, Stefan Peidli, Antonia Schumacher, Lilly May
9
9
  Maintainer-email: Lukas Heumos <lukas.heumos@posteo.net>
10
10
  License: MIT License
11
11
 
@@ -29,18 +29,31 @@ License: MIT License
29
29
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
30
  SOFTWARE.
31
31
  License-File: LICENSE
32
- Requires-Python: >=3.9
32
+ Classifier: Development Status :: 4 - Beta
33
+ Classifier: Environment :: Console
34
+ Classifier: Framework :: Jupyter
35
+ Classifier: Intended Audience :: Developers
36
+ Classifier: Intended Audience :: Science/Research
37
+ Classifier: License :: OSI Approved :: Apache Software License
38
+ Classifier: Natural Language :: English
39
+ Classifier: Operating System :: MacOS :: MacOS X
40
+ Classifier: Operating System :: POSIX :: Linux
41
+ Classifier: Programming Language :: Python :: 3
42
+ Classifier: Programming Language :: Python :: 3.10
43
+ Classifier: Programming Language :: Python :: 3.11
44
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
45
+ Classifier: Topic :: Scientific/Engineering :: Visualization
46
+ Requires-Python: >=3.10
33
47
  Requires-Dist: adjusttext
34
- Requires-Dist: anndata
35
48
  Requires-Dist: arviz
49
+ Requires-Dist: blitzgsea
36
50
  Requires-Dist: decoupler
37
- Requires-Dist: ipywidgets
38
51
  Requires-Dist: muon
39
- Requires-Dist: numba
40
52
  Requires-Dist: numpyro
41
53
  Requires-Dist: openpyxl
42
54
  Requires-Dist: ott-jax
43
- Requires-Dist: plotnine
55
+ Requires-Dist: pubchempy
56
+ Requires-Dist: pyarrow
44
57
  Requires-Dist: requests
45
58
  Requires-Dist: rich
46
59
  Requires-Dist: scanpy[leiden]
@@ -76,13 +89,13 @@ Requires-Dist: sphinx>=4; extra == 'doc'
76
89
  Requires-Dist: sphinxcontrib-bibtex>=1.0.0; extra == 'doc'
77
90
  Requires-Dist: sphinxext-opengraph; extra == 'doc'
78
91
  Provides-Extra: test
92
+ Requires-Dist: coverage; extra == 'test'
79
93
  Requires-Dist: pytest; extra == 'test'
80
- Requires-Dist: pytest-cov; extra == 'test'
81
94
  Description-Content-Type: text/markdown
82
95
 
83
96
  [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
84
97
  [![Build](https://github.com/theislab/pertpy/actions/workflows/build.yml/badge.svg)](https://github.com/theislab/pertpy/actions/workflows/build.yml)
85
- [![Codecov](https://codecov.io/gh/theislab/pertpy/branch/master/graph/badge.svg)](https://codecov.io/gh/theislab/pertpy)
98
+ [![codecov](https://codecov.io/gh/theislab/pertpy/graph/badge.svg?token=1dTpIPBShv)](https://codecov.io/gh/theislab/pertpy)
86
99
  [![License](https://img.shields.io/github/license/theislab/pertpy)](https://opensource.org/licenses/Apache2.0)
87
100
  [![PyPI](https://img.shields.io/pypi/v/pertpy.svg)](https://pypi.org/project/pertpy/)
88
101
  [![Python Version](https://img.shields.io/pypi/pyversions/pertpy)](https://pypi.org/project/pertpy)
@@ -92,7 +105,7 @@ Description-Content-Type: text/markdown
92
105
 
93
106
  # pertpy
94
107
 
95
- ![pertpy-wide1](https://user-images.githubusercontent.com/21954664/235677503-0c72f90d-3f6d-4a16-a1ff-ff8c11a540fb.png)
108
+ ![fig1](https://github.com/theislab/pertpy/assets/99650244/182fa9c3-6d23-4002-b86a-82bf2a243377)
96
109
 
97
110
  ## Documentation
98
111
 
@@ -103,7 +116,13 @@ Please read the [documentation](https://pertpy.readthedocs.io/en/latest).
103
116
  You can install _pertpy_ via [pip] from [PyPI]:
104
117
 
105
118
  ```console
106
- $ pip install pertpy
119
+ pip install pertpy
120
+ ```
121
+
122
+ if you want to use scCODA please install it as:
123
+
124
+ ```console
125
+ pip install pertpy[coda]
107
126
  ```
108
127
 
109
128
  [pip]: https://pip.pypa.io/
@@ -0,0 +1,53 @@
1
+ pertpy/__init__.py,sha256=HI_6a6S0V2YQmRZrlopJ9fxrHVEsBc0GhcX-fZZQCWk,600
2
+ pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ pertpy/data/__init__.py,sha256=zj4jWyw_Dr6HJvdznQi9QUkZCoFN6vdFRhFF_tw4jts,1275
4
+ pertpy/data/_dataloader.py,sha256=UTIFK2pofjO6HJjj1dJRbZiQ04_Sslg0d6ggBRkaTnk,2325
5
+ pertpy/data/_datasets.py,sha256=78RQb4lGFhCDQdAwb8p_MYy4ZxFjRhhfAsKkuWMl5EQ,63179
6
+ pertpy/metadata/__init__.py,sha256=9FC_FfcpGTsc3j5orjr3USkfsT63REnQBy4r7_VRjco,171
7
+ pertpy/metadata/_cell_line.py,sha256=tA17gWp9uY0jn36S5bt5lfc9zU2HC2DxZ6x5jVRUosU,40644
8
+ pertpy/metadata/_compound.py,sha256=Y1oaV_Nq4-wDVFPQhenJTjvOlY4HYi8bCxUUN3VNbMc,4861
9
+ pertpy/metadata/_drug.py,sha256=q3snazwU_wTLC3Js-9pr_CvfBCRGuFBLlHSFVV-h3ko,9357
10
+ pertpy/metadata/_look_up.py,sha256=Tb9BHZVc05aH4Kg2KLcfrgvra8gEV5aBqwAZ433krMc,29476
11
+ pertpy/metadata/_metadata.py,sha256=h_m5bjn_DvCOuiie_q_DB6eM9ermJKXSnLJ0i84AGqI,3419
12
+ pertpy/metadata/_moa.py,sha256=gLjM59ANqK5Ov-Y-v6COr5Bz4nI2O_LEBfJ2jRpR5U8,4941
13
+ pertpy/plot/__init__.py,sha256=ZbOAZMpnnAy8GR8oSXJhuwA3gr9GDqgzx0fEkz0TFaE,257
14
+ pertpy/plot/_augur.py,sha256=uHBoKVeNMcQUV23eHX_k3KQ4sa7wJUxn8VYgoGM0XSw,6713
15
+ pertpy/plot/_coda.py,sha256=qZfkfFT61pwf-rlOYPwJiYDkx_1DTkT4TfG9fObQWoc,26891
16
+ pertpy/plot/_guide_rna.py,sha256=iIjkOus7tftIv01zaCTamSX2q8bHBI8O4dxv6QDRVM0,2702
17
+ pertpy/plot/_milopy.py,sha256=LGMJ4UQFAoRXtBC5LrUtskCU3YEQLzGvMOPuxBg1Tak,7414
18
+ pertpy/plot/_mixscape.py,sha256=wKsY1efBP9VtQptAq79w5-ODcgzomD3QgW8iOF-5LZY,14862
19
+ pertpy/preprocessing/__init__.py,sha256=uja9T469LLYQAGgrTyFa4MudXci6NXnAgOn97FHXcxA,40
20
+ pertpy/preprocessing/_guide_rna.py,sha256=9FSu0QCxZujGPxhl-PcU4yRR7ZSGhgg99eJaLmf8hRA,7713
21
+ pertpy/tools/__init__.py,sha256=zs_xmmmefEFbOQFMFhh_ds9StyuZ1iiSDdk4qhNlQq4,931
22
+ pertpy/tools/_augur.py,sha256=Lws7-oSjSFGRupoloj7yIA3YNvVt_dcjSnQ8_Ctt2u0,55797
23
+ pertpy/tools/_cinemaot.py,sha256=EbnucYl-Q3sfPDL9RiA-cXjpGFJYvfF4CaDEkyr7Snc,39501
24
+ pertpy/tools/_dialogue.py,sha256=u98h147h2PwhWxFQo8limNmt20e2r5La9-wZvS6tXjo,52143
25
+ pertpy/tools/_differential_gene_expression.py,sha256=eqit0SJ_fIpSYRS5grYZxEln_JkpXyng-q3LUw9ArY8,14360
26
+ pertpy/tools/_enrichment.py,sha256=U2WY7t--FCQFZsXXBLZYBwF4WO9t5WToWjCGn_-tP6I,21922
27
+ pertpy/tools/_kernel_pca.py,sha256=3S1D_wrp4vlHUPiRbCAoRbUyY-rVs112Qh-BZHSmTxE,1578
28
+ pertpy/tools/_milo.py,sha256=SxpEjhvn2FOYQcQ56tG9ts8fS2vpp4DsB3V7lr8QHbc,43388
29
+ pertpy/tools/_mixscape.py,sha256=5s0BrLsv1uv1tbZkZnNXN2qUzWjxvIuGe2iu9czBIxA,52500
30
+ pertpy/tools/decoupler_LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
31
+ pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
32
+ pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
+ pertpy/tools/_coda/_base_coda.py,sha256=L6OEz7PHDve_geS6JdyjMC5fSMNitZ4pYQOyMjSW04E,115262
34
+ pertpy/tools/_coda/_sccoda.py,sha256=K_zsU8z_YP1hwJV1Urug4fd7kS-q23eLkcSsX_hxqYg,22849
35
+ pertpy/tools/_coda/_tasccoda.py,sha256=xTe89TMtbUkw7WAEAvWzOQsMotdekP3Vs5zhMtB2t6c,31033
36
+ pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
+ pertpy/tools/_distances/_distance_tests.py,sha256=aC4zspn4nJ2JRk56c76xTgSxl9znY7QI7zLktwsQ8SY,13604
38
+ pertpy/tools/_distances/_distances.py,sha256=_dB9dyqWvg_yJBMJTuruFw5w-CLtQWqol-28KVifPBM,35619
39
+ pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
+ pertpy/tools/_perturbation_space/_clustering.py,sha256=a1p0vzjN9nsmtWXWaKIeR3To3mL-7mzqD3HTTiHCpRk,3470
41
+ pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=JYDPJlyqxeFbBqWqXk0el6ClXIwrUIYTye5ehmQvmao,22080
42
+ pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
43
+ pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=VwwYjQ8-VfRHp8_97Id320xd3tnvv-wea6iLFkRE0RY,18868
44
+ pertpy/tools/_perturbation_space/_simple.py,sha256=37zZweW-50Oa8pTQyC53KyNZQH157HurSc8qXZZt7JA,12681
45
+ pertpy/tools/_scgen/__init__.py,sha256=neOj3JEcXxnIvgSMrSuPaUSGl1upo7yQ7K_NXHb8Rb8,45
46
+ pertpy/tools/_scgen/_base_components.py,sha256=dIw-_7Z8iCietPF4tnpM7bFHtDksjnaHXwUjp9GoCIQ,2936
47
+ pertpy/tools/_scgen/_scgen.py,sha256=qPu_zWCdLvcuOqSDNb2FLzxE5Y1uH0uSbpPGu7kOKks,30705
48
+ pertpy/tools/_scgen/_scgenvae.py,sha256=v_6tZ4wY-JjdMH1QVd_wG4_N0PoaqB-FM8zC2JsDu1o,3935
49
+ pertpy/tools/_scgen/_utils.py,sha256=y0LGS1OLmIVUBq2ZYySM2Up51o3c08-yjTvUkFb3E0U,2841
50
+ pertpy-0.7.0.dist-info/METADATA,sha256=tSfWThb8uxC4jwbO5VfOIikt8DRBpO4UdHdFde0al3w,5782
51
+ pertpy-0.7.0.dist-info/WHEEL,sha256=as-1oFTWSeWBgyzh0O_qF439xqBe6AbBgt4MfYe5zwY,87
52
+ pertpy-0.7.0.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
53
+ pertpy-0.7.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.18.0
2
+ Generator: hatchling 1.22.5
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
pertpy/plot/_cinemaot.py DELETED
@@ -1,81 +0,0 @@
1
- from typing import Optional
2
-
3
- import matplotlib.pyplot as plt
4
- import pandas as pd
5
- import scanpy as sc
6
- import seaborn as sns
7
- from anndata import AnnData
8
- from matplotlib.axes import Axes
9
- from scanpy.plotting import _utils
10
-
11
-
12
- class CinemaotPlot:
13
- """Plotting functions for CINEMA-OT. Only includes new functions beyond the scanpy.pl.embedding family."""
14
-
15
- @staticmethod
16
- def vis_matching(
17
- adata: AnnData,
18
- de: AnnData,
19
- pert_key: str,
20
- control: str,
21
- de_label: str,
22
- source_label: str,
23
- matching_rep: str = "ot",
24
- resolution: float = 0.5,
25
- normalize: str = "col",
26
- title: str = "CINEMA-OT matching matrix",
27
- min_val: float = 0.01,
28
- show: bool = True,
29
- save: Optional[str] = None,
30
- ax: Optional[Axes] = None,
31
- **kwargs,
32
- ) -> None:
33
- """Visualize the CINEMA-OT matching matrix.
34
-
35
- Args:
36
- adata: the original anndata after running cinemaot.causaleffect or cinemaot.causaleffect_weighted.
37
- de: The anndata output from Cinemaot.causaleffect() or Cinemaot.causaleffect_weighted().
38
- pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
39
- control: Control category from the `pert_key` column.
40
- de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0.
41
- source_label: the confounder / cell type label.
42
- matching_rep: the place that stores the matching matrix. default de.obsm['ot'].
43
- normalize: normalize the coarse-grained matching matrix by row / column.
44
- title: the title for the figure.
45
- min_val: The min value to truncate the matching matrix.
46
- show: Show the plot, do not return axis.
47
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
48
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
49
- **kwargs: Other parameters to input for seaborn.heatmap.
50
- """
51
- adata_ = adata[adata.obs[pert_key] == control]
52
-
53
- df = pd.DataFrame(de.obsm[matching_rep])
54
- if de_label is None:
55
- de_label = "leiden"
56
- sc.pp.neighbors(de, use_rep="X_embedding")
57
- sc.tl.leiden(de, resolution=resolution)
58
- df["de_label"] = de.obs[de_label].astype(str).values
59
- df["de_label"] = "Response " + df["de_label"]
60
- df = df.groupby("de_label").sum().T
61
- df["source_label"] = adata_.obs[source_label].astype(str).values
62
- df = df.groupby("source_label").sum()
63
-
64
- if normalize == "col":
65
- df = df / df.sum(axis=0)
66
- else:
67
- df = (df.T / df.sum(axis=1)).T
68
- df = df.clip(lower=min_val) - min_val
69
- if normalize == "col":
70
- df = df / df.sum(axis=0)
71
- else:
72
- df = (df.T / df.sum(axis=1)).T
73
-
74
- g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
75
- plt.title(title)
76
- _utils.savefig_or_show("matching_heatmap", show=show, save=save)
77
- if not show:
78
- if ax is not None:
79
- return ax
80
- else:
81
- return g
pertpy/plot/_dialogue.py DELETED
@@ -1,91 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import pandas as pd
3
- import scanpy as sc
4
- import seaborn as sns
5
- from anndata import AnnData
6
- from seaborn import PairGrid
7
-
8
-
9
- class DialoguePlot:
10
- @staticmethod
11
- def split_violins(
12
- adata: AnnData,
13
- split_key: str,
14
- celltype_key=str,
15
- split_which: tuple[str, str] = None,
16
- mcp: str = "mcp_0",
17
- ) -> plt.Axes:
18
- """Plots split violin plots for a given MCP and split variable.
19
-
20
- Any cells with a value for split_key not in split_which are removed from the plot.
21
-
22
- Args:
23
- adata: Annotated data object.
24
- split_key: Variable in adata.obs used to split the data.
25
- celltype_key: Key for cell type annotations.
26
- split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
27
- mcp: Key for MCP data. Defaults to "mcp_0".
28
-
29
- Returns:
30
- A :class:`~matplotlib.axes.Axes` object
31
-
32
- Examples:
33
- >>> import pertpy as pt
34
- >>> import scanpy as sc
35
- >>> adata = pt.dt.dialogue_example()
36
- >>> sc.pp.pca(adata)
37
- >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
38
- n_counts_key = "nCount_RNA", n_mpcs = 3)
39
- >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
40
- >>> pt.pl.dl.split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
41
- """
42
- df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
43
- if split_which is None:
44
- split_which = df[split_key].unique()
45
- df = df[df[split_key].isin(split_which)]
46
- df[split_key] = df[split_key].cat.remove_unused_categories()
47
-
48
- ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
49
-
50
- ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
51
-
52
- return ax
53
-
54
- @staticmethod
55
- def pairplot(adata: AnnData, celltype_key: str, color: str, sample_id: str, mcp: str = "mcp_0") -> PairGrid:
56
- """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
57
-
58
- Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
59
- then creates a pairplot to visualize the relationships between these mean MCP values.
60
-
61
- Args:
62
- adata: Annotated data object.
63
- celltype_key: Key in adata.obs containing cell type annotations.
64
- color: Key in adata.obs for color annotations. This parameter is used as the hue
65
- sample_id: Key in adata.obs for the sample annotations.
66
- mcp: Key in adata.obs for MCP feature values. Defaults to "mcp_0".
67
-
68
- Returns:
69
- Seaborn Pairgrid object.
70
-
71
- Examples:
72
- >>> import pertpy as pt
73
- >>> import scanpy as sc
74
- >>> adata = pt.dt.dialogue_example()
75
- >>> sc.pp.pca(adata)
76
- >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
77
- n_counts_key = "nCount_RNA", n_mpcs = 3)
78
- >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
79
- >>> pt.pl.dl.pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
80
- """
81
- mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
82
- mean_mcps = mean_mcps.reset_index()
83
- mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
84
-
85
- aggstats = adata.obs.groupby([sample_id])[color].describe()
86
- aggstats = aggstats.loc[list(mcp_pivot.index), :]
87
- aggstats[color] = aggstats["top"]
88
- mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
89
- ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
90
-
91
- return ax
pertpy/plot/_scgen.py DELETED
@@ -1,337 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import scanpy as sc
4
- from adjustText import adjust_text
5
- from matplotlib import pyplot
6
- from scipy import stats
7
- from scvi import REGISTRY_KEYS
8
-
9
-
10
- class JaxscgenPlot:
11
- """Plotting functions for Jaxscgen."""
12
-
13
- @staticmethod
14
- def reg_mean_plot(
15
- adata,
16
- condition_key,
17
- axis_keys,
18
- labels,
19
- path_to_save="./reg_mean.pdf",
20
- save=True,
21
- gene_list=None,
22
- show=False,
23
- top_100_genes=None,
24
- verbose=False,
25
- legend=True,
26
- title=None,
27
- x_coeff=0.30,
28
- y_coeff=0.8,
29
- fontsize=14,
30
- **kwargs,
31
- ):
32
- """Plots mean matching figure for a set of specific genes.
33
-
34
- Args:
35
- adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
36
- AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
37
- corresponding to batch and cell type metadata, respectively.
38
- condition_key: The key for the condition
39
- axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
40
- `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
41
- labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
42
- path_to_save: path to save the plot.
43
- save: Specify if the plot should be saved or not.
44
- gene_list: list of gene names to be plotted.
45
- show: if `True`: will show to the plot after saving it.
46
- top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
47
- verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
48
- legend: if `True`: plots a legend, defaults to `True`.
49
- title: Set if you want the plot to display a title.
50
- x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
51
- y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
52
- fontsize: Fontsize used for text in the plot, defaults to 14.
53
- **kwargs:
54
-
55
- Examples:
56
- >>> import pertpy at pt
57
- >>> data = pt.dt.kang_2018()
58
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
59
- >>> model = pt.tl.SCGEN(data)
60
- >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
61
- >>> pred, delta = model.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
62
- >>> pred.obs['label'] = 'pred'
63
- >>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred)
64
- >>> r2_value = pt.pl.scg.reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \
65
- labels={"x": "predicted", "y": "ground truth"}, save=False, show=True)
66
- """
67
- import seaborn as sns
68
-
69
- sns.set()
70
- sns.set(color_codes=True)
71
-
72
- diff_genes = top_100_genes
73
- stim = adata[adata.obs[condition_key] == axis_keys["y"]]
74
- ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
75
- if diff_genes is not None:
76
- if hasattr(diff_genes, "tolist"):
77
- diff_genes = diff_genes.tolist()
78
- adata_diff = adata[:, diff_genes]
79
- stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
80
- ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
81
- x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel()
82
- y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
83
- m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
84
- if verbose:
85
- print("top_100 DEGs mean: ", r_value_diff**2)
86
- x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
87
- y = np.asarray(np.mean(stim.X, axis=0)).ravel()
88
- m, b, r_value, p_value, std_err = stats.linregress(x, y)
89
- if verbose:
90
- print("All genes mean: ", r_value**2)
91
- df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
92
- ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
93
- ax.tick_params(labelsize=fontsize)
94
- if "range" in kwargs:
95
- start, stop, step = kwargs.get("range")
96
- ax.set_xticks(np.arange(start, stop, step))
97
- ax.set_yticks(np.arange(start, stop, step))
98
- ax.set_xlabel(labels["x"], fontsize=fontsize)
99
- ax.set_ylabel(labels["y"], fontsize=fontsize)
100
- if gene_list is not None:
101
- texts = []
102
- for i in gene_list:
103
- j = adata.var_names.tolist().index(i)
104
- x_bar = x[j]
105
- y_bar = y[j]
106
- texts.append(pyplot.text(x_bar, y_bar, i, fontsize=11, color="black"))
107
- pyplot.plot(x_bar, y_bar, "o", color="red", markersize=5)
108
- # if "y1" in axis_keys.keys():
109
- # y1_bar = y1[j]
110
- # pyplot.text(x_bar, y1_bar, i, fontsize=11, color="black")
111
- if gene_list is not None:
112
- adjust_text(
113
- texts,
114
- x=x,
115
- y=y,
116
- arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5},
117
- force_points=(0.0, 0.0),
118
- )
119
- if legend:
120
- pyplot.legend(loc="center left", bbox_to_anchor=(1, 0.5))
121
- if title is None:
122
- pyplot.title("", fontsize=fontsize)
123
- else:
124
- pyplot.title(title, fontsize=fontsize)
125
- ax.text(
126
- max(x) - max(x) * x_coeff,
127
- max(y) - y_coeff * max(y),
128
- r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
129
- fontsize=kwargs.get("textsize", fontsize),
130
- )
131
- if diff_genes is not None:
132
- ax.text(
133
- max(x) - max(x) * x_coeff,
134
- max(y) - (y_coeff + 0.15) * max(y),
135
- r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
136
- fontsize=kwargs.get("textsize", fontsize),
137
- )
138
- if save:
139
- pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
140
- if show:
141
- pyplot.show()
142
- pyplot.close()
143
- if diff_genes is not None:
144
- return r_value**2, r_value_diff**2
145
- else:
146
- return r_value**2
147
-
148
- @staticmethod
149
- def reg_var_plot(
150
- adata,
151
- condition_key,
152
- axis_keys,
153
- labels,
154
- path_to_save="./reg_var.pdf",
155
- save=True,
156
- gene_list=None,
157
- top_100_genes=None,
158
- show=False,
159
- legend=True,
160
- title=None,
161
- verbose=False,
162
- x_coeff=0.30,
163
- y_coeff=0.8,
164
- fontsize=14,
165
- **kwargs,
166
- ):
167
- """Plots variance matching figure for a set of specific genes.
168
-
169
- Args:
170
- adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
171
- AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
172
- corresponding to batch and cell type metadata, respectively.
173
- condition_key: Key of the condition.
174
- axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
175
- `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
176
- labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
177
- path_to_save: path to save the plot.
178
- save: Specify if the plot should be saved or not.
179
- gene_list: list of gene names to be plotted.
180
- show: if `True`: will show to the plot after saving it.
181
- top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
182
- legend: if `True`: plots a legend, defaults to `True`.
183
- title: Set if you want the plot to display a title.
184
- verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
185
- x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
186
- y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
187
- fontsize: Fontsize used for text in the plot, defaults to 14.
188
- """
189
- import seaborn as sns
190
-
191
- sns.set()
192
- sns.set(color_codes=True)
193
-
194
- sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon")
195
- diff_genes = top_100_genes
196
- stim = adata[adata.obs[condition_key] == axis_keys["y"]]
197
- ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
198
- if diff_genes is not None:
199
- if hasattr(diff_genes, "tolist"):
200
- diff_genes = diff_genes.tolist()
201
- adata_diff = adata[:, diff_genes]
202
- stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
203
- ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
204
- x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel()
205
- y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
206
- m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
207
- if verbose:
208
- print("Top 100 DEGs var: ", r_value_diff**2)
209
- if "y1" in axis_keys.keys():
210
- real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
211
- x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
212
- y = np.asarray(np.var(stim.X, axis=0)).ravel()
213
- m, b, r_value, p_value, std_err = stats.linregress(x, y)
214
- if verbose:
215
- print("All genes var: ", r_value**2)
216
- df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
217
- ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
218
- ax.tick_params(labelsize=fontsize)
219
- if "range" in kwargs:
220
- start, stop, step = kwargs.get("range")
221
- ax.set_xticks(np.arange(start, stop, step))
222
- ax.set_yticks(np.arange(start, stop, step))
223
- # _p1 = pyplot.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}")
224
- # pyplot.plot(x, m * x + b, "-", color="green")
225
- ax.set_xlabel(labels["x"], fontsize=fontsize)
226
- ax.set_ylabel(labels["y"], fontsize=fontsize)
227
- if "y1" in axis_keys.keys():
228
- y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
229
- _ = pyplot.scatter(
230
- x,
231
- y1,
232
- marker="*",
233
- c="grey",
234
- alpha=0.5,
235
- label=f"{axis_keys['x']}-{axis_keys['y1']}",
236
- )
237
- if gene_list is not None:
238
- for i in gene_list:
239
- j = adata.var_names.tolist().index(i)
240
- x_bar = x[j]
241
- y_bar = y[j]
242
- pyplot.text(x_bar, y_bar, i, fontsize=11, color="black")
243
- pyplot.plot(x_bar, y_bar, "o", color="red", markersize=5)
244
- if "y1" in axis_keys.keys():
245
- y1_bar = y1[j]
246
- pyplot.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
247
- if legend:
248
- pyplot.legend(loc="center left", bbox_to_anchor=(1, 0.5))
249
- if title is None:
250
- pyplot.title("", fontsize=12)
251
- else:
252
- pyplot.title(title, fontsize=12)
253
- ax.text(
254
- max(x) - max(x) * x_coeff,
255
- max(y) - y_coeff * max(y),
256
- r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
257
- fontsize=kwargs.get("textsize", fontsize),
258
- )
259
- if diff_genes is not None:
260
- ax.text(
261
- max(x) - max(x) * x_coeff,
262
- max(y) - (y_coeff + 0.15) * max(y),
263
- r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
264
- fontsize=kwargs.get("textsize", fontsize),
265
- )
266
-
267
- if save:
268
- pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
269
- if show:
270
- pyplot.show()
271
- pyplot.close()
272
- if diff_genes is not None:
273
- return r_value**2, r_value_diff**2
274
- else:
275
- return r_value**2
276
-
277
- @staticmethod
278
- def binary_classifier(
279
- scgen,
280
- adata,
281
- delta,
282
- ctrl_key,
283
- stim_key,
284
- path_to_save,
285
- save=True,
286
- fontsize=14,
287
- ):
288
- """Latent space classifier.
289
-
290
- Builds a linear classifier based on the dot product between
291
- the difference vector and the latent representation of each
292
- cell and plots the dot product results between delta and latent representation.
293
-
294
- Args:
295
- scgen: ScGen object that was trained.
296
- adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
297
- AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
298
- corresponding to batch and cell type metadata, respectively.
299
- delta: Difference between stimulated and control cells in latent space
300
- ctrl_key: Key for `control` part of the `data` found in `condition_key`.
301
- stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
302
- path_to_save: Path to save the plot.
303
- save: Specify if the plot should be saved or not.
304
- fontsize: Set the font size of the plot.
305
- """
306
- # matplotlib.rcParams.update(matplotlib.rcParamsDefault)
307
- pyplot.close("all")
308
- adata = scgen._validate_anndata(adata)
309
- condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
310
- cd = adata[adata.obs[condition_key] == ctrl_key, :]
311
- stim = adata[adata.obs[condition_key] == stim_key, :]
312
- all_latent_cd = scgen.get_latent_representation(cd.X)
313
- all_latent_stim = scgen.get_latent_representation(stim.X)
314
- dot_cd = np.zeros(len(all_latent_cd))
315
- dot_sal = np.zeros(len(all_latent_stim))
316
- for ind, vec in enumerate(all_latent_cd):
317
- dot_cd[ind] = np.dot(delta, vec)
318
- for ind, vec in enumerate(all_latent_stim):
319
- dot_sal[ind] = np.dot(delta, vec)
320
- pyplot.hist(
321
- dot_cd,
322
- label=ctrl_key,
323
- bins=50,
324
- )
325
- pyplot.hist(dot_sal, label=stim_key, bins=50)
326
- pyplot.axvline(0, color="k", linestyle="dashed", linewidth=1)
327
- pyplot.title(" ", fontsize=fontsize)
328
- pyplot.xlabel(" ", fontsize=fontsize)
329
- pyplot.ylabel(" ", fontsize=fontsize)
330
- pyplot.xticks(fontsize=fontsize)
331
- pyplot.yticks(fontsize=fontsize)
332
- ax = pyplot.gca()
333
- ax.grid(False)
334
-
335
- if save:
336
- pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
337
- pyplot.show()
File without changes