pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,129 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from rich import print
9
+ from scanpy import settings
10
+
11
+ from pertpy.data._dataloader import _download
12
+
13
+ from ._look_up import LookUp
14
+ from ._metadata import MetaData
15
+
16
+ if TYPE_CHECKING:
17
+ from anndata import AnnData
18
+
19
+
20
+ class Moa(MetaData):
21
+ """Utilities to fetch metadata for mechanism of action studies."""
22
+
23
+ def __init__(self):
24
+ self.clue = None
25
+
26
+ def _download_clue(self) -> None:
27
+ clue_path = Path(settings.cachedir) / "repurposing_drugs_20200324.txt"
28
+ if not Path(clue_path).exists():
29
+ print("[bold yellow]No metadata file was found for clue. Starting download now.")
30
+ _download(
31
+ url="https://s3.amazonaws.com/data.clue.io/repurposing/downloads/repurposing_drugs_20200324.txt",
32
+ output_file_name="repurposing_drugs_20200324.txt",
33
+ output_path=settings.cachedir,
34
+ block_size=4096,
35
+ is_zip=False,
36
+ )
37
+ self.clue = pd.read_csv(clue_path, sep=" ", skiprows=9)
38
+ self.clue = self.clue[["pert_iname", "moa", "target"]]
39
+
40
+ def annotate(
41
+ self,
42
+ adata: AnnData,
43
+ query_id: str = "perturbation",
44
+ target: str | None = None,
45
+ verbosity: int | str = 5,
46
+ copy: bool = False,
47
+ ) -> AnnData:
48
+ """Annotate cells affected by perturbations by mechanism of action.
49
+
50
+ For each cell, we fetch the mechanism of action and molecular targets of the compounds sourced from clue.io.
51
+
52
+ Args:
53
+ adata: The data object to annotate.
54
+ query_id: The column of `.obs` with the name of a perturbagen. Defaults to 'perturbation'.
55
+ target: The column of `.obs` with target information. If set to None, all MoAs are retrieved without comparing molecular targets.
56
+ Defaults to None.
57
+ verbosity: The number of unmatched identifiers to print, can be either non-negative values or 'all'.
58
+ Defaults to 5.
59
+ copy: Determines whether a copy of the `adata` is returned. Defaults to False.
60
+
61
+ Returns:
62
+ Returns an AnnData object with MoA annotation.
63
+ """
64
+ if copy:
65
+ adata = adata.copy()
66
+
67
+ if query_id not in adata.obs.columns:
68
+ raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n" "Please check again.")
69
+
70
+ if self.clue is None:
71
+ self._download_clue()
72
+
73
+ identifier_num_all = len(adata.obs[query_id].unique())
74
+ not_matched_identifiers = list(set(adata.obs[query_id].str.lower()) - set(self.clue["pert_iname"].str.lower()))
75
+ self._warn_unmatch(
76
+ total_identifiers=identifier_num_all,
77
+ unmatched_identifiers=not_matched_identifiers,
78
+ query_id=query_id,
79
+ reference_id="pert_iname",
80
+ metadata_type="moa",
81
+ verbosity=verbosity,
82
+ )
83
+
84
+ adata.obs = (
85
+ adata.obs.merge(
86
+ self.clue,
87
+ left_on=adata.obs[query_id].str.lower(),
88
+ right_on=self.clue["pert_iname"].str.lower(),
89
+ how="left",
90
+ suffixes=("", "_fromMeta"),
91
+ )
92
+ .set_index(adata.obs.index)
93
+ .drop("key_0", axis=1)
94
+ )
95
+
96
+ # If target column is given, check whether it is one of the targets listed in the metadata
97
+ # If inconsistent, treat this perturbagen as unmatched and overwrite the annotated metadata with NaN
98
+ if target is not None:
99
+ target_meta = "target" if target != "target" else "target_fromMeta"
100
+ adata.obs[target_meta] = adata.obs[target_meta].mask(
101
+ ~adata.obs.apply(lambda row: str(row[target]) in str(row[target_meta]), axis=1)
102
+ )
103
+ pertname_meta = "pert_iname" if query_id != "pert_iname" else "pert_iname_fromMeta"
104
+ adata.obs.loc[adata.obs[target_meta].isna(), [pertname_meta, "moa"]] = np.nan
105
+
106
+ # If query_id and reference_id have different names, there will be a column for each of them after merging
107
+ # which is redundant as they refer to the same information.
108
+ if query_id != "pert_iname":
109
+ del adata.obs["pert_iname"]
110
+
111
+ return adata
112
+
113
+ def lookup(self) -> LookUp:
114
+ """Generate LookUp object for Moa metadata.
115
+
116
+ The LookUp object provides an overview of the metadata to annotate.
117
+ annotate_moa function has a corresponding lookup function in the LookUp object,
118
+ where users can search the query_ids and targets in the metadata.
119
+
120
+ Returns:
121
+ Returns a LookUp object specific for MoA annotation.
122
+ """
123
+ if self.clue is None:
124
+ self._download_clue()
125
+
126
+ return LookUp(
127
+ type="moa",
128
+ transfer_metadata=[self.clue],
129
+ )
pertpy/plot/__init__.py CHANGED
@@ -1,13 +1,5 @@
1
1
  from pertpy.plot._augur import AugurpyPlot as ag
2
- from pertpy.plot._dialogue import DialoguePlot as dl
3
-
4
- try:
5
- from pertpy.plot._coda import CodaPlot as coda
6
- except ImportError:
7
- pass
8
-
9
- from pertpy.plot._cinemaot import CinemaotPlot as cot
2
+ from pertpy.plot._coda import CodaPlot as coda
10
3
  from pertpy.plot._guide_rna import GuideRnaPlot as guide
11
4
  from pertpy.plot._milopy import MilopyPlot as milo
12
5
  from pertpy.plot._mixscape import MixscapePlot as ms
13
- from pertpy.plot._scgen import JaxscgenPlot as scg
pertpy/plot/_augur.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import warnings
3
4
  from typing import TYPE_CHECKING, Any
4
5
 
5
6
  from anndata import AnnData
@@ -15,7 +16,7 @@ class AugurpyPlot:
15
16
  """Plotting functions for Augurpy."""
16
17
 
17
18
  @staticmethod
18
- def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure: bool = False) -> Figure | Axes:
19
+ def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None) -> Figure | Axes:
19
20
  """Plot result of differential prioritization.
20
21
 
21
22
  Args:
@@ -42,38 +43,24 @@ class AugurpyPlot:
42
43
 
43
44
  >>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
44
45
  permuted_results1=results_15_permute, permuted_results2=results_48_permute)
45
- >>> pt.pl.ag.dp_scatter(pvals)
46
+ >>> ag_rfc.plot_dp_scatter(pvals)
46
47
  """
47
- x = results["mean_augur_score1"]
48
- y = results["mean_augur_score2"]
49
-
50
- if ax is None:
51
- fig, ax = plt.subplots()
52
- scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
53
-
54
- # adding optional labels
55
- top_n_index = results.sort_values(by="pval").index[:top_n]
56
- for idx in top_n_index:
57
- ax.annotate(
58
- results.loc[idx, "cell_type"],
59
- (results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
60
- )
48
+ warnings.warn(
49
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
50
+ " Please use the corresponding 'pt.tl' object",
51
+ FutureWarning,
52
+ stacklevel=2,
53
+ )
61
54
 
62
- # add diagonal
63
- limits = max(ax.get_xlim(), ax.get_ylim())
64
- (diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
55
+ from pertpy.tools import Augur
65
56
 
66
- # formatting and details
67
- plt.xlabel("Augur scores 1")
68
- plt.ylabel("Augur scores 2")
69
- legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
70
- ax.add_artist(legend1)
57
+ ag = Augur("random_forest_classifier")
71
58
 
72
- return fig if return_figure else ax
59
+ return ag.plot_dp_scatter(results=results, top_n=top_n, ax=ax)
73
60
 
74
61
  @staticmethod
75
62
  def important_features(
76
- data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None, return_figure: bool = False
63
+ data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None
77
64
  ) -> Figure | Axes:
78
65
  """Plot a lollipop plot of the n features with largest feature importances.
79
66
 
@@ -92,44 +79,26 @@ class AugurpyPlot:
92
79
  >>> adata = pt.dt.sc_sim_augur()
93
80
  >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
94
81
  >>> loaded_data = ag_rfc.load(adata)
95
- >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
96
- >>> pt.pl.ag.important_features(v_results)
82
+ >>> v_adata, v_results = ag_rfc.predict(
83
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
84
+ ... )
85
+ >>> ag_rfc.plot_important_features(v_results)
97
86
  """
98
- if isinstance(data, AnnData):
99
- results = data.uns[key]
100
- else:
101
- results = data
102
- # top_n features to plot
103
- n_features = (
104
- results["feature_importances"]
105
- .groupby("genes", as_index=False)
106
- .feature_importances.mean()
107
- .sort_values(by="feature_importances")[-top_n:]
87
+ warnings.warn(
88
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
89
+ " Please use the corresponding 'pt.tl' object",
90
+ FutureWarning,
91
+ stacklevel=2,
108
92
  )
109
93
 
110
- if ax is None:
111
- fig, ax = plt.subplots()
112
- y_axes_range = range(1, top_n + 1)
113
- ax.hlines(
114
- y_axes_range,
115
- xmin=0,
116
- xmax=n_features["feature_importances"],
117
- )
118
-
119
- # drawing the markers (circle)
120
- ax.plot(n_features["feature_importances"], y_axes_range, "o")
94
+ from pertpy.tools import Augur
121
95
 
122
- # formatting and details
123
- plt.xlabel("Mean Feature Importance")
124
- plt.ylabel("Gene")
125
- plt.yticks(y_axes_range, n_features["genes"])
96
+ ag = Augur("random_forest_classifier")
126
97
 
127
- return fig if return_figure else ax
98
+ return ag.plot_important_features(data=data, key=key, top_n=top_n, ax=ax)
128
99
 
129
100
  @staticmethod
130
- def lollipop(
131
- data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None, return_figure: bool = False
132
- ) -> Figure | Axes:
101
+ def lollipop(data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None) -> Figure | Axes | None:
133
102
  """Plot a lollipop plot of the mean augur values.
134
103
 
135
104
  Args:
@@ -146,40 +115,26 @@ class AugurpyPlot:
146
115
  >>> adata = pt.dt.sc_sim_augur()
147
116
  >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
148
117
  >>> loaded_data = ag_rfc.load(adata)
149
- >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
150
- >>> pt.pl.ag.lollipop(v_results)
118
+ >>> v_adata, v_results = ag_rfc.predict(
119
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
120
+ ... )
121
+ >>> ag_rfc.plot_lollipop(v_results)
151
122
  """
152
- if isinstance(data, AnnData):
153
- results = data.uns[key]
154
- else:
155
- results = data
156
- if ax is None:
157
- fig, ax = plt.subplots()
158
- y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
159
- ax.hlines(
160
- y_axes_range,
161
- xmin=0,
162
- xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
123
+ warnings.warn(
124
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
125
+ " Please use the corresponding 'pt.tl' object",
126
+ FutureWarning,
127
+ stacklevel=2,
163
128
  )
164
129
 
165
- # drawing the markers (circle)
166
- ax.plot(
167
- results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
168
- y_axes_range,
169
- "o",
170
- )
130
+ from pertpy.tools import Augur
171
131
 
172
- # formatting and details
173
- plt.xlabel("Mean Augur Score")
174
- plt.ylabel("Cell Type")
175
- plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
132
+ ag = Augur("random_forest_classifier")
176
133
 
177
- return fig if return_figure else ax
134
+ return ag.plot_lollipop(data=data, key=key, ax=ax)
178
135
 
179
136
  @staticmethod
180
- def scatterplot(
181
- results1: dict[str, Any], results2: dict[str, Any], top_n=None, return_figure: bool = False
182
- ) -> Figure | Axes:
137
+ def scatterplot(results1: dict[str, Any], results2: dict[str, Any], top_n=None) -> Figure | Axes:
183
138
  """Create scatterplot with two augur results.
184
139
 
185
140
  Args:
@@ -197,38 +152,20 @@ class AugurpyPlot:
197
152
  >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
198
153
  >>> loaded_data = ag_rfc.load(adata)
199
154
  >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
200
- >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
201
- >>> pt.pl.ag.scatterplot(v_results, h_results)
155
+ >>> v_adata, v_results = ag_rfc.predict(
156
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
157
+ ... )
158
+ >>> ag_rfc.plot_scatterplot(v_results, h_results)
202
159
  """
203
- cell_types = results1["summary_metrics"].columns
204
-
205
- fig, ax = plt.subplots()
206
- ax.scatter(
207
- results1["summary_metrics"].loc["mean_augur_score", cell_types],
208
- results2["summary_metrics"].loc["mean_augur_score", cell_types],
160
+ warnings.warn(
161
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
162
+ " Please use the corresponding 'pt.tl' object",
163
+ FutureWarning,
164
+ stacklevel=2,
209
165
  )
210
166
 
211
- # adding optional labels
212
- top_n_cell_types = (
213
- (results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
214
- .sort_values(ascending=False)
215
- .index[:top_n]
216
- )
217
- for txt in top_n_cell_types:
218
- ax.annotate(
219
- txt,
220
- (
221
- results1["summary_metrics"].loc["mean_augur_score", txt],
222
- results2["summary_metrics"].loc["mean_augur_score", txt],
223
- ),
224
- )
225
-
226
- # adding diagonal
227
- limits = max(ax.get_xlim(), ax.get_ylim())
228
- (diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
229
-
230
- # formatting and details
231
- plt.xlabel("Augur scores 1")
232
- plt.ylabel("Augur scores 2")
233
-
234
- return fig if return_figure else ax
167
+ from pertpy.tools import Augur
168
+
169
+ ag = Augur("random_forest_classifier")
170
+
171
+ return ag.plot_scatterplot(results1=results1, results2=results2, top_n=top_n)