pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_mixscape.py CHANGED
@@ -1,35 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- import copy
4
- from collections import OrderedDict
3
+ import warnings
5
4
  from typing import TYPE_CHECKING, Literal
6
5
 
7
- import numpy as np
8
- import pandas as pd
9
- import scanpy as sc
10
- from matplotlib import pyplot as pl
11
- from plotnine import (
12
- aes,
13
- element_blank,
14
- element_text,
15
- facet_wrap,
16
- geom_bar,
17
- geom_density,
18
- geom_point,
19
- ggplot,
20
- labs,
21
- scale_color_manual,
22
- scale_fill_manual,
23
- theme,
24
- theme_classic,
25
- xlab,
26
- ylab,
27
- )
28
- from scanpy import get
29
- from scanpy._settings import settings
30
- from scanpy._utils import _check_use_raw, sanitize_anndata
31
- from scanpy.plotting import _utils
32
-
33
6
  if TYPE_CHECKING:
34
7
  from collections.abc import Sequence
35
8
 
@@ -51,8 +24,8 @@ class MixscapePlot:
51
24
  strip_text_size: int = 6,
52
25
  panel_spacing_x: float = 0.3,
53
26
  panel_spacing_y: float = 0.3,
54
- legend_title_size: int = 8,
55
- legend_text_size: int = 8,
27
+ legend_title_size: int = 18,
28
+ legend_text_size: int = 18,
56
29
  show: bool | None = None,
57
30
  save: bool | str | None = None,
58
31
  ):
@@ -73,58 +46,34 @@ class MixscapePlot:
73
46
  Examples:
74
47
  >>> import pertpy as pt
75
48
  >>> mdata = pt.dt.papalexi_2021()
76
- >>> mixscape_identifier = pt.tl.Mixscape()
77
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
78
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
79
- >>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT')
49
+ >>> ms = pt.tl.Mixscape()
50
+ >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
51
+ >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
52
+ >>> ms.plot_barplot(mdata["rna"], guide_rna_column="NT")
80
53
  """
81
- if mixscape_class_global not in adata.obs:
82
- raise ValueError("Please run `pt.tl.mixscape` first.")
83
- count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
84
- all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
85
- KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
86
- KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
87
-
88
- new_levels = KO_cells_percentage[guide_rna_column]
89
- all_cells_percentage[guide_rna_column] = pd.Categorical(
90
- all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
91
- )
92
- all_cells_percentage[mixscape_class_global] = pd.Categorical(
93
- all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
94
- )
95
- all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
96
- all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
97
- all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
98
- NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
99
-
100
- p1 = (
101
- ggplot(NP_KO_cells, aes(x="guide_number", y="value", fill="mixscape_class_global"))
102
- + scale_fill_manual(values=["#7d7d7d", "#c9c9c9", "#ff7256"])
103
- + geom_bar(stat="identity")
104
- + theme_classic()
105
- + xlab("sgRNA")
106
- + ylab("% of cells")
54
+ warnings.warn(
55
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
56
+ " Please use the corresponding 'pt.tl' object",
57
+ FutureWarning,
58
+ stacklevel=2,
107
59
  )
108
60
 
109
- p1 = (
110
- p1
111
- + theme(
112
- axis_text_x=element_text(size=axis_text_x_size, hjust=2),
113
- axis_text_y=element_text(size=axis_text_y_size),
114
- axis_title=element_text(size=axis_title_size),
115
- strip_text=element_text(size=strip_text_size, face="bold"),
116
- panel_spacing_x=panel_spacing_x,
117
- panel_spacing_y=panel_spacing_y,
118
- )
119
- + facet_wrap("gene", ncol=5, scales="free")
120
- + labs(fill="mixscape class")
121
- + theme(legend_title=element_text(size=legend_title_size), legend_text=element_text(size=legend_text_size))
61
+ from pertpy.tools import Mixscape
62
+
63
+ ms = Mixscape()
64
+ return ms.plot_barplot(
65
+ adata=adata,
66
+ guide_rna_column=guide_rna_column,
67
+ mixscape_class_global=mixscape_class_global,
68
+ axis_text_x_size=axis_text_x_size,
69
+ axis_text_y_size=axis_text_y_size,
70
+ axis_title_size=axis_title_size,
71
+ legend_title_size=legend_title_size,
72
+ legend_text_size=legend_text_size,
73
+ show=show,
74
+ save=save,
122
75
  )
123
76
 
124
- _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
125
- if not show:
126
- return p1
127
-
128
77
  @staticmethod
129
78
  def heatmap( # pragma: no cover
130
79
  adata: AnnData,
@@ -138,7 +87,7 @@ class MixscapePlot:
138
87
  vmax: float | None = 2,
139
88
  show: bool | None = None,
140
89
  save: bool | str | None = None,
141
- **kwds,
90
+ **kwargs,
142
91
  ):
143
92
  """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
144
93
 
@@ -154,33 +103,41 @@ class MixscapePlot:
154
103
  vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
155
104
  show: Show the plot, do not return axis.
156
105
  save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
157
- ax: A matplotlib axes object. Only works if plotting a single component.
158
106
  **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
159
107
 
160
108
  Examples:
161
109
  >>> import pertpy as pt
162
110
  >>> mdata = pt.dt.papalexi_2021()
163
- >>> mixscape_identifier = pt.tl.Mixscape()
164
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
165
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
166
- >>> pt.pl.ms.heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
111
+ >>> ms = pt.tl.Mixscape()
112
+ >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
113
+ >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
114
+ >>> ms.plot_heatmap(
115
+ ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
116
+ ... )
167
117
  """
168
- if "mixscape_class" not in adata.obs:
169
- raise ValueError("Please run `pt.tl.mixscape` first.")
170
- adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
171
- sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
172
- sc.pp.scale(adata_subset, max_value=vmax)
173
- sc.pp.subsample(adata_subset, n_obs=subsample_number)
174
- return sc.pl.rank_genes_groups_heatmap(
175
- adata_subset,
176
- groupby="mixscape_class",
118
+ warnings.warn(
119
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
120
+ " Please use the corresponding 'pt.tl' object",
121
+ FutureWarning,
122
+ stacklevel=2,
123
+ )
124
+
125
+ from pertpy.tools import Mixscape
126
+
127
+ ms = Mixscape()
128
+ return ms.plot_heatmap(
129
+ adata=adata,
130
+ labels=labels,
131
+ target_gene=target_gene,
132
+ control=control,
133
+ layer=layer,
134
+ method=method,
135
+ subsample_number=subsample_number,
177
136
  vmin=vmin,
178
137
  vmax=vmax,
179
- n_genes=20,
180
- groups=["NT"],
181
138
  show=show,
182
139
  save=save,
183
- **kwds,
140
+ **kwargs,
184
141
  )
185
142
 
186
143
  @staticmethod
@@ -218,132 +175,32 @@ class MixscapePlot:
218
175
  >>> import pertpy as pt
219
176
  >>> mdata = pt.dt.papalexi_2021()
220
177
  >>> mixscape_identifier = pt.tl.Mixscape()
221
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
222
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
223
- >>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
178
+ >>> mixscape_identifier.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
179
+ >>> mixscape_identifier.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
180
+ >>> mixscape_identifier.perturbscore(
181
+ ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange"
182
+ ... )
224
183
  """
225
- if "mixscape" not in adata.uns:
226
- raise ValueError("Please run `pt.tl.mixscape` first.")
227
- perturbation_score = None
228
- for key in adata.uns["mixscape"][target_gene].keys():
229
- perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
230
- perturbation_score_temp["name"] = key
231
- if perturbation_score is None:
232
- perturbation_score = copy.deepcopy(perturbation_score_temp)
233
- else:
234
- perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
235
- perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
236
- gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
237
- # If before_mixscape is True, split densities based on original target gene classification
238
- if before_mixscape is True:
239
- cols = {gd: "#7d7d7d", target_gene: color}
240
- p = ggplot(perturbation_score, aes(x="pvec", color=labels)) + geom_density() + theme_classic()
241
- p_copy = copy.deepcopy(p)
242
- p_copy._build()
243
- top_r = max(p_copy.layers[0].data["density"])
244
- perturbation_score["y_jitter"] = perturbation_score["pvec"]
245
- rng = np.random.default_rng()
246
- perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
247
- low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
248
- )
249
- perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
250
- low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
251
- )
252
- # If split_by is provided, split densities based on the split_by
253
- if split_by is not None:
254
- perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
255
- p2 = (
256
- p
257
- + scale_color_manual(values=cols, drop=False)
258
- + geom_density(size=1.5)
259
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
260
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
261
- + ylab("Cell density")
262
- + xlab("Perturbation score")
263
- + theme(
264
- legend_key_size=1,
265
- legend_text=element_text(colour="black", size=14),
266
- legend_title=element_blank(),
267
- plot_title=element_text(size=16, face="bold"),
268
- )
269
- + facet_wrap("split")
270
- )
271
- else:
272
- p2 = (
273
- p
274
- + scale_color_manual(values=cols, drop=False)
275
- + geom_density(size=1.5)
276
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
277
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
278
- + ylab("Cell density")
279
- + xlab("Perturbation score")
280
- + theme(
281
- legend_key_size=1,
282
- legend_text=element_text(colour="black", size=14),
283
- legend_title=element_blank(),
284
- plot_title=element_text(size=16, face="bold"),
285
- )
286
- )
287
- # If before_mixscape is False, split densities based on mixscape classifications
288
- else:
289
- cols = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
290
- p = ggplot(perturbation_score, aes(x="pvec", color="mix")) + geom_density() + theme_classic()
291
- p_copy = copy.deepcopy(p)
292
- p_copy._build()
293
- top_r = max(p_copy.layers[0].data["density"])
294
- perturbation_score["y_jitter"] = perturbation_score["pvec"]
295
- rng = np.random.default_rng()
296
- gd2 = list(
297
- set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
298
- )[0]
299
- perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
300
- low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
301
- )
302
- perturbation_score.loc[
303
- perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"
304
- ] = rng.uniform(
305
- low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
306
- )
307
- perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
308
- low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
309
- )
310
- # If split_by is provided, split densities based on the split_by
311
- if split_by is not None:
312
- perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
313
- p2 = (
314
- ggplot(perturbation_score, aes(x="pvec", color="mix"))
315
- + scale_color_manual(values=cols, drop=False)
316
- + geom_density(size=1.5)
317
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
318
- + theme_classic()
319
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
320
- + ylab("Cell density")
321
- + xlab("Perturbation score")
322
- + theme(
323
- legend_key_size=1,
324
- legend_text=element_text(colour="black", size=14),
325
- legend_title=element_blank(),
326
- plot_title=element_text(size=16, face="bold"),
327
- )
328
- + facet_wrap("split")
329
- )
330
- else:
331
- p2 = (
332
- p
333
- + scale_color_manual(values=cols, drop=False)
334
- + geom_density(size=1.5)
335
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
336
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
337
- + ylab("Cell density")
338
- + xlab("Perturbation score")
339
- + theme(
340
- legend_key_size=1,
341
- legend_text=element_text(colour="black", size=14),
342
- legend_title=element_blank(),
343
- plot_title=element_text(size=16, face="bold"),
344
- )
345
- )
346
- return p2
184
+ warnings.warn(
185
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
186
+ " Please use the corresponding 'pt.tl' object",
187
+ FutureWarning,
188
+ stacklevel=2,
189
+ )
190
+
191
+ from pertpy.tools import Mixscape
192
+
193
+ ms = Mixscape()
194
+ return ms.plot_perturbscore(
195
+ adata=adata,
196
+ labels=labels,
197
+ target_gene=target_gene,
198
+ mixscape_class=mixscape_class,
199
+ color=color,
200
+ split_by=split_by,
201
+ before_mixscape=before_mixscape,
202
+ perturbation_type=perturbation_type,
203
+ )
347
204
 
348
205
  @staticmethod
349
206
  def violin( # pragma: no cover
@@ -367,7 +224,7 @@ class MixscapePlot:
367
224
  show: bool | None = None,
368
225
  save: bool | str | None = None,
369
226
  ax: Axes | None = None,
370
- **kwds,
227
+ **kwargs,
371
228
  ):
372
229
  """Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first.
373
230
 
@@ -385,7 +242,7 @@ class MixscapePlot:
385
242
  show: Show the plot, do not return axis.
386
243
  save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
387
244
  ax: A matplotlib axes object. Only works if plotting a single component.
388
- **kwds: Additional arguments to `seaborn.violinplot`.
245
+ **kwargs: Additional arguments to `seaborn.violinplot`.
389
246
 
390
247
  Returns:
391
248
  A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
@@ -393,153 +250,46 @@ class MixscapePlot:
393
250
  Examples:
394
251
  >>> import pertpy as pt
395
252
  >>> mdata = pt.dt.papalexi_2021()
396
- >>> mixscape_identifier = pt.tl.Mixscape()
397
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
398
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
399
- >>> pt.pl.ms.violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
253
+ >>> ms = pt.tl.Mixscape()
254
+ >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
255
+ >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
256
+ >>> ms.plot_violin(
257
+ ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
258
+ ... )
400
259
  """
401
- if isinstance(target_gene_idents, str):
402
- mixscape_class_mask = adata.obs[groupby] == target_gene_idents
403
- elif isinstance(target_gene_idents, list):
404
- mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
405
- for ident in target_gene_idents:
406
- mixscape_class_mask |= adata.obs[groupby] == ident
407
- adata = adata[mixscape_class_mask]
408
-
409
- import seaborn as sns # Slow import, only import if called
410
-
411
- sanitize_anndata(adata)
412
- use_raw = _check_use_raw(adata, use_raw)
413
- if isinstance(keys, str):
414
- keys = [keys]
415
- keys = list(OrderedDict.fromkeys(keys)) # remove duplicates, preserving the order
416
-
417
- if isinstance(ylabel, (str, type(None))):
418
- ylabel = [ylabel] * (1 if groupby is None else len(keys))
419
- if groupby is None:
420
- if len(ylabel) != 1:
421
- raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
422
- elif len(ylabel) != len(keys):
423
- raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
424
-
425
- if groupby is not None:
426
- if hue is not None:
427
- obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
428
- else:
429
- obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
430
-
431
- else:
432
- obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
433
- if groupby is None:
434
- obs_tidy = pd.melt(obs_df, value_vars=keys)
435
- x = "variable"
436
- ys = ["value"]
437
- else:
438
- obs_tidy = obs_df
439
- x = groupby
440
- ys = keys
441
-
442
- if multi_panel and groupby is None and len(ys) == 1:
443
- # This is a quick and dirty way for adapting scales across several
444
- # keys if groupby is None.
445
- y = ys[0]
446
-
447
- g = sns.catplot(
448
- y=y,
449
- data=obs_tidy,
450
- kind="violin",
451
- scale=scale,
452
- col=x,
453
- col_order=keys,
454
- sharey=False,
455
- order=keys,
456
- cut=0,
457
- inner=None,
458
- **kwds,
459
- )
460
-
461
- if stripplot:
462
- grouped_df = obs_tidy.groupby(x)
463
- for ax_id, key in zip(range(g.axes.shape[1]), keys):
464
- sns.stripplot(
465
- y=y,
466
- data=grouped_df.get_group(key),
467
- jitter=jitter,
468
- size=size,
469
- color="black",
470
- ax=g.axes[0, ax_id],
471
- )
472
- if log:
473
- g.set(yscale="log")
474
- g.set_titles(col_template="{col_name}").set_xlabels("")
475
- if rotation is not None:
476
- for ax in g.axes[0]:
477
- ax.tick_params(axis="x", labelrotation=rotation)
478
- else:
479
- # set by default the violin plot cut=0 to limit the extend
480
- # of the violin plot (see stacked_violin code) for more info.
481
- kwds.setdefault("cut", 0)
482
- kwds.setdefault("inner")
483
-
484
- if ax is None:
485
- axs, _, _, _ = _utils.setup_axes(
486
- ax=ax,
487
- panels=["x"] if groupby is None else keys,
488
- show_ticks=True,
489
- right_margin=0.3,
490
- )
491
- else:
492
- axs = [ax]
493
- for ax, y, ylab in zip(axs, ys, ylabel): # noqa: F402
494
- ax = sns.violinplot(
495
- x=x,
496
- y=y,
497
- data=obs_tidy,
498
- order=order,
499
- orient="vertical",
500
- scale=scale,
501
- ax=ax,
502
- hue=hue,
503
- **kwds,
504
- )
505
- # Get the handles and labels.
506
- handles, labels = ax.get_legend_handles_labels()
507
- if stripplot:
508
- ax = sns.stripplot(
509
- x=x,
510
- y=y,
511
- data=obs_tidy,
512
- order=order,
513
- jitter=jitter,
514
- color="black",
515
- size=size,
516
- ax=ax,
517
- hue=hue,
518
- dodge=True,
519
- )
520
- if xlabel == "" and groupby is not None and rotation is None:
521
- xlabel = groupby.replace("_", " ")
522
- ax.set_xlabel(xlabel)
523
- if ylab is not None:
524
- ax.set_ylabel(ylab)
525
-
526
- if log:
527
- ax.set_yscale("log")
528
- if rotation is not None:
529
- ax.tick_params(axis="x", labelrotation=rotation)
530
-
531
- show = settings.autoshow if show is None else show
532
- if hue is not None and stripplot is True:
533
- pl.legend(handles, labels)
534
- _utils.savefig_or_show("mixscape_violin", show=show, save=save)
535
-
536
- if not show:
537
- if multi_panel and groupby is None and len(ys) == 1:
538
- return g
539
- elif len(axs) == 1:
540
- return axs[0]
541
- else:
542
- return axs
260
+ warnings.warn(
261
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
262
+ " Please use the corresponding 'pt.tl' object",
263
+ FutureWarning,
264
+ stacklevel=2,
265
+ )
266
+
267
+ from pertpy.tools import Mixscape
268
+
269
+ ms = Mixscape()
270
+ return ms.plot_violin(
271
+ adata=adata,
272
+ target_gene_idents=target_gene_idents,
273
+ keys=keys,
274
+ groupby=groupby,
275
+ log=log,
276
+ use_raw=use_raw,
277
+ stripplot=stripplot,
278
+ hue=hue,
279
+ jitter=jitter,
280
+ size=size,
281
+ layer=layer,
282
+ scale=scale,
283
+ order=order,
284
+ multi_panel=multi_panel,
285
+ xlabel=xlabel,
286
+ ylabel=ylabel,
287
+ rotation=rotation,
288
+ show=show,
289
+ save=save,
290
+ ax=ax,
291
+ **kwargs,
292
+ )
543
293
 
544
294
  @staticmethod
545
295
  def lda( # pragma: no cover
@@ -552,7 +302,7 @@ class MixscapePlot:
552
302
  n_components: int | None = None,
553
303
  show: bool | None = None,
554
304
  save: bool | str | None = None,
555
- **kwds,
305
+ **kwargs,
556
306
  ):
557
307
  """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
558
308
 
@@ -562,33 +312,44 @@ class MixscapePlot:
562
312
  labels: The column of `.obs` with target gene labels.
563
313
  mixscape_class: The column of `.obs` with the mixscape classification result.
564
314
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
565
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to 'KO'.
315
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
316
+ Defaults to 'KO'.
566
317
  lda_key: If not speficied, lda looks .uns["mixscape_lda"] for the LDA results.
567
318
  n_components: The number of dimensions of the embedding.
568
319
  show: Show the plot, do not return axis.
569
- save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
570
- **kwds: Additional arguments to `scanpy.pl.umap`.
320
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
321
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
322
+ **kwargs: Additional arguments to `scanpy.pl.umap`.
571
323
 
572
324
  Examples:
573
325
  >>> import pertpy as pt
574
326
  >>> mdata = pt.dt.papalexi_2021()
575
- >>> mixscape_identifier = pt.tl.Mixscape()
576
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
577
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
578
- >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
579
- >>> pt.pl.ms.lda(adata=mdata['rna'], control='NT')
327
+ >>> ms = pt.tl.Mixscape()
328
+ >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
329
+ >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
330
+ >>> ms.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
331
+ >>> ms.plot_lda(adata=mdata["rna"], control="NT")
580
332
  """
581
- if mixscape_class not in adata.obs:
582
- raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.')
583
- if lda_key not in adata.uns:
584
- raise ValueError(f'Did not find .uns["{lda_key!r}"]. Run `pt.tl.neighbors` first.')
585
-
586
- adata_subset = adata[
587
- (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
588
- ].copy()
589
- adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
590
- if n_components is None:
591
- n_components = adata_subset.uns[lda_key].shape[1]
592
- sc.pp.neighbors(adata_subset, use_rep=lda_key)
593
- sc.tl.umap(adata_subset, n_components=n_components)
594
- sc.pl.umap(adata_subset, color=mixscape_class, show=show, save=save, **kwds)
333
+ warnings.warn(
334
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
335
+ " Please use the corresponding 'pt.tl' object",
336
+ FutureWarning,
337
+ stacklevel=2,
338
+ )
339
+
340
+ from pertpy.tools import Mixscape
341
+
342
+ ms = Mixscape()
343
+
344
+ return ms.plot_lda(
345
+ adata=adata,
346
+ control=control,
347
+ mixscape_class=mixscape_class,
348
+ mixscape_class_global=mixscape_class_global,
349
+ perturbation_type=perturbation_type,
350
+ lda_key=lda_key,
351
+ n_components=n_components,
352
+ show=show,
353
+ save=save,
354
+ **kwargs,
355
+ )