pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_guide_rna.py CHANGED
@@ -1,13 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import warnings
3
4
  from typing import TYPE_CHECKING
4
5
 
5
- import numpy as np
6
- import scanpy as sc
7
-
8
6
  if TYPE_CHECKING:
7
+ import numpy as np
9
8
  from anndata import AnnData
10
- from matplotlib.axes import Axes
11
9
 
12
10
 
13
11
  class GuideRnaPlot:
@@ -17,8 +15,8 @@ class GuideRnaPlot:
17
15
  layer: str | None = None,
18
16
  order_by: np.ndarray | str | None = None,
19
17
  key_to_save_order: str = None,
20
- **kwds,
21
- ) -> list[Axes]:
18
+ **kwargs,
19
+ ):
22
20
  """Heatmap plotting of guide RNA expression matrix.
23
21
 
24
22
  Assuming guides have sparse expression, this function reorders cells
@@ -36,7 +34,7 @@ class GuideRnaPlot:
36
34
  If a string is provided, adata.obs[order_by] will be used as the order.
37
35
  If a numpy array is provided, the array will be used for ordering.
38
36
  key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
39
- kwds: Are passed to sc.pl.heatmap.
37
+ kwargs: Are passed to sc.pl.heatmap.
40
38
 
41
39
  Returns:
42
40
  List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
@@ -47,36 +45,20 @@ class GuideRnaPlot:
47
45
  visualized using a heatmap.
48
46
 
49
47
  >>> import pertpy as pt
50
- >>> mdata = pt.data.papalexi_2021()
51
- >>> gdo = mdata.mod['gdo']
48
+ >>> mdata = pt.dt.papalexi_2021()
49
+ >>> gdo = mdata.mod["gdo"]
52
50
  >>> ga = pt.pp.GuideAssignment()
53
51
  >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
54
- >>> pt.pl.guide.heatmap(gdo)
52
+ >>> ga.plot_heatmap(gdo)
55
53
  """
56
- data = adata.X if layer is None else adata.layers[layer]
54
+ warnings.warn(
55
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
56
+ " Please use the corresponding 'pt.tl' object",
57
+ FutureWarning,
58
+ stacklevel=2,
59
+ )
57
60
 
58
- if order_by is None:
59
- max_guide_index = np.where(
60
- np.array(data.max(axis=1)).squeeze() != data.min(), np.array(data.argmax(axis=1)).squeeze(), -1
61
- )
62
- order = np.argsort(max_guide_index)
63
- elif isinstance(order_by, str):
64
- order = adata.obs[order_by]
65
- else:
66
- order = order_by
61
+ from pertpy.preprocessing import GuideAssignment
67
62
 
68
- adata.obs["_tmp_pertpy_grna_plot_dummy_group"] = ""
69
- if key_to_save_order is not None:
70
- adata.obs[key_to_save_order] = order
71
- axis_group = sc.pl.heatmap(
72
- adata[order],
73
- adata.var.index.tolist(),
74
- groupby="_tmp_pertpy_grna_plot_dummy_group",
75
- cmap="viridis",
76
- use_raw=False,
77
- dendrogram=False,
78
- layer=layer,
79
- **kwds,
80
- )
81
- del adata.obs["_tmp_pertpy_grna_plot_dummy_group"]
82
- return axis_group
63
+ ga = GuideAssignment()
64
+ ga.plot_heatmap(adata=adata, layer=layer, order_by=order_by, key_to_save_order=key_to_save_order, kwargs=kwargs)
pertpy/plot/_milopy.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import warnings
3
4
  from typing import TYPE_CHECKING
4
5
 
5
6
  import matplotlib.pyplot as plt
@@ -52,44 +53,29 @@ class MilopyPlot:
52
53
  >>> sc.tl.umap(mdata["rna"])
53
54
  >>> milo.make_nhoods(mdata["rna"])
54
55
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
55
- >>> milo.da_nhoods(mdata, design="~label")
56
+ >>> milo.da_nhoods(mdata,
57
+ >>> design='~label',
58
+ >>> model_contrasts='labelwithdraw_15d_Cocaine-labelwithdraw_48h_Cocaine')
56
59
  >>> milo.build_nhood_graph(mdata)
57
- >>> pt.pl.milo.nhood_graph(mdata)
58
- # TODO: If necessary adjust after fixing StopIteration error, which is currently thrown
60
+ >>> milo.plot_nhood_graph(mdata)
59
61
  """
60
- nhood_adata = mdata["milo"].T.copy()
61
-
62
- if "Nhood_size" not in nhood_adata.obs.columns:
63
- raise KeyError(
64
- 'Cannot find "Nhood_size" column in adata.uns["nhood_adata"].obs -- \
65
- please run milopy.utils.build_nhood_graph(adata)'
66
- )
67
-
68
- nhood_adata.obs["graph_color"] = nhood_adata.obs["logFC"]
69
- nhood_adata.obs.loc[nhood_adata.obs["SpatialFDR"] > alpha, "graph_color"] = np.nan
70
- nhood_adata.obs["abs_logFC"] = abs(nhood_adata.obs["logFC"])
71
- nhood_adata.obs.loc[nhood_adata.obs["abs_logFC"] < min_logFC, "graph_color"] = np.nan
62
+ warnings.warn(
63
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
64
+ " Please use the corresponding 'pt.tl' object",
65
+ FutureWarning,
66
+ stacklevel=2,
67
+ )
72
68
 
73
- # Plotting order - extreme logFC on top
74
- nhood_adata.obs.loc[nhood_adata.obs["graph_color"].isna(), "abs_logFC"] = np.nan
75
- ordered = nhood_adata.obs.sort_values("abs_logFC", na_position="first").index
76
- nhood_adata = nhood_adata[ordered]
69
+ from pertpy.tools import Milo
77
70
 
78
- vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
79
- vmin = -vmax
71
+ milo = Milo()
80
72
 
81
- sc.pl.embedding(
82
- nhood_adata,
83
- "X_milo_graph",
84
- color="graph_color",
85
- cmap="RdBu_r",
86
- size=nhood_adata.obs["Nhood_size"] * min_size,
87
- edges=plot_edges,
88
- neighbors_key="nhood",
89
- sort_order=False,
90
- frameon=False,
91
- vmax=vmax,
92
- vmin=vmin,
73
+ return milo.plot_nhood_graph(
74
+ madata=mdata,
75
+ alpha=alpha,
76
+ min_logFC=min_logFC,
77
+ min_size=min_size,
78
+ plot_edges=plot_edges,
93
79
  title=title,
94
80
  show=show,
95
81
  save=save,
@@ -127,12 +113,19 @@ class MilopyPlot:
127
113
  >>> milo.make_nhoods(mdata["rna"])
128
114
  >>> pt.pl.milo.nhood(mdata, ix=0)
129
115
  """
130
-
131
- mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
132
- sc.pl.embedding(
133
- mdata[feature_key], basis, color="Nhood", size=30, title="Nhood" + str(ix), show=show, save=save, **kwargs
116
+ warnings.warn(
117
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
118
+ " Please use the corresponding 'pt.tl' object",
119
+ FutureWarning,
120
+ stacklevel=2,
134
121
  )
135
122
 
123
+ from pertpy.tools import Milo
124
+
125
+ milo = Milo()
126
+
127
+ milo.plot_nhood(mdata=mdata, ix=ix, feature_key=feature_key, basis=basis, show=show, save=save, **kwargs)
128
+
136
129
  @staticmethod
137
130
  def da_beeswarm(
138
131
  mdata: MuData,
@@ -146,7 +139,7 @@ class MilopyPlot:
146
139
 
147
140
  Args:
148
141
  mdata: MuData object
149
- anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
142
+ anno_col: Column in mdata['milo'].var to use as annotation. (default: 'nhood_annotation'.)
150
143
  alpha: Significance threshold. (default: 0.1)
151
144
  subset_nhoods: List of nhoods to plot. If None, plot all nhoods. (default: None)
152
145
  palette: Name of Seaborn color palette for violinplots.
@@ -162,84 +155,28 @@ class MilopyPlot:
162
155
  >>> milo.make_nhoods(mdata["rna"])
163
156
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
164
157
  >>> milo.da_nhoods(mdata, design="~label")
165
- >>> milo.annotate_nhoods(mdata, anno_col='cell_type')
166
- >>> pt.pl.milo.da_beeswarm(mdata)
158
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
159
+ >>> milo.plot_da_beeswarm(mdata)
167
160
  """
168
- try:
169
- nhood_adata = mdata["milo"].T.copy()
170
- except KeyError:
171
- raise RuntimeError(
172
- "mdata should be a MuData object with two slots: feature_key and 'milo'. Run 'milopy.count_nhoods(adata)' first."
173
- ) from None
174
-
175
- if subset_nhoods is not None:
176
- nhood_adata = nhood_adata[subset_nhoods]
177
-
178
- try:
179
- nhood_adata.obs[anno_col]
180
- except KeyError:
181
- raise RuntimeError(
182
- f"Unable to find {anno_col} in mdata.uns['nhood_adata']. Run 'milopy.utils.annotate_nhoods(adata, anno_col)' first"
183
- ) from None
184
-
185
- try:
186
- nhood_adata.obs["logFC"]
187
- except KeyError:
188
- raise RuntimeError(
189
- "Unable to find 'logFC' in mdata.uns['nhood_adata'].obs. Run 'core.da_nhoods(adata)' first."
190
- ) from None
191
-
192
- sorted_annos = (
193
- nhood_adata.obs[[anno_col, "logFC"]].groupby(anno_col).median().sort_values("logFC", ascending=True).index
161
+ warnings.warn(
162
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
163
+ " Please use the corresponding 'pt.tl' object",
164
+ FutureWarning,
165
+ stacklevel=2,
194
166
  )
195
167
 
196
- anno_df = nhood_adata.obs[[anno_col, "logFC", "SpatialFDR"]].copy()
197
- anno_df["is_signif"] = anno_df["SpatialFDR"] < alpha
198
- anno_df = anno_df[anno_df[anno_col] != "nan"]
168
+ from pertpy.tools import Milo
199
169
 
200
- try:
201
- obs_col = nhood_adata.uns["annotation_obs"]
202
- if palette is None:
203
- palette = dict(
204
- zip(mdata[feature_key].obs[obs_col].cat.categories, mdata[feature_key].uns[f"{obs_col}_colors"])
205
- )
206
- sns.violinplot(
207
- data=anno_df,
208
- y=anno_col,
209
- x="logFC",
210
- order=sorted_annos,
211
- size=190,
212
- inner=None,
213
- orient="h",
214
- palette=palette,
215
- linewidth=0,
216
- scale="width",
217
- )
218
- except BaseException: # noqa: BLE001
219
- sns.violinplot(
220
- data=anno_df,
221
- y=anno_col,
222
- x="logFC",
223
- order=sorted_annos,
224
- size=190,
225
- inner=None,
226
- orient="h",
227
- linewidth=0,
228
- scale="width",
229
- )
230
- sns.stripplot(
231
- data=anno_df,
232
- y=anno_col,
233
- x="logFC",
234
- order=sorted_annos,
235
- size=2,
236
- hue="is_signif",
237
- palette=["grey", "black"],
238
- orient="h",
239
- alpha=0.5,
170
+ milo = Milo()
171
+
172
+ milo.plot_da_beeswarm(
173
+ mdata=mdata,
174
+ feature_key=feature_key,
175
+ anno_col=anno_col,
176
+ alpha=alpha,
177
+ subset_nhoods=subset_nhoods,
178
+ palette=palette,
240
179
  )
241
- plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
242
- plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
243
180
 
244
181
  @staticmethod
245
182
  def nhood_counts_by_cond(
@@ -256,29 +193,17 @@ class MilopyPlot:
256
193
  subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods. (default: None)
257
194
  log_counts: Whether to plot log1p of cell counts. (default: False)
258
195
  """
259
- try:
260
- nhood_adata = mdata["milo"].T.copy()
261
- except KeyError:
262
- raise RuntimeError(
263
- "mdata should be a MuData object with two slots: feature_key and 'milo'. Run milopy.count_nhoods(mdata) first"
264
- ) from None
196
+ warnings.warn(
197
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
198
+ " Please use the corresponding 'pt.tl' object",
199
+ FutureWarning,
200
+ stacklevel=2,
201
+ )
265
202
 
266
- if subset_nhoods is None:
267
- subset_nhoods = nhood_adata.obs_names
203
+ from pertpy.tools import Milo
268
204
 
269
- pl_df = pd.DataFrame(nhood_adata[subset_nhoods].X.A, columns=nhood_adata.var_names).melt(
270
- var_name=nhood_adata.uns["sample_col"], value_name="n_cells"
271
- )
272
- pl_df = pd.merge(pl_df, nhood_adata.var)
273
- pl_df["log_n_cells"] = np.log1p(pl_df["n_cells"])
274
- if not log_counts:
275
- sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue")
276
- sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3)
277
- plt.ylabel("# cells")
278
- else:
279
- sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue")
280
- sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3)
281
- plt.ylabel("log(# cells + 1)")
205
+ milo = Milo()
282
206
 
283
- plt.xticks(rotation=90)
284
- plt.xlabel(test_var)
207
+ milo.plot_nhood_counts_by_cond(
208
+ mdata=mdata, test_var=test_var, subset_nhoods=subset_nhoods, log_counts=log_counts
209
+ )