pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_mixscape.py DELETED
@@ -1,594 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
- from collections import OrderedDict
5
- from typing import TYPE_CHECKING, Literal
6
-
7
- import numpy as np
8
- import pandas as pd
9
- import scanpy as sc
10
- from matplotlib import pyplot as pl
11
- from plotnine import (
12
- aes,
13
- element_blank,
14
- element_text,
15
- facet_wrap,
16
- geom_bar,
17
- geom_density,
18
- geom_point,
19
- ggplot,
20
- labs,
21
- scale_color_manual,
22
- scale_fill_manual,
23
- theme,
24
- theme_classic,
25
- xlab,
26
- ylab,
27
- )
28
- from scanpy import get
29
- from scanpy._settings import settings
30
- from scanpy._utils import _check_use_raw, sanitize_anndata
31
- from scanpy.plotting import _utils
32
-
33
- if TYPE_CHECKING:
34
- from collections.abc import Sequence
35
-
36
- from anndata import AnnData
37
- from matplotlib.axes import Axes
38
-
39
-
40
- class MixscapePlot:
41
- """Plotting functions for Mixscape."""
42
-
43
- @staticmethod
44
- def barplot( # pragma: no cover
45
- adata: AnnData,
46
- guide_rna_column: str,
47
- mixscape_class_global="mixscape_class_global",
48
- axis_text_x_size: int = 8,
49
- axis_text_y_size: int = 6,
50
- axis_title_size: int = 8,
51
- strip_text_size: int = 6,
52
- panel_spacing_x: float = 0.3,
53
- panel_spacing_y: float = 0.3,
54
- legend_title_size: int = 8,
55
- legend_text_size: int = 8,
56
- show: bool | None = None,
57
- save: bool | str | None = None,
58
- ):
59
- """Barplot to visualize perturbation scores calculated from RunMixscape function.
60
-
61
- Args:
62
- adata: The annotated data object.
63
- guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
64
- The format must be <gene_target>g<#>. For example, 'STAT2g1' and 'ATF2g1'.
65
- mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
66
- show: Show the plot, do not return axis.
67
- save: If True or a str, save the figure. A string is appended to the default filename.
68
- Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
69
-
70
- Returns:
71
- If show is False, return ggplot object used to draw the plot.
72
-
73
- Examples:
74
- >>> import pertpy as pt
75
- >>> mdata = pt.dt.papalexi_2021()
76
- >>> mixscape_identifier = pt.tl.Mixscape()
77
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
78
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
79
- >>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT')
80
- """
81
- if mixscape_class_global not in adata.obs:
82
- raise ValueError("Please run `pt.tl.mixscape` first.")
83
- count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
84
- all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
85
- KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
86
- KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
87
-
88
- new_levels = KO_cells_percentage[guide_rna_column]
89
- all_cells_percentage[guide_rna_column] = pd.Categorical(
90
- all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
91
- )
92
- all_cells_percentage[mixscape_class_global] = pd.Categorical(
93
- all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
94
- )
95
- all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
96
- all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
97
- all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
98
- NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
99
-
100
- p1 = (
101
- ggplot(NP_KO_cells, aes(x="guide_number", y="value", fill="mixscape_class_global"))
102
- + scale_fill_manual(values=["#7d7d7d", "#c9c9c9", "#ff7256"])
103
- + geom_bar(stat="identity")
104
- + theme_classic()
105
- + xlab("sgRNA")
106
- + ylab("% of cells")
107
- )
108
-
109
- p1 = (
110
- p1
111
- + theme(
112
- axis_text_x=element_text(size=axis_text_x_size, hjust=2),
113
- axis_text_y=element_text(size=axis_text_y_size),
114
- axis_title=element_text(size=axis_title_size),
115
- strip_text=element_text(size=strip_text_size, face="bold"),
116
- panel_spacing_x=panel_spacing_x,
117
- panel_spacing_y=panel_spacing_y,
118
- )
119
- + facet_wrap("gene", ncol=5, scales="free")
120
- + labs(fill="mixscape class")
121
- + theme(legend_title=element_text(size=legend_title_size), legend_text=element_text(size=legend_text_size))
122
- )
123
-
124
- _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
125
- if not show:
126
- return p1
127
-
128
- @staticmethod
129
- def heatmap( # pragma: no cover
130
- adata: AnnData,
131
- labels: str,
132
- target_gene: str,
133
- control: str,
134
- layer: str | None = None,
135
- method: str | None = "wilcoxon",
136
- subsample_number: int | None = 900,
137
- vmin: float | None = -2,
138
- vmax: float | None = 2,
139
- show: bool | None = None,
140
- save: bool | str | None = None,
141
- **kwds,
142
- ):
143
- """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
144
-
145
- Args:
146
- adata: The annotated data object.
147
- labels: The column of `.obs` with target gene labels.
148
- target_gene: Target gene name to visualize heatmap for.
149
- control: Control category from the `pert_key` column.
150
- layer: Key from `adata.layers` whose value will be used to perform tests on.
151
- method: The default method is 'wilcoxon', see `method` parameter in `scanpy.tl.rank_genes_groups` for more options.
152
- subsample_number: Subsample to this number of observations.
153
- vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
154
- vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
155
- show: Show the plot, do not return axis.
156
- save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
157
- ax: A matplotlib axes object. Only works if plotting a single component.
158
- **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
159
-
160
- Examples:
161
- >>> import pertpy as pt
162
- >>> mdata = pt.dt.papalexi_2021()
163
- >>> mixscape_identifier = pt.tl.Mixscape()
164
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
165
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
166
- >>> pt.pl.ms.heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
167
- """
168
- if "mixscape_class" not in adata.obs:
169
- raise ValueError("Please run `pt.tl.mixscape` first.")
170
- adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
171
- sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
172
- sc.pp.scale(adata_subset, max_value=vmax)
173
- sc.pp.subsample(adata_subset, n_obs=subsample_number)
174
- return sc.pl.rank_genes_groups_heatmap(
175
- adata_subset,
176
- groupby="mixscape_class",
177
- vmin=vmin,
178
- vmax=vmax,
179
- n_genes=20,
180
- groups=["NT"],
181
- show=show,
182
- save=save,
183
- **kwds,
184
- )
185
-
186
- @staticmethod
187
- def perturbscore( # pragma: no cover
188
- adata: AnnData,
189
- labels: str,
190
- target_gene: str,
191
- mixscape_class="mixscape_class",
192
- color="orange",
193
- split_by: str = None,
194
- before_mixscape=False,
195
- perturbation_type: str = "KO",
196
- ):
197
- """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. Requires `pt.tl.mixscape` to be run first.
198
-
199
- https://satijalab.org/seurat/reference/plotperturbscore
200
-
201
- Args:
202
- adata: The annotated data object.
203
- labels: The column of `.obs` with target gene labels.
204
- target_gene: Target gene name to visualize perturbation scores for.
205
- mixscape_class: The column of `.obs` with mixscape classifications.
206
- color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
207
- split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
208
- the perturbation signature for every replicate separately.
209
- before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification. Default is set to NULL and plots cells by original class ID.
210
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Default is KO.
211
-
212
- Returns:
213
- The ggplot object used for drawn.
214
-
215
- Examples:
216
- Visualizing the perturbation scores for the cells in a dataset:
217
-
218
- >>> import pertpy as pt
219
- >>> mdata = pt.dt.papalexi_2021()
220
- >>> mixscape_identifier = pt.tl.Mixscape()
221
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
222
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
223
- >>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
224
- """
225
- if "mixscape" not in adata.uns:
226
- raise ValueError("Please run `pt.tl.mixscape` first.")
227
- perturbation_score = None
228
- for key in adata.uns["mixscape"][target_gene].keys():
229
- perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
230
- perturbation_score_temp["name"] = key
231
- if perturbation_score is None:
232
- perturbation_score = copy.deepcopy(perturbation_score_temp)
233
- else:
234
- perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
235
- perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
236
- gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
237
- # If before_mixscape is True, split densities based on original target gene classification
238
- if before_mixscape is True:
239
- cols = {gd: "#7d7d7d", target_gene: color}
240
- p = ggplot(perturbation_score, aes(x="pvec", color=labels)) + geom_density() + theme_classic()
241
- p_copy = copy.deepcopy(p)
242
- p_copy._build()
243
- top_r = max(p_copy.layers[0].data["density"])
244
- perturbation_score["y_jitter"] = perturbation_score["pvec"]
245
- rng = np.random.default_rng()
246
- perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
247
- low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
248
- )
249
- perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
250
- low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
251
- )
252
- # If split_by is provided, split densities based on the split_by
253
- if split_by is not None:
254
- perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
255
- p2 = (
256
- p
257
- + scale_color_manual(values=cols, drop=False)
258
- + geom_density(size=1.5)
259
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
260
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
261
- + ylab("Cell density")
262
- + xlab("Perturbation score")
263
- + theme(
264
- legend_key_size=1,
265
- legend_text=element_text(colour="black", size=14),
266
- legend_title=element_blank(),
267
- plot_title=element_text(size=16, face="bold"),
268
- )
269
- + facet_wrap("split")
270
- )
271
- else:
272
- p2 = (
273
- p
274
- + scale_color_manual(values=cols, drop=False)
275
- + geom_density(size=1.5)
276
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
277
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
278
- + ylab("Cell density")
279
- + xlab("Perturbation score")
280
- + theme(
281
- legend_key_size=1,
282
- legend_text=element_text(colour="black", size=14),
283
- legend_title=element_blank(),
284
- plot_title=element_text(size=16, face="bold"),
285
- )
286
- )
287
- # If before_mixscape is False, split densities based on mixscape classifications
288
- else:
289
- cols = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
290
- p = ggplot(perturbation_score, aes(x="pvec", color="mix")) + geom_density() + theme_classic()
291
- p_copy = copy.deepcopy(p)
292
- p_copy._build()
293
- top_r = max(p_copy.layers[0].data["density"])
294
- perturbation_score["y_jitter"] = perturbation_score["pvec"]
295
- rng = np.random.default_rng()
296
- gd2 = list(
297
- set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
298
- )[0]
299
- perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
300
- low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
301
- )
302
- perturbation_score.loc[
303
- perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"
304
- ] = rng.uniform(
305
- low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
306
- )
307
- perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
308
- low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
309
- )
310
- # If split_by is provided, split densities based on the split_by
311
- if split_by is not None:
312
- perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
313
- p2 = (
314
- ggplot(perturbation_score, aes(x="pvec", color="mix"))
315
- + scale_color_manual(values=cols, drop=False)
316
- + geom_density(size=1.5)
317
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
318
- + theme_classic()
319
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
320
- + ylab("Cell density")
321
- + xlab("Perturbation score")
322
- + theme(
323
- legend_key_size=1,
324
- legend_text=element_text(colour="black", size=14),
325
- legend_title=element_blank(),
326
- plot_title=element_text(size=16, face="bold"),
327
- )
328
- + facet_wrap("split")
329
- )
330
- else:
331
- p2 = (
332
- p
333
- + scale_color_manual(values=cols, drop=False)
334
- + geom_density(size=1.5)
335
- + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
336
- + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
337
- + ylab("Cell density")
338
- + xlab("Perturbation score")
339
- + theme(
340
- legend_key_size=1,
341
- legend_text=element_text(colour="black", size=14),
342
- legend_title=element_blank(),
343
- plot_title=element_text(size=16, face="bold"),
344
- )
345
- )
346
- return p2
347
-
348
- @staticmethod
349
- def violin( # pragma: no cover
350
- adata: AnnData,
351
- target_gene_idents: str | list[str],
352
- keys: str | Sequence[str] = "mixscape_class_p_ko",
353
- groupby: str | None = "mixscape_class",
354
- log: bool = False,
355
- use_raw: bool | None = None,
356
- stripplot: bool = True,
357
- hue: str | None = None,
358
- jitter: float | bool = True,
359
- size: int = 1,
360
- layer: str | None = None,
361
- scale: Literal["area", "count", "width"] = "width",
362
- order: Sequence[str] | None = None,
363
- multi_panel: bool | None = None,
364
- xlabel: str = "",
365
- ylabel: str | Sequence[str] | None = None,
366
- rotation: float | None = None,
367
- show: bool | None = None,
368
- save: bool | str | None = None,
369
- ax: Axes | None = None,
370
- **kwds,
371
- ):
372
- """Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first.
373
-
374
- Args:
375
- adata: The annotated data object.
376
- target_gene: Target gene name to plot.
377
- keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
378
- groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
379
- log: Plot on logarithmic axis.
380
- use_raw: Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
381
- stripplot: Add a stripplot on top of the violin plot.
382
- order: Order in which to show the categories.
383
- xlabel: Label of the x axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
384
- ylabel: Label of the y axis. If `None` and `groupby` is `None`, defaults to `'value'`. If `None` and `groubpy` is not `None`, defaults to `keys`.
385
- show: Show the plot, do not return axis.
386
- save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
387
- ax: A matplotlib axes object. Only works if plotting a single component.
388
- **kwds: Additional arguments to `seaborn.violinplot`.
389
-
390
- Returns:
391
- A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
392
-
393
- Examples:
394
- >>> import pertpy as pt
395
- >>> mdata = pt.dt.papalexi_2021()
396
- >>> mixscape_identifier = pt.tl.Mixscape()
397
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
398
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
399
- >>> pt.pl.ms.violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
400
- """
401
- if isinstance(target_gene_idents, str):
402
- mixscape_class_mask = adata.obs[groupby] == target_gene_idents
403
- elif isinstance(target_gene_idents, list):
404
- mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
405
- for ident in target_gene_idents:
406
- mixscape_class_mask |= adata.obs[groupby] == ident
407
- adata = adata[mixscape_class_mask]
408
-
409
- import seaborn as sns # Slow import, only import if called
410
-
411
- sanitize_anndata(adata)
412
- use_raw = _check_use_raw(adata, use_raw)
413
- if isinstance(keys, str):
414
- keys = [keys]
415
- keys = list(OrderedDict.fromkeys(keys)) # remove duplicates, preserving the order
416
-
417
- if isinstance(ylabel, (str, type(None))):
418
- ylabel = [ylabel] * (1 if groupby is None else len(keys))
419
- if groupby is None:
420
- if len(ylabel) != 1:
421
- raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
422
- elif len(ylabel) != len(keys):
423
- raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
424
-
425
- if groupby is not None:
426
- if hue is not None:
427
- obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
428
- else:
429
- obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
430
-
431
- else:
432
- obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
433
- if groupby is None:
434
- obs_tidy = pd.melt(obs_df, value_vars=keys)
435
- x = "variable"
436
- ys = ["value"]
437
- else:
438
- obs_tidy = obs_df
439
- x = groupby
440
- ys = keys
441
-
442
- if multi_panel and groupby is None and len(ys) == 1:
443
- # This is a quick and dirty way for adapting scales across several
444
- # keys if groupby is None.
445
- y = ys[0]
446
-
447
- g = sns.catplot(
448
- y=y,
449
- data=obs_tidy,
450
- kind="violin",
451
- scale=scale,
452
- col=x,
453
- col_order=keys,
454
- sharey=False,
455
- order=keys,
456
- cut=0,
457
- inner=None,
458
- **kwds,
459
- )
460
-
461
- if stripplot:
462
- grouped_df = obs_tidy.groupby(x)
463
- for ax_id, key in zip(range(g.axes.shape[1]), keys):
464
- sns.stripplot(
465
- y=y,
466
- data=grouped_df.get_group(key),
467
- jitter=jitter,
468
- size=size,
469
- color="black",
470
- ax=g.axes[0, ax_id],
471
- )
472
- if log:
473
- g.set(yscale="log")
474
- g.set_titles(col_template="{col_name}").set_xlabels("")
475
- if rotation is not None:
476
- for ax in g.axes[0]:
477
- ax.tick_params(axis="x", labelrotation=rotation)
478
- else:
479
- # set by default the violin plot cut=0 to limit the extend
480
- # of the violin plot (see stacked_violin code) for more info.
481
- kwds.setdefault("cut", 0)
482
- kwds.setdefault("inner")
483
-
484
- if ax is None:
485
- axs, _, _, _ = _utils.setup_axes(
486
- ax=ax,
487
- panels=["x"] if groupby is None else keys,
488
- show_ticks=True,
489
- right_margin=0.3,
490
- )
491
- else:
492
- axs = [ax]
493
- for ax, y, ylab in zip(axs, ys, ylabel): # noqa: F402
494
- ax = sns.violinplot(
495
- x=x,
496
- y=y,
497
- data=obs_tidy,
498
- order=order,
499
- orient="vertical",
500
- scale=scale,
501
- ax=ax,
502
- hue=hue,
503
- **kwds,
504
- )
505
- # Get the handles and labels.
506
- handles, labels = ax.get_legend_handles_labels()
507
- if stripplot:
508
- ax = sns.stripplot(
509
- x=x,
510
- y=y,
511
- data=obs_tidy,
512
- order=order,
513
- jitter=jitter,
514
- color="black",
515
- size=size,
516
- ax=ax,
517
- hue=hue,
518
- dodge=True,
519
- )
520
- if xlabel == "" and groupby is not None and rotation is None:
521
- xlabel = groupby.replace("_", " ")
522
- ax.set_xlabel(xlabel)
523
- if ylab is not None:
524
- ax.set_ylabel(ylab)
525
-
526
- if log:
527
- ax.set_yscale("log")
528
- if rotation is not None:
529
- ax.tick_params(axis="x", labelrotation=rotation)
530
-
531
- show = settings.autoshow if show is None else show
532
- if hue is not None and stripplot is True:
533
- pl.legend(handles, labels)
534
- _utils.savefig_or_show("mixscape_violin", show=show, save=save)
535
-
536
- if not show:
537
- if multi_panel and groupby is None and len(ys) == 1:
538
- return g
539
- elif len(axs) == 1:
540
- return axs[0]
541
- else:
542
- return axs
543
-
544
- @staticmethod
545
- def lda( # pragma: no cover
546
- adata: AnnData,
547
- control: str,
548
- mixscape_class="mixscape_class",
549
- mixscape_class_global="mixscape_class_global",
550
- perturbation_type: str | None = "KO",
551
- lda_key: str | None = "mixscape_lda",
552
- n_components: int | None = None,
553
- show: bool | None = None,
554
- save: bool | str | None = None,
555
- **kwds,
556
- ):
557
- """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
558
-
559
- Args:
560
- adata: The annotated data object.
561
- control: Control category from the `pert_key` column.
562
- labels: The column of `.obs` with target gene labels.
563
- mixscape_class: The column of `.obs` with the mixscape classification result.
564
- mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
565
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to 'KO'.
566
- lda_key: If not speficied, lda looks .uns["mixscape_lda"] for the LDA results.
567
- n_components: The number of dimensions of the embedding.
568
- show: Show the plot, do not return axis.
569
- save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
570
- **kwds: Additional arguments to `scanpy.pl.umap`.
571
-
572
- Examples:
573
- >>> import pertpy as pt
574
- >>> mdata = pt.dt.papalexi_2021()
575
- >>> mixscape_identifier = pt.tl.Mixscape()
576
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
577
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
578
- >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
579
- >>> pt.pl.ms.lda(adata=mdata['rna'], control='NT')
580
- """
581
- if mixscape_class not in adata.obs:
582
- raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.')
583
- if lda_key not in adata.uns:
584
- raise ValueError(f'Did not find .uns["{lda_key!r}"]. Run `pt.tl.neighbors` first.')
585
-
586
- adata_subset = adata[
587
- (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
588
- ].copy()
589
- adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
590
- if n_components is None:
591
- n_components = adata_subset.uns[lda_key].shape[1]
592
- sc.pp.neighbors(adata_subset, use_rep=lda_key)
593
- sc.tl.umap(adata_subset, n_components=n_components)
594
- sc.pl.umap(adata_subset, color=mixscape_class, show=show, save=save, **kwds)