pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_mixscape.py DELETED
@@ -1,594 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
- from collections import OrderedDict
5
- from typing import TYPE_CHECKING, Literal
6
-
7
- import numpy as np
8
- import pandas as pd
9
- import scanpy as sc
10
- from matplotlib import pyplot as pl
11
- from plotnine import (
12
- aes,
13
- element_blank,
14
- element_text,
15
- facet_wrap,
16
- geom_bar,
17
- geom_density,
18
- geom_point,
19
- ggplot,
20
- labs,
21
- scale_color_manual,
22
- scale_fill_manual,
23
- theme,
24
- theme_classic,
25
- xlab,
26
- ylab,
27
- )
28
- from scanpy import get
29
- from scanpy._settings import settings
30
- from scanpy._utils import _check_use_raw, sanitize_anndata
31
- from scanpy.plotting import _utils
32
-
33
- if TYPE_CHECKING:
34
- from collections.abc import Sequence
35
-
36
- from anndata import AnnData
37
- from matplotlib.axes import Axes
38
-
39
-
40
- class MixscapePlot:
41
- """Plotting functions for Mixscape."""
42
-
43
- @staticmethod
44
- def barplot( # pragma: no cover
45
- adata: AnnData,
46
- guide_rna_column: str,
47
- mixscape_class_global="mixscape_class_global",
48
- axis_text_x_size: int = 8,
49
- axis_text_y_size: int = 6,
50
- axis_title_size: int = 8,
51
- strip_text_size: int = 6,
52
- panel_spacing_x: float = 0.3,
53
- panel_spacing_y: float = 0.3,
54
- legend_title_size: int = 8,
55
- legend_text_size: int = 8,
56
- show: bool | None = None,
57
- save: bool | str | None = None,
58
- ):
59
- """Barplot to visualize perturbation scores calculated from RunMixscape function.
60
-
61
- Args:
62
- adata: The annotated data object.
63
- guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
64
- The format must be <gene_target>g<#>. For example, 'STAT2g1' and 'ATF2g1'.
65
- mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
66
- show: Show the plot, do not return axis.
67
- save: If True or a str, save the figure. A string is appended to the default filename.
68
- Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
69
-
70
- Returns:
71
- If show is False, return ggplot object used to draw the plot.
72
-
73
- Examples:
74
- >>> import pertpy as pt
75
- >>> mdata = pt.dt.papalexi_2021()
76
- >>> mixscape_identifier = pt.tl.Mixscape()
77
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
78
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
79
- >>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT')
80
- """
81
- if mixscape_class_global not in adata.obs:
82
- raise ValueError("Please run `pt.tl.mixscape` first.")
83
- count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
84
- all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
85
- KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
86
- KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
87
-
88
- new_levels = KO_cells_percentage[guide_rna_column]
89
- all_cells_percentage[guide_rna_column] = pd.Categorical(
90
- all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
91
- )
92
- all_cells_percentage[mixscape_class_global] = pd.Categorical(
93
- all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
94
- )
95
- all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
96
- all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
97
- all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
98
- NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
99
-
100
- p1 = (
101
- ggplot(NP_KO_cells, aes(x="guide_number", y="value", fill="mixscape_class_global"))
102
- + scale_fill_manual(values=["#7d7d7d", "#c9c9c9", "#ff7256"])
103
- + geom_bar(stat="identity")
104
- + theme_classic()
105
- + xlab("sgRNA")
106
- + ylab("% of cells")
107
- )
108
-
109
- p1 = (
110
- p1
111
- + theme(
112
- axis_text_x=element_text(size=axis_text_x_size, hjust=2),
113
- axis_text_y=element_text(size=axis_text_y_size),
114
- axis_title=element_text(size=axis_title_size),
115
- strip_text=element_text(size=strip_text_size, face="bold"),
116
- panel_spacing_x=panel_spacing_x,
117
- panel_spacing_y=panel_spacing_y,
118
- )
119
- + facet_wrap("gene", ncol=5, scales="free")
120
- + labs(fill="mixscape class")
121
- + theme(legend_title=element_text(size=legend_title_size), legend_text=element_text(size=legend_text_size))
122
- )
123
-
124
- _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
125
- if not show:
126
- return p1
127
-
128
- @staticmethod
129
- def heatmap( # pragma: no cover
130
- adata: AnnData,
131
- labels: str,
132
- target_gene: str,
133
- control: str,
134
- layer: str | None = None,
135
- method: str | None = "wilcoxon",
136
- subsample_number: int | None = 900,
137
- vmin: float | None = -2,
138
- vmax: float | None = 2,
139
- show: bool | None = None,
140
- save: bool | str | None = None,
141
- **kwds,
142
- ):
143
- """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
144
-
145
- Args:
146
- adata: The annotated data object.
147
- labels: The column of `.obs` with target gene labels.
148
- target_gene: Target gene name to visualize heatmap for.
149
- control: Control category from the `pert_key` column.
150
- layer: Key from `adata.layers` whose value will be used to perform tests on.
151
- method: The default method is 'wilcoxon', see `method` parameter in `scanpy.tl.rank_genes_groups` for more options.
152
- subsample_number: Subsample to this number of observations.
153
- vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
154
- vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
155
- show: Show the plot, do not return axis.
156
- save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
157
- ax: A matplotlib axes object. Only works if plotting a single component.
158
- **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
159
-
160
- Examples:
161
- >>> import pertpy as pt
162
- >>> mdata = pt.dt.papalexi_2021()
163
- >>> mixscape_identifier = pt.tl.Mixscape()
164
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
165
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
166
- >>> pt.pl.ms.heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
167
- """
168
- if "mixscape_class" not in adata.obs:
169
- raise ValueError("Please run `pt.tl.mixscape` first.")
170
- adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
171
- sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
172
- sc.pp.scale(adata_subset, max_value=vmax)
173
- sc.pp.subsample(adata_subset, n_obs=subsample_number)
174
- return sc.pl.rank_genes_groups_heatmap(
175
- adata_subset,
176
- groupby="mixscape_class",
177
- vmin=vmin,
178
- vmax=vmax,
179
- n_genes=20,
180
- groups=["NT"],
181
- show=show,
182
- save=save,
183
- **kwds,
184
- )
185
-
186
- @staticmethod
187
- def perturbscore( # pragma: no cover
188
- adata: AnnData,
189
- labels: str,
190
- target_gene: str,
191
- mixscape_class="mixscape_class",
192
- color="orange",
193
- split_by: str = None,
194
- before_mixscape=False,
195
- perturbation_type: str = "KO",
196
- ):
197
- """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. Requires `pt.tl.mixscape` to be run first.
198
-
199
- https://satijalab.org/seurat/reference/plotperturbscore
200
-
201
- Args:
202
- adata: The annotated data object.
203
- labels: The column of `.obs` with target gene labels.
204
- target_gene: Target gene name to visualize perturbation scores for.
205
- mixscape_class: The column of `.obs` with mixscape classifications.
206
- color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
207
- split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
208
- the perturbation signature for every replicate separately.
209
- before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification. Default is set to NULL and plots cells by original class ID.
210
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Default is KO.
211
-
212
- Returns:
213
- The ggplot object used for drawn.
214
-
215
- Examples:
216
- Visualizing the perturbation scores for the cells in a dataset:
217
-
218
- >>> import pertpy as pt
219
- >>> mdata = pt.dt.papalexi_2021()
220
- >>> mixscape_identifier = pt.tl.Mixscape()
221
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
222
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
223
- >>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
224
- """
225
- if "mixscape" not in adata.uns:
226
- raise ValueError("Please run `pt.tl.mixscape` first.")
227
- perturbation_score = None
228
- for key in adata.uns["mixscape"][target_gene].keys():
229
- perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
230
- perturbation_score_temp["name"] = key
231
- if perturbation_score is None:
232
- perturbation_score = copy.deepcopy(perturbation_score_temp)
233
- else:
234
- perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
235
- perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
236
- gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
237
- # If before_mixscape is True, split densities based on original target gene classification
238
- if before_mixscape is True:
239
- cols = {gd: "#7d7d7d", target_gene: color}
240
- p = ggplot(perturbation_score, aes(x="pvec", color=labels)) + geom_density() + theme_classic()
241
- p_copy = copy.deepcopy(p)
242
- p_copy._build()
243
- top_r = max(p_copy.layers[0].data["density"])
244
- perturbation_score["y_jitter"] = perturbation_score["pvec"]
245
- rng = np.random.default_rng()
246
- perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
247
- low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
248
- )
249
- perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
250
- low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
251
- )
252
- # If split_by is provided, split densities based on the split_by
253
- if split_by is not None:
254
- perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
255
- p2 = (
256
- p
257
- + scale_color_manual(values=cols, drop=False)
258
- + geom_density(size=1.5)
259
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
260
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
261
- + ylab("Cell density")
262
- + xlab("Perturbation score")
263
- + theme(
264
- legend_key_size=1,
265
- legend_text=element_text(colour="black", size=14),
266
- legend_title=element_blank(),
267
- plot_title=element_text(size=16, face="bold"),
268
- )
269
- + facet_wrap("split")
270
- )
271
- else:
272
- p2 = (
273
- p
274
- + scale_color_manual(values=cols, drop=False)
275
- + geom_density(size=1.5)
276
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
277
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
278
- + ylab("Cell density")
279
- + xlab("Perturbation score")
280
- + theme(
281
- legend_key_size=1,
282
- legend_text=element_text(colour="black", size=14),
283
- legend_title=element_blank(),
284
- plot_title=element_text(size=16, face="bold"),
285
- )
286
- )
287
- # If before_mixscape is False, split densities based on mixscape classifications
288
- else:
289
- cols = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
290
- p = ggplot(perturbation_score, aes(x="pvec", color="mix")) + geom_density() + theme_classic()
291
- p_copy = copy.deepcopy(p)
292
- p_copy._build()
293
- top_r = max(p_copy.layers[0].data["density"])
294
- perturbation_score["y_jitter"] = perturbation_score["pvec"]
295
- rng = np.random.default_rng()
296
- gd2 = list(
297
- set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
298
- )[0]
299
- perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
300
- low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
301
- )
302
- perturbation_score.loc[
303
- perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"
304
- ] = rng.uniform(
305
- low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
306
- )
307
- perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
308
- low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
309
- )
310
- # If split_by is provided, split densities based on the split_by
311
- if split_by is not None:
312
- perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
313
- p2 = (
314
- ggplot(perturbation_score, aes(x="pvec", color="mix"))
315
- + scale_color_manual(values=cols, drop=False)
316
- + geom_density(size=1.5)
317
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
318
- + theme_classic()
319
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
320
- + ylab("Cell density")
321
- + xlab("Perturbation score")
322
- + theme(
323
- legend_key_size=1,
324
- legend_text=element_text(colour="black", size=14),
325
- legend_title=element_blank(),
326
- plot_title=element_text(size=16, face="bold"),
327
- )
328
- + facet_wrap("split")
329
- )
330
- else:
331
- p2 = (
332
- p
333
- + scale_color_manual(values=cols, drop=False)
334
- + geom_density(size=1.5)
335
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
336
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
337
- + ylab("Cell density")
338
- + xlab("Perturbation score")
339
- + theme(
340
- legend_key_size=1,
341
- legend_text=element_text(colour="black", size=14),
342
- legend_title=element_blank(),
343
- plot_title=element_text(size=16, face="bold"),
344
- )
345
- )
346
- return p2
347
-
348
- @staticmethod
349
- def violin( # pragma: no cover
350
- adata: AnnData,
351
- target_gene_idents: str | list[str],
352
- keys: str | Sequence[str] = "mixscape_class_p_ko",
353
- groupby: str | None = "mixscape_class",
354
- log: bool = False,
355
- use_raw: bool | None = None,
356
- stripplot: bool = True,
357
- hue: str | None = None,
358
- jitter: float | bool = True,
359
- size: int = 1,
360
- layer: str | None = None,
361
- scale: Literal["area", "count", "width"] = "width",
362
- order: Sequence[str] | None = None,
363
- multi_panel: bool | None = None,
364
- xlabel: str = "",
365
- ylabel: str | Sequence[str] | None = None,
366
- rotation: float | None = None,
367
- show: bool | None = None,
368
- save: bool | str | None = None,
369
- ax: Axes | None = None,
370
- **kwds,
371
- ):
372
- """Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first.
373
-
374
- Args:
375
- adata: The annotated data object.
376
- target_gene: Target gene name to plot.
377
- keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
378
- groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
379
- log: Plot on logarithmic axis.
380
- use_raw: Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
381
- stripplot: Add a stripplot on top of the violin plot.
382
- order: Order in which to show the categories.
383
- xlabel: Label of the x axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
384
- ylabel: Label of the y axis. If `None` and `groupby` is `None`, defaults to `'value'`. If `None` and `groubpy` is not `None`, defaults to `keys`.
385
- show: Show the plot, do not return axis.
386
- save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
387
- ax: A matplotlib axes object. Only works if plotting a single component.
388
- **kwds: Additional arguments to `seaborn.violinplot`.
389
-
390
- Returns:
391
- A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
392
-
393
- Examples:
394
- >>> import pertpy as pt
395
- >>> mdata = pt.dt.papalexi_2021()
396
- >>> mixscape_identifier = pt.tl.Mixscape()
397
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
398
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
399
- >>> pt.pl.ms.violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
400
- """
401
- if isinstance(target_gene_idents, str):
402
- mixscape_class_mask = adata.obs[groupby] == target_gene_idents
403
- elif isinstance(target_gene_idents, list):
404
- mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
405
- for ident in target_gene_idents:
406
- mixscape_class_mask |= adata.obs[groupby] == ident
407
- adata = adata[mixscape_class_mask]
408
-
409
- import seaborn as sns # Slow import, only import if called
410
-
411
- sanitize_anndata(adata)
412
- use_raw = _check_use_raw(adata, use_raw)
413
- if isinstance(keys, str):
414
- keys = [keys]
415
- keys = list(OrderedDict.fromkeys(keys)) # remove duplicates, preserving the order
416
-
417
- if isinstance(ylabel, (str, type(None))):
418
- ylabel = [ylabel] * (1 if groupby is None else len(keys))
419
- if groupby is None:
420
- if len(ylabel) != 1:
421
- raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
422
- elif len(ylabel) != len(keys):
423
- raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
424
-
425
- if groupby is not None:
426
- if hue is not None:
427
- obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
428
- else:
429
- obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
430
-
431
- else:
432
- obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
433
- if groupby is None:
434
- obs_tidy = pd.melt(obs_df, value_vars=keys)
435
- x = "variable"
436
- ys = ["value"]
437
- else:
438
- obs_tidy = obs_df
439
- x = groupby
440
- ys = keys
441
-
442
- if multi_panel and groupby is None and len(ys) == 1:
443
- # This is a quick and dirty way for adapting scales across several
444
- # keys if groupby is None.
445
- y = ys[0]
446
-
447
- g = sns.catplot(
448
- y=y,
449
- data=obs_tidy,
450
- kind="violin",
451
- scale=scale,
452
- col=x,
453
- col_order=keys,
454
- sharey=False,
455
- order=keys,
456
- cut=0,
457
- inner=None,
458
- **kwds,
459
- )
460
-
461
- if stripplot:
462
- grouped_df = obs_tidy.groupby(x)
463
- for ax_id, key in zip(range(g.axes.shape[1]), keys):
464
- sns.stripplot(
465
- y=y,
466
- data=grouped_df.get_group(key),
467
- jitter=jitter,
468
- size=size,
469
- color="black",
470
- ax=g.axes[0, ax_id],
471
- )
472
- if log:
473
- g.set(yscale="log")
474
- g.set_titles(col_template="{col_name}").set_xlabels("")
475
- if rotation is not None:
476
- for ax in g.axes[0]:
477
- ax.tick_params(axis="x", labelrotation=rotation)
478
- else:
479
- # set by default the violin plot cut=0 to limit the extend
480
- # of the violin plot (see stacked_violin code) for more info.
481
- kwds.setdefault("cut", 0)
482
- kwds.setdefault("inner")
483
-
484
- if ax is None:
485
- axs, _, _, _ = _utils.setup_axes(
486
- ax=ax,
487
- panels=["x"] if groupby is None else keys,
488
- show_ticks=True,
489
- right_margin=0.3,
490
- )
491
- else:
492
- axs = [ax]
493
- for ax, y, ylab in zip(axs, ys, ylabel): # noqa: F402
494
- ax = sns.violinplot(
495
- x=x,
496
- y=y,
497
- data=obs_tidy,
498
- order=order,
499
- orient="vertical",
500
- scale=scale,
501
- ax=ax,
502
- hue=hue,
503
- **kwds,
504
- )
505
- # Get the handles and labels.
506
- handles, labels = ax.get_legend_handles_labels()
507
- if stripplot:
508
- ax = sns.stripplot(
509
- x=x,
510
- y=y,
511
- data=obs_tidy,
512
- order=order,
513
- jitter=jitter,
514
- color="black",
515
- size=size,
516
- ax=ax,
517
- hue=hue,
518
- dodge=True,
519
- )
520
- if xlabel == "" and groupby is not None and rotation is None:
521
- xlabel = groupby.replace("_", " ")
522
- ax.set_xlabel(xlabel)
523
- if ylab is not None:
524
- ax.set_ylabel(ylab)
525
-
526
- if log:
527
- ax.set_yscale("log")
528
- if rotation is not None:
529
- ax.tick_params(axis="x", labelrotation=rotation)
530
-
531
- show = settings.autoshow if show is None else show
532
- if hue is not None and stripplot is True:
533
- pl.legend(handles, labels)
534
- _utils.savefig_or_show("mixscape_violin", show=show, save=save)
535
-
536
- if not show:
537
- if multi_panel and groupby is None and len(ys) == 1:
538
- return g
539
- elif len(axs) == 1:
540
- return axs[0]
541
- else:
542
- return axs
543
-
544
- @staticmethod
545
- def lda( # pragma: no cover
546
- adata: AnnData,
547
- control: str,
548
- mixscape_class="mixscape_class",
549
- mixscape_class_global="mixscape_class_global",
550
- perturbation_type: str | None = "KO",
551
- lda_key: str | None = "mixscape_lda",
552
- n_components: int | None = None,
553
- show: bool | None = None,
554
- save: bool | str | None = None,
555
- **kwds,
556
- ):
557
- """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
558
-
559
- Args:
560
- adata: The annotated data object.
561
- control: Control category from the `pert_key` column.
562
- labels: The column of `.obs` with target gene labels.
563
- mixscape_class: The column of `.obs` with the mixscape classification result.
564
- mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
565
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to 'KO'.
566
- lda_key: If not speficied, lda looks .uns["mixscape_lda"] for the LDA results.
567
- n_components: The number of dimensions of the embedding.
568
- show: Show the plot, do not return axis.
569
- save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
570
- **kwds: Additional arguments to `scanpy.pl.umap`.
571
-
572
- Examples:
573
- >>> import pertpy as pt
574
- >>> mdata = pt.dt.papalexi_2021()
575
- >>> mixscape_identifier = pt.tl.Mixscape()
576
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
577
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
578
- >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
579
- >>> pt.pl.ms.lda(adata=mdata['rna'], control='NT')
580
- """
581
- if mixscape_class not in adata.obs:
582
- raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.')
583
- if lda_key not in adata.uns:
584
- raise ValueError(f'Did not find .uns["{lda_key!r}"]. Run `pt.tl.neighbors` first.')
585
-
586
- adata_subset = adata[
587
- (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
588
- ].copy()
589
- adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
590
- if n_components is None:
591
- n_components = adata_subset.uns[lda_key].shape[1]
592
- sc.pp.neighbors(adata_subset, use_rep=lda_key)
593
- sc.tl.umap(adata_subset, n_components=n_components)
594
- sc.pl.umap(adata_subset, color=mixscape_class, show=show, save=save, **kwds)