pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,129 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from rich import print
9
+ from scanpy import settings
10
+
11
+ from pertpy.data._dataloader import _download
12
+
13
+ from ._look_up import LookUp
14
+ from ._metadata import MetaData
15
+
16
+ if TYPE_CHECKING:
17
+ from anndata import AnnData
18
+
19
+
20
+ class Moa(MetaData):
21
+ """Utilities to fetch metadata for mechanism of action studies."""
22
+
23
+ def __init__(self):
24
+ self.clue = None
25
+
26
+ def _download_clue(self) -> None:
27
+ clue_path = Path(settings.cachedir) / "repurposing_drugs_20200324.txt"
28
+ if not Path(clue_path).exists():
29
+ print("[bold yellow]No metadata file was found for clue. Starting download now.")
30
+ _download(
31
+ url="https://s3.amazonaws.com/data.clue.io/repurposing/downloads/repurposing_drugs_20200324.txt",
32
+ output_file_name="repurposing_drugs_20200324.txt",
33
+ output_path=settings.cachedir,
34
+ block_size=4096,
35
+ is_zip=False,
36
+ )
37
+ self.clue = pd.read_csv(clue_path, sep=" ", skiprows=9)
38
+ self.clue = self.clue[["pert_iname", "moa", "target"]]
39
+
40
+ def annotate(
41
+ self,
42
+ adata: AnnData,
43
+ query_id: str = "perturbation",
44
+ target: str | None = None,
45
+ verbosity: int | str = 5,
46
+ copy: bool = False,
47
+ ) -> AnnData:
48
+ """Annotate cells affected by perturbations by mechanism of action.
49
+
50
+ For each cell, we fetch the mechanism of action and molecular targets of the compounds sourced from clue.io.
51
+
52
+ Args:
53
+ adata: The data object to annotate.
54
+ query_id: The column of `.obs` with the name of a perturbagen. Defaults to 'perturbation'.
55
+ target: The column of `.obs` with target information. If set to None, all MoAs are retrieved without comparing molecular targets.
56
+ Defaults to None.
57
+ verbosity: The number of unmatched identifiers to print, can be either non-negative values or 'all'.
58
+ Defaults to 5.
59
+ copy: Determines whether a copy of the `adata` is returned. Defaults to False.
60
+
61
+ Returns:
62
+ Returns an AnnData object with MoA annotation.
63
+ """
64
+ if copy:
65
+ adata = adata.copy()
66
+
67
+ if query_id not in adata.obs.columns:
68
+ raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n" "Please check again.")
69
+
70
+ if self.clue is None:
71
+ self._download_clue()
72
+
73
+ identifier_num_all = len(adata.obs[query_id].unique())
74
+ not_matched_identifiers = list(set(adata.obs[query_id].str.lower()) - set(self.clue["pert_iname"].str.lower()))
75
+ self._warn_unmatch(
76
+ total_identifiers=identifier_num_all,
77
+ unmatched_identifiers=not_matched_identifiers,
78
+ query_id=query_id,
79
+ reference_id="pert_iname",
80
+ metadata_type="moa",
81
+ verbosity=verbosity,
82
+ )
83
+
84
+ adata.obs = (
85
+ adata.obs.merge(
86
+ self.clue,
87
+ left_on=adata.obs[query_id].str.lower(),
88
+ right_on=self.clue["pert_iname"].str.lower(),
89
+ how="left",
90
+ suffixes=("", "_fromMeta"),
91
+ )
92
+ .set_index(adata.obs.index)
93
+ .drop("key_0", axis=1)
94
+ )
95
+
96
+ # If target column is given, check whether it is one of the targets listed in the metadata
97
+ # If inconsistent, treat this perturbagen as unmatched and overwrite the annotated metadata with NaN
98
+ if target is not None:
99
+ target_meta = "target" if target != "target" else "target_fromMeta"
100
+ adata.obs[target_meta] = adata.obs[target_meta].mask(
101
+ ~adata.obs.apply(lambda row: str(row[target]) in str(row[target_meta]), axis=1)
102
+ )
103
+ pertname_meta = "pert_iname" if query_id != "pert_iname" else "pert_iname_fromMeta"
104
+ adata.obs.loc[adata.obs[target_meta].isna(), [pertname_meta, "moa"]] = np.nan
105
+
106
+ # If query_id and reference_id have different names, there will be a column for each of them after merging
107
+ # which is redundant as they refer to the same information.
108
+ if query_id != "pert_iname":
109
+ del adata.obs["pert_iname"]
110
+
111
+ return adata
112
+
113
+ def lookup(self) -> LookUp:
114
+ """Generate LookUp object for Moa metadata.
115
+
116
+ The LookUp object provides an overview of the metadata to annotate.
117
+ annotate_moa function has a corresponding lookup function in the LookUp object,
118
+ where users can search the query_ids and targets in the metadata.
119
+
120
+ Returns:
121
+ Returns a LookUp object specific for MoA annotation.
122
+ """
123
+ if self.clue is None:
124
+ self._download_clue()
125
+
126
+ return LookUp(
127
+ type="moa",
128
+ transfer_metadata=[self.clue],
129
+ )
pertpy/plot/__init__.py CHANGED
@@ -1,13 +1,5 @@
1
1
  from pertpy.plot._augur import AugurpyPlot as ag
2
- from pertpy.plot._dialogue import DialoguePlot as dl
3
-
4
- try:
5
- from pertpy.plot._coda import CodaPlot as coda
6
- except ImportError:
7
- pass
8
-
9
- from pertpy.plot._cinemaot import CinemaotPlot as cot
2
+ from pertpy.plot._coda import CodaPlot as coda
10
3
  from pertpy.plot._guide_rna import GuideRnaPlot as guide
11
4
  from pertpy.plot._milopy import MilopyPlot as milo
12
5
  from pertpy.plot._mixscape import MixscapePlot as ms
13
- from pertpy.plot._scgen import JaxscgenPlot as scg
pertpy/plot/_augur.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import warnings
3
4
  from typing import TYPE_CHECKING, Any
4
5
 
5
6
  from anndata import AnnData
@@ -15,7 +16,7 @@ class AugurpyPlot:
15
16
  """Plotting functions for Augurpy."""
16
17
 
17
18
  @staticmethod
18
- def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure: bool = False) -> Figure | Axes:
19
+ def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None) -> Figure | Axes:
19
20
  """Plot result of differential prioritization.
20
21
 
21
22
  Args:
@@ -42,38 +43,24 @@ class AugurpyPlot:
42
43
 
43
44
  >>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
44
45
  permuted_results1=results_15_permute, permuted_results2=results_48_permute)
45
- >>> pt.pl.ag.dp_scatter(pvals)
46
+ >>> ag_rfc.plot_dp_scatter(pvals)
46
47
  """
47
- x = results["mean_augur_score1"]
48
- y = results["mean_augur_score2"]
49
-
50
- if ax is None:
51
- fig, ax = plt.subplots()
52
- scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
53
-
54
- # adding optional labels
55
- top_n_index = results.sort_values(by="pval").index[:top_n]
56
- for idx in top_n_index:
57
- ax.annotate(
58
- results.loc[idx, "cell_type"],
59
- (results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
60
- )
48
+ warnings.warn(
49
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
50
+ " Please use the corresponding 'pt.tl' object",
51
+ FutureWarning,
52
+ stacklevel=2,
53
+ )
61
54
 
62
- # add diagonal
63
- limits = max(ax.get_xlim(), ax.get_ylim())
64
- (diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
55
+ from pertpy.tools import Augur
65
56
 
66
- # formatting and details
67
- plt.xlabel("Augur scores 1")
68
- plt.ylabel("Augur scores 2")
69
- legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
70
- ax.add_artist(legend1)
57
+ ag = Augur("random_forest_classifier")
71
58
 
72
- return fig if return_figure else ax
59
+ return ag.plot_dp_scatter(results=results, top_n=top_n, ax=ax)
73
60
 
74
61
  @staticmethod
75
62
  def important_features(
76
- data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None, return_figure: bool = False
63
+ data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None
77
64
  ) -> Figure | Axes:
78
65
  """Plot a lollipop plot of the n features with largest feature importances.
79
66
 
@@ -92,44 +79,26 @@ class AugurpyPlot:
92
79
  >>> adata = pt.dt.sc_sim_augur()
93
80
  >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
94
81
  >>> loaded_data = ag_rfc.load(adata)
95
- >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
96
- >>> pt.pl.ag.important_features(v_results)
82
+ >>> v_adata, v_results = ag_rfc.predict(
83
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
84
+ ... )
85
+ >>> ag_rfc.plot_important_features(v_results)
97
86
  """
98
- if isinstance(data, AnnData):
99
- results = data.uns[key]
100
- else:
101
- results = data
102
- # top_n features to plot
103
- n_features = (
104
- results["feature_importances"]
105
- .groupby("genes", as_index=False)
106
- .feature_importances.mean()
107
- .sort_values(by="feature_importances")[-top_n:]
87
+ warnings.warn(
88
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
89
+ " Please use the corresponding 'pt.tl' object",
90
+ FutureWarning,
91
+ stacklevel=2,
108
92
  )
109
93
 
110
- if ax is None:
111
- fig, ax = plt.subplots()
112
- y_axes_range = range(1, top_n + 1)
113
- ax.hlines(
114
- y_axes_range,
115
- xmin=0,
116
- xmax=n_features["feature_importances"],
117
- )
118
-
119
- # drawing the markers (circle)
120
- ax.plot(n_features["feature_importances"], y_axes_range, "o")
94
+ from pertpy.tools import Augur
121
95
 
122
- # formatting and details
123
- plt.xlabel("Mean Feature Importance")
124
- plt.ylabel("Gene")
125
- plt.yticks(y_axes_range, n_features["genes"])
96
+ ag = Augur("random_forest_classifier")
126
97
 
127
- return fig if return_figure else ax
98
+ return ag.plot_important_features(data=data, key=key, top_n=top_n, ax=ax)
128
99
 
129
100
  @staticmethod
130
- def lollipop(
131
- data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None, return_figure: bool = False
132
- ) -> Figure | Axes:
101
+ def lollipop(data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None) -> Figure | Axes | None:
133
102
  """Plot a lollipop plot of the mean augur values.
134
103
 
135
104
  Args:
@@ -146,40 +115,26 @@ class AugurpyPlot:
146
115
  >>> adata = pt.dt.sc_sim_augur()
147
116
  >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
148
117
  >>> loaded_data = ag_rfc.load(adata)
149
- >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
150
- >>> pt.pl.ag.lollipop(v_results)
118
+ >>> v_adata, v_results = ag_rfc.predict(
119
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
120
+ ... )
121
+ >>> ag_rfc.plot_lollipop(v_results)
151
122
  """
152
- if isinstance(data, AnnData):
153
- results = data.uns[key]
154
- else:
155
- results = data
156
- if ax is None:
157
- fig, ax = plt.subplots()
158
- y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
159
- ax.hlines(
160
- y_axes_range,
161
- xmin=0,
162
- xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
123
+ warnings.warn(
124
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
125
+ " Please use the corresponding 'pt.tl' object",
126
+ FutureWarning,
127
+ stacklevel=2,
163
128
  )
164
129
 
165
- # drawing the markers (circle)
166
- ax.plot(
167
- results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
168
- y_axes_range,
169
- "o",
170
- )
130
+ from pertpy.tools import Augur
171
131
 
172
- # formatting and details
173
- plt.xlabel("Mean Augur Score")
174
- plt.ylabel("Cell Type")
175
- plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
132
+ ag = Augur("random_forest_classifier")
176
133
 
177
- return fig if return_figure else ax
134
+ return ag.plot_lollipop(data=data, key=key, ax=ax)
178
135
 
179
136
  @staticmethod
180
- def scatterplot(
181
- results1: dict[str, Any], results2: dict[str, Any], top_n=None, return_figure: bool = False
182
- ) -> Figure | Axes:
137
+ def scatterplot(results1: dict[str, Any], results2: dict[str, Any], top_n=None) -> Figure | Axes:
183
138
  """Create scatterplot with two augur results.
184
139
 
185
140
  Args:
@@ -197,38 +152,20 @@ class AugurpyPlot:
197
152
  >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
198
153
  >>> loaded_data = ag_rfc.load(adata)
199
154
  >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
200
- >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
201
- >>> pt.pl.ag.scatterplot(v_results, h_results)
155
+ >>> v_adata, v_results = ag_rfc.predict(
156
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
157
+ ... )
158
+ >>> ag_rfc.plot_scatterplot(v_results, h_results)
202
159
  """
203
- cell_types = results1["summary_metrics"].columns
204
-
205
- fig, ax = plt.subplots()
206
- ax.scatter(
207
- results1["summary_metrics"].loc["mean_augur_score", cell_types],
208
- results2["summary_metrics"].loc["mean_augur_score", cell_types],
160
+ warnings.warn(
161
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
162
+ " Please use the corresponding 'pt.tl' object",
163
+ FutureWarning,
164
+ stacklevel=2,
209
165
  )
210
166
 
211
- # adding optional labels
212
- top_n_cell_types = (
213
- (results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
214
- .sort_values(ascending=False)
215
- .index[:top_n]
216
- )
217
- for txt in top_n_cell_types:
218
- ax.annotate(
219
- txt,
220
- (
221
- results1["summary_metrics"].loc["mean_augur_score", txt],
222
- results2["summary_metrics"].loc["mean_augur_score", txt],
223
- ),
224
- )
225
-
226
- # adding diagonal
227
- limits = max(ax.get_xlim(), ax.get_ylim())
228
- (diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
229
-
230
- # formatting and details
231
- plt.xlabel("Augur scores 1")
232
- plt.ylabel("Augur scores 2")
233
-
234
- return fig if return_figure else ax
167
+ from pertpy.tools import Augur
168
+
169
+ ag = Augur("random_forest_classifier")
170
+
171
+ return ag.plot_scatterplot(results1=results1, results2=results2, top_n=top_n)