pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_mixscape.py CHANGED
@@ -1,35 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- import copy
4
- from collections import OrderedDict
3
+ import warnings
5
4
  from typing import TYPE_CHECKING, Literal
6
5
 
7
- import numpy as np
8
- import pandas as pd
9
- import scanpy as sc
10
- from matplotlib import pyplot as pl
11
- from plotnine import (
12
- aes,
13
- element_blank,
14
- element_text,
15
- facet_wrap,
16
- geom_bar,
17
- geom_density,
18
- geom_point,
19
- ggplot,
20
- labs,
21
- scale_color_manual,
22
- scale_fill_manual,
23
- theme,
24
- theme_classic,
25
- xlab,
26
- ylab,
27
- )
28
- from scanpy import get
29
- from scanpy._settings import settings
30
- from scanpy._utils import _check_use_raw, sanitize_anndata
31
- from scanpy.plotting import _utils
32
-
33
6
  if TYPE_CHECKING:
34
7
  from collections.abc import Sequence
35
8
 
@@ -51,8 +24,8 @@ class MixscapePlot:
51
24
  strip_text_size: int = 6,
52
25
  panel_spacing_x: float = 0.3,
53
26
  panel_spacing_y: float = 0.3,
54
- legend_title_size: int = 8,
55
- legend_text_size: int = 8,
27
+ legend_title_size: int = 18,
28
+ legend_text_size: int = 18,
56
29
  show: bool | None = None,
57
30
  save: bool | str | None = None,
58
31
  ):
@@ -73,58 +46,34 @@ class MixscapePlot:
73
46
  Examples:
74
47
  >>> import pertpy as pt
75
48
  >>> mdata = pt.dt.papalexi_2021()
76
- >>> mixscape_identifier = pt.tl.Mixscape()
77
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
78
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
79
- >>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT')
49
+ >>> ms = pt.tl.Mixscape()
50
+ >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
51
+ >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
52
+ >>> ms.plot_barplot(mdata["rna"], guide_rna_column="NT")
80
53
  """
81
- if mixscape_class_global not in adata.obs:
82
- raise ValueError("Please run `pt.tl.mixscape` first.")
83
- count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
84
- all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
85
- KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
86
- KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
87
-
88
- new_levels = KO_cells_percentage[guide_rna_column]
89
- all_cells_percentage[guide_rna_column] = pd.Categorical(
90
- all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
91
- )
92
- all_cells_percentage[mixscape_class_global] = pd.Categorical(
93
- all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
94
- )
95
- all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
96
- all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
97
- all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
98
- NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
99
-
100
- p1 = (
101
- ggplot(NP_KO_cells, aes(x="guide_number", y="value", fill="mixscape_class_global"))
102
- + scale_fill_manual(values=["#7d7d7d", "#c9c9c9", "#ff7256"])
103
- + geom_bar(stat="identity")
104
- + theme_classic()
105
- + xlab("sgRNA")
106
- + ylab("% of cells")
54
+ warnings.warn(
55
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
56
+ " Please use the corresponding 'pt.tl' object",
57
+ FutureWarning,
58
+ stacklevel=2,
107
59
  )
108
60
 
109
- p1 = (
110
- p1
111
- + theme(
112
- axis_text_x=element_text(size=axis_text_x_size, hjust=2),
113
- axis_text_y=element_text(size=axis_text_y_size),
114
- axis_title=element_text(size=axis_title_size),
115
- strip_text=element_text(size=strip_text_size, face="bold"),
116
- panel_spacing_x=panel_spacing_x,
117
- panel_spacing_y=panel_spacing_y,
118
- )
119
- + facet_wrap("gene", ncol=5, scales="free")
120
- + labs(fill="mixscape class")
121
- + theme(legend_title=element_text(size=legend_title_size), legend_text=element_text(size=legend_text_size))
61
+ from pertpy.tools import Mixscape
62
+
63
+ ms = Mixscape()
64
+ return ms.plot_barplot(
65
+ adata=adata,
66
+ guide_rna_column=guide_rna_column,
67
+ mixscape_class_global=mixscape_class_global,
68
+ axis_text_x_size=axis_text_x_size,
69
+ axis_text_y_size=axis_text_y_size,
70
+ axis_title_size=axis_title_size,
71
+ legend_title_size=legend_title_size,
72
+ legend_text_size=legend_text_size,
73
+ show=show,
74
+ save=save,
122
75
  )
123
76
 
124
- _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
125
- if not show:
126
- return p1
127
-
128
77
  @staticmethod
129
78
  def heatmap( # pragma: no cover
130
79
  adata: AnnData,
@@ -138,7 +87,7 @@ class MixscapePlot:
138
87
  vmax: float | None = 2,
139
88
  show: bool | None = None,
140
89
  save: bool | str | None = None,
141
- **kwds,
90
+ **kwargs,
142
91
  ):
143
92
  """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
144
93
 
@@ -154,33 +103,41 @@ class MixscapePlot:
154
103
  vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
155
104
  show: Show the plot, do not return axis.
156
105
  save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
157
- ax: A matplotlib axes object. Only works if plotting a single component.
158
106
  **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
159
107
 
160
108
  Examples:
161
109
  >>> import pertpy as pt
162
110
  >>> mdata = pt.dt.papalexi_2021()
163
- >>> mixscape_identifier = pt.tl.Mixscape()
164
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
165
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
166
- >>> pt.pl.ms.heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
111
+ >>> ms = pt.tl.Mixscape()
112
+ >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
113
+ >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
114
+ >>> ms.plot_heatmap(
115
+ ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
116
+ ... )
167
117
  """
168
- if "mixscape_class" not in adata.obs:
169
- raise ValueError("Please run `pt.tl.mixscape` first.")
170
- adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
171
- sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
172
- sc.pp.scale(adata_subset, max_value=vmax)
173
- sc.pp.subsample(adata_subset, n_obs=subsample_number)
174
- return sc.pl.rank_genes_groups_heatmap(
175
- adata_subset,
176
- groupby="mixscape_class",
118
+ warnings.warn(
119
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
120
+ " Please use the corresponding 'pt.tl' object",
121
+ FutureWarning,
122
+ stacklevel=2,
123
+ )
124
+
125
+ from pertpy.tools import Mixscape
126
+
127
+ ms = Mixscape()
128
+ return ms.plot_heatmap(
129
+ adata=adata,
130
+ labels=labels,
131
+ target_gene=target_gene,
132
+ control=control,
133
+ layer=layer,
134
+ method=method,
135
+ subsample_number=subsample_number,
177
136
  vmin=vmin,
178
137
  vmax=vmax,
179
- n_genes=20,
180
- groups=["NT"],
181
138
  show=show,
182
139
  save=save,
183
- **kwds,
140
+ **kwargs,
184
141
  )
185
142
 
186
143
  @staticmethod
@@ -218,132 +175,32 @@ class MixscapePlot:
218
175
  >>> import pertpy as pt
219
176
  >>> mdata = pt.dt.papalexi_2021()
220
177
  >>> mixscape_identifier = pt.tl.Mixscape()
221
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
222
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
223
- >>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
178
+ >>> mixscape_identifier.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
179
+ >>> mixscape_identifier.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
180
+ >>> mixscape_identifier.perturbscore(
181
+ ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange"
182
+ ... )
224
183
  """
225
- if "mixscape" not in adata.uns:
226
- raise ValueError("Please run `pt.tl.mixscape` first.")
227
- perturbation_score = None
228
- for key in adata.uns["mixscape"][target_gene].keys():
229
- perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
230
- perturbation_score_temp["name"] = key
231
- if perturbation_score is None:
232
- perturbation_score = copy.deepcopy(perturbation_score_temp)
233
- else:
234
- perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
235
- perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
236
- gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
237
- # If before_mixscape is True, split densities based on original target gene classification
238
- if before_mixscape is True:
239
- cols = {gd: "#7d7d7d", target_gene: color}
240
- p = ggplot(perturbation_score, aes(x="pvec", color=labels)) + geom_density() + theme_classic()
241
- p_copy = copy.deepcopy(p)
242
- p_copy._build()
243
- top_r = max(p_copy.layers[0].data["density"])
244
- perturbation_score["y_jitter"] = perturbation_score["pvec"]
245
- rng = np.random.default_rng()
246
- perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
247
- low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
248
- )
249
- perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
250
- low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
251
- )
252
- # If split_by is provided, split densities based on the split_by
253
- if split_by is not None:
254
- perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
255
- p2 = (
256
- p
257
- + scale_color_manual(values=cols, drop=False)
258
- + geom_density(size=1.5)
259
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
260
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
261
- + ylab("Cell density")
262
- + xlab("Perturbation score")
263
- + theme(
264
- legend_key_size=1,
265
- legend_text=element_text(colour="black", size=14),
266
- legend_title=element_blank(),
267
- plot_title=element_text(size=16, face="bold"),
268
- )
269
- + facet_wrap("split")
270
- )
271
- else:
272
- p2 = (
273
- p
274
- + scale_color_manual(values=cols, drop=False)
275
- + geom_density(size=1.5)
276
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
277
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
278
- + ylab("Cell density")
279
- + xlab("Perturbation score")
280
- + theme(
281
- legend_key_size=1,
282
- legend_text=element_text(colour="black", size=14),
283
- legend_title=element_blank(),
284
- plot_title=element_text(size=16, face="bold"),
285
- )
286
- )
287
- # If before_mixscape is False, split densities based on mixscape classifications
288
- else:
289
- cols = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
290
- p = ggplot(perturbation_score, aes(x="pvec", color="mix")) + geom_density() + theme_classic()
291
- p_copy = copy.deepcopy(p)
292
- p_copy._build()
293
- top_r = max(p_copy.layers[0].data["density"])
294
- perturbation_score["y_jitter"] = perturbation_score["pvec"]
295
- rng = np.random.default_rng()
296
- gd2 = list(
297
- set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
298
- )[0]
299
- perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
300
- low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
301
- )
302
- perturbation_score.loc[
303
- perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"
304
- ] = rng.uniform(
305
- low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
306
- )
307
- perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
308
- low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
309
- )
310
- # If split_by is provided, split densities based on the split_by
311
- if split_by is not None:
312
- perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
313
- p2 = (
314
- ggplot(perturbation_score, aes(x="pvec", color="mix"))
315
- + scale_color_manual(values=cols, drop=False)
316
- + geom_density(size=1.5)
317
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
318
- + theme_classic()
319
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
320
- + ylab("Cell density")
321
- + xlab("Perturbation score")
322
- + theme(
323
- legend_key_size=1,
324
- legend_text=element_text(colour="black", size=14),
325
- legend_title=element_blank(),
326
- plot_title=element_text(size=16, face="bold"),
327
- )
328
- + facet_wrap("split")
329
- )
330
- else:
331
- p2 = (
332
- p
333
- + scale_color_manual(values=cols, drop=False)
334
- + geom_density(size=1.5)
335
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
336
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
337
- + ylab("Cell density")
338
- + xlab("Perturbation score")
339
- + theme(
340
- legend_key_size=1,
341
- legend_text=element_text(colour="black", size=14),
342
- legend_title=element_blank(),
343
- plot_title=element_text(size=16, face="bold"),
344
- )
345
- )
346
- return p2
184
+ warnings.warn(
185
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
186
+ " Please use the corresponding 'pt.tl' object",
187
+ FutureWarning,
188
+ stacklevel=2,
189
+ )
190
+
191
+ from pertpy.tools import Mixscape
192
+
193
+ ms = Mixscape()
194
+ return ms.plot_perturbscore(
195
+ adata=adata,
196
+ labels=labels,
197
+ target_gene=target_gene,
198
+ mixscape_class=mixscape_class,
199
+ color=color,
200
+ split_by=split_by,
201
+ before_mixscape=before_mixscape,
202
+ perturbation_type=perturbation_type,
203
+ )
347
204
 
348
205
  @staticmethod
349
206
  def violin( # pragma: no cover
@@ -367,7 +224,7 @@ class MixscapePlot:
367
224
  show: bool | None = None,
368
225
  save: bool | str | None = None,
369
226
  ax: Axes | None = None,
370
- **kwds,
227
+ **kwargs,
371
228
  ):
372
229
  """Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first.
373
230
 
@@ -385,7 +242,7 @@ class MixscapePlot:
385
242
  show: Show the plot, do not return axis.
386
243
  save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
387
244
  ax: A matplotlib axes object. Only works if plotting a single component.
388
- **kwds: Additional arguments to `seaborn.violinplot`.
245
+ **kwargs: Additional arguments to `seaborn.violinplot`.
389
246
 
390
247
  Returns:
391
248
  A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
@@ -393,153 +250,46 @@ class MixscapePlot:
393
250
  Examples:
394
251
  >>> import pertpy as pt
395
252
  >>> mdata = pt.dt.papalexi_2021()
396
- >>> mixscape_identifier = pt.tl.Mixscape()
397
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
398
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
399
- >>> pt.pl.ms.violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
253
+ >>> ms = pt.tl.Mixscape()
254
+ >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
255
+ >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
256
+ >>> ms.plot_violin(
257
+ ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
258
+ ... )
400
259
  """
401
- if isinstance(target_gene_idents, str):
402
- mixscape_class_mask = adata.obs[groupby] == target_gene_idents
403
- elif isinstance(target_gene_idents, list):
404
- mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
405
- for ident in target_gene_idents:
406
- mixscape_class_mask |= adata.obs[groupby] == ident
407
- adata = adata[mixscape_class_mask]
408
-
409
- import seaborn as sns # Slow import, only import if called
410
-
411
- sanitize_anndata(adata)
412
- use_raw = _check_use_raw(adata, use_raw)
413
- if isinstance(keys, str):
414
- keys = [keys]
415
- keys = list(OrderedDict.fromkeys(keys)) # remove duplicates, preserving the order
416
-
417
- if isinstance(ylabel, (str, type(None))):
418
- ylabel = [ylabel] * (1 if groupby is None else len(keys))
419
- if groupby is None:
420
- if len(ylabel) != 1:
421
- raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
422
- elif len(ylabel) != len(keys):
423
- raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
424
-
425
- if groupby is not None:
426
- if hue is not None:
427
- obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
428
- else:
429
- obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
430
-
431
- else:
432
- obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
433
- if groupby is None:
434
- obs_tidy = pd.melt(obs_df, value_vars=keys)
435
- x = "variable"
436
- ys = ["value"]
437
- else:
438
- obs_tidy = obs_df
439
- x = groupby
440
- ys = keys
441
-
442
- if multi_panel and groupby is None and len(ys) == 1:
443
- # This is a quick and dirty way for adapting scales across several
444
- # keys if groupby is None.
445
- y = ys[0]
446
-
447
- g = sns.catplot(
448
- y=y,
449
- data=obs_tidy,
450
- kind="violin",
451
- scale=scale,
452
- col=x,
453
- col_order=keys,
454
- sharey=False,
455
- order=keys,
456
- cut=0,
457
- inner=None,
458
- **kwds,
459
- )
460
-
461
- if stripplot:
462
- grouped_df = obs_tidy.groupby(x)
463
- for ax_id, key in zip(range(g.axes.shape[1]), keys):
464
- sns.stripplot(
465
- y=y,
466
- data=grouped_df.get_group(key),
467
- jitter=jitter,
468
- size=size,
469
- color="black",
470
- ax=g.axes[0, ax_id],
471
- )
472
- if log:
473
- g.set(yscale="log")
474
- g.set_titles(col_template="{col_name}").set_xlabels("")
475
- if rotation is not None:
476
- for ax in g.axes[0]:
477
- ax.tick_params(axis="x", labelrotation=rotation)
478
- else:
479
- # set by default the violin plot cut=0 to limit the extend
480
- # of the violin plot (see stacked_violin code) for more info.
481
- kwds.setdefault("cut", 0)
482
- kwds.setdefault("inner")
483
-
484
- if ax is None:
485
- axs, _, _, _ = _utils.setup_axes(
486
- ax=ax,
487
- panels=["x"] if groupby is None else keys,
488
- show_ticks=True,
489
- right_margin=0.3,
490
- )
491
- else:
492
- axs = [ax]
493
- for ax, y, ylab in zip(axs, ys, ylabel): # noqa: F402
494
- ax = sns.violinplot(
495
- x=x,
496
- y=y,
497
- data=obs_tidy,
498
- order=order,
499
- orient="vertical",
500
- scale=scale,
501
- ax=ax,
502
- hue=hue,
503
- **kwds,
504
- )
505
- # Get the handles and labels.
506
- handles, labels = ax.get_legend_handles_labels()
507
- if stripplot:
508
- ax = sns.stripplot(
509
- x=x,
510
- y=y,
511
- data=obs_tidy,
512
- order=order,
513
- jitter=jitter,
514
- color="black",
515
- size=size,
516
- ax=ax,
517
- hue=hue,
518
- dodge=True,
519
- )
520
- if xlabel == "" and groupby is not None and rotation is None:
521
- xlabel = groupby.replace("_", " ")
522
- ax.set_xlabel(xlabel)
523
- if ylab is not None:
524
- ax.set_ylabel(ylab)
525
-
526
- if log:
527
- ax.set_yscale("log")
528
- if rotation is not None:
529
- ax.tick_params(axis="x", labelrotation=rotation)
530
-
531
- show = settings.autoshow if show is None else show
532
- if hue is not None and stripplot is True:
533
- pl.legend(handles, labels)
534
- _utils.savefig_or_show("mixscape_violin", show=show, save=save)
535
-
536
- if not show:
537
- if multi_panel and groupby is None and len(ys) == 1:
538
- return g
539
- elif len(axs) == 1:
540
- return axs[0]
541
- else:
542
- return axs
260
+ warnings.warn(
261
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
262
+ " Please use the corresponding 'pt.tl' object",
263
+ FutureWarning,
264
+ stacklevel=2,
265
+ )
266
+
267
+ from pertpy.tools import Mixscape
268
+
269
+ ms = Mixscape()
270
+ return ms.plot_violin(
271
+ adata=adata,
272
+ target_gene_idents=target_gene_idents,
273
+ keys=keys,
274
+ groupby=groupby,
275
+ log=log,
276
+ use_raw=use_raw,
277
+ stripplot=stripplot,
278
+ hue=hue,
279
+ jitter=jitter,
280
+ size=size,
281
+ layer=layer,
282
+ scale=scale,
283
+ order=order,
284
+ multi_panel=multi_panel,
285
+ xlabel=xlabel,
286
+ ylabel=ylabel,
287
+ rotation=rotation,
288
+ show=show,
289
+ save=save,
290
+ ax=ax,
291
+ **kwargs,
292
+ )
543
293
 
544
294
  @staticmethod
545
295
  def lda( # pragma: no cover
@@ -552,7 +302,7 @@ class MixscapePlot:
552
302
  n_components: int | None = None,
553
303
  show: bool | None = None,
554
304
  save: bool | str | None = None,
555
- **kwds,
305
+ **kwargs,
556
306
  ):
557
307
  """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
558
308
 
@@ -562,33 +312,44 @@ class MixscapePlot:
562
312
  labels: The column of `.obs` with target gene labels.
563
313
  mixscape_class: The column of `.obs` with the mixscape classification result.
564
314
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
565
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to 'KO'.
315
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
316
+ Defaults to 'KO'.
566
317
  lda_key: If not speficied, lda looks .uns["mixscape_lda"] for the LDA results.
567
318
  n_components: The number of dimensions of the embedding.
568
319
  show: Show the plot, do not return axis.
569
- save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
570
- **kwds: Additional arguments to `scanpy.pl.umap`.
320
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
321
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
322
+ **kwargs: Additional arguments to `scanpy.pl.umap`.
571
323
 
572
324
  Examples:
573
325
  >>> import pertpy as pt
574
326
  >>> mdata = pt.dt.papalexi_2021()
575
- >>> mixscape_identifier = pt.tl.Mixscape()
576
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
577
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
578
- >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
579
- >>> pt.pl.ms.lda(adata=mdata['rna'], control='NT')
327
+ >>> ms = pt.tl.Mixscape()
328
+ >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
329
+ >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
330
+ >>> ms.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
331
+ >>> ms.plot_lda(adata=mdata["rna"], control="NT")
580
332
  """
581
- if mixscape_class not in adata.obs:
582
- raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.')
583
- if lda_key not in adata.uns:
584
- raise ValueError(f'Did not find .uns["{lda_key!r}"]. Run `pt.tl.neighbors` first.')
585
-
586
- adata_subset = adata[
587
- (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
588
- ].copy()
589
- adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
590
- if n_components is None:
591
- n_components = adata_subset.uns[lda_key].shape[1]
592
- sc.pp.neighbors(adata_subset, use_rep=lda_key)
593
- sc.tl.umap(adata_subset, n_components=n_components)
594
- sc.pl.umap(adata_subset, color=mixscape_class, show=show, save=save, **kwds)
333
+ warnings.warn(
334
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
335
+ " Please use the corresponding 'pt.tl' object",
336
+ FutureWarning,
337
+ stacklevel=2,
338
+ )
339
+
340
+ from pertpy.tools import Mixscape
341
+
342
+ ms = Mixscape()
343
+
344
+ return ms.plot_lda(
345
+ adata=adata,
346
+ control=control,
347
+ mixscape_class=mixscape_class,
348
+ mixscape_class_global=mixscape_class_global,
349
+ perturbation_type=perturbation_type,
350
+ lda_key=lda_key,
351
+ n_components=n_components,
352
+ show=show,
353
+ save=save,
354
+ **kwargs,
355
+ )