pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_mixscape.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import
|
4
|
-
from
|
3
|
+
import copy
|
4
|
+
from collections import OrderedDict
|
5
|
+
from typing import TYPE_CHECKING, Literal
|
5
6
|
|
7
|
+
import matplotlib.pyplot as plt
|
6
8
|
import numpy as np
|
7
9
|
import pandas as pd
|
8
10
|
import scanpy as sc
|
9
|
-
|
11
|
+
import seaborn as sns
|
12
|
+
from scanpy import get
|
13
|
+
from scanpy._settings import settings
|
14
|
+
from scanpy._utils import _check_use_raw, sanitize_anndata
|
15
|
+
from scanpy.plotting import _utils
|
10
16
|
from scanpy.tools._utils import _choose_representation
|
11
17
|
from scipy.sparse import csr_matrix, issparse, spmatrix
|
12
18
|
from sklearn.mixture import GaussianMixture
|
@@ -14,11 +20,13 @@ from sklearn.mixture import GaussianMixture
|
|
14
20
|
import pertpy as pt
|
15
21
|
|
16
22
|
if TYPE_CHECKING:
|
23
|
+
from collections.abc import Sequence
|
24
|
+
|
17
25
|
from anndata import AnnData
|
26
|
+
from matplotlib.axes import Axes
|
27
|
+
from matplotlib.colors import Colormap
|
18
28
|
from scipy import sparse
|
19
29
|
|
20
|
-
warnings.simplefilter("ignore")
|
21
|
-
|
22
30
|
|
23
31
|
class Mixscape:
|
24
32
|
"""Python implementation of Mixscape."""
|
@@ -65,15 +73,15 @@ class Mixscape:
|
|
65
73
|
|
66
74
|
Returns:
|
67
75
|
If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`.
|
68
|
-
Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
|
76
|
+
Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
|
69
77
|
|
70
78
|
Examples:
|
71
79
|
Calcutate perturbation signature for each cell in the dataset:
|
72
80
|
|
73
81
|
>>> import pertpy as pt
|
74
82
|
>>> mdata = pt.dt.papalexi_2021()
|
75
|
-
>>>
|
76
|
-
>>>
|
83
|
+
>>> ms_pt = pt.tl.Mixscape()
|
84
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
77
85
|
"""
|
78
86
|
if copy:
|
79
87
|
adata = adata.copy()
|
@@ -86,18 +94,17 @@ class Mixscape:
|
|
86
94
|
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
|
87
95
|
else:
|
88
96
|
split_obs = adata.obs[split_by]
|
89
|
-
|
90
|
-
split_masks = [split_obs == cat for cat in cats]
|
97
|
+
split_masks = [split_obs == cat for cat in split_obs.unique()]
|
91
98
|
|
92
|
-
|
99
|
+
representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
|
93
100
|
|
94
101
|
for split_mask in split_masks:
|
95
102
|
control_mask_split = control_mask & split_mask
|
96
103
|
|
97
|
-
R_split =
|
98
|
-
R_control =
|
104
|
+
R_split = representation[split_mask]
|
105
|
+
R_control = representation[control_mask_split]
|
99
106
|
|
100
|
-
from pynndescent import NNDescent
|
107
|
+
from pynndescent import NNDescent
|
101
108
|
|
102
109
|
eps = kwargs.pop("epsilon", 0.1)
|
103
110
|
nn_index = NNDescent(R_control, **kwargs)
|
@@ -161,7 +168,6 @@ class Mixscape:
|
|
161
168
|
|
162
169
|
Args:
|
163
170
|
adata: The annotated data object.
|
164
|
-
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
|
165
171
|
labels: The column of `.obs` with target gene labels.
|
166
172
|
control: Control category from the `pert_key` column.
|
167
173
|
new_class_name: Name of mixscape classification to be stored in `.obs`.
|
@@ -172,31 +178,31 @@ class Mixscape:
|
|
172
178
|
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
|
173
179
|
the perturbation signature for every replicate separately.
|
174
180
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
175
|
-
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
181
|
+
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
176
182
|
copy: Determines whether a copy of the `adata` is returned.
|
177
183
|
|
178
184
|
Returns:
|
179
185
|
If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
|
180
|
-
Otherwise writes the results directly to `.obs` of the provided `adata`.
|
186
|
+
Otherwise, writes the results directly to `.obs` of the provided `adata`.
|
181
187
|
|
182
|
-
mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
|
183
|
-
|
188
|
+
- mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
|
189
|
+
Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
|
184
190
|
|
185
|
-
mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
|
186
|
-
|
191
|
+
- mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
|
192
|
+
Global classification result (perturbed, NP or NT).
|
187
193
|
|
188
|
-
mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
|
189
|
-
|
190
|
-
|
194
|
+
- mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
|
195
|
+
Posterior probabilities used to determine if a cell is KO (default).
|
196
|
+
Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP.
|
191
197
|
|
192
198
|
Examples:
|
193
199
|
Calcutate perturbation signature for each cell in the dataset:
|
194
200
|
|
195
201
|
>>> import pertpy as pt
|
196
202
|
>>> mdata = pt.dt.papalexi_2021()
|
197
|
-
>>>
|
198
|
-
>>>
|
199
|
-
>>>
|
203
|
+
>>> ms_pt = pt.tl.Mixscape()
|
204
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
205
|
+
>>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
|
200
206
|
"""
|
201
207
|
if copy:
|
202
208
|
adata = adata.copy()
|
@@ -220,10 +226,9 @@ class Mixscape:
|
|
220
226
|
try:
|
221
227
|
X = adata_comp.layers["X_pert"]
|
222
228
|
except KeyError:
|
223
|
-
|
224
|
-
|
225
|
-
)
|
226
|
-
raise
|
229
|
+
raise KeyError(
|
230
|
+
"No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
|
231
|
+
) from None
|
227
232
|
# initialize return variables
|
228
233
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
|
229
234
|
adata.obs[new_class_name] = adata.obs[labels].astype(str)
|
@@ -305,12 +310,14 @@ class Mixscape:
|
|
305
310
|
old_classes = adata.obs[new_class_name][all_cells]
|
306
311
|
n_iter += 1
|
307
312
|
|
308
|
-
adata.obs.loc[
|
309
|
-
|
310
|
-
|
313
|
+
adata.obs.loc[(adata.obs[new_class_name] == gene) & split_mask, new_class_name] = (
|
314
|
+
f"{gene} {perturbation_type}"
|
315
|
+
)
|
311
316
|
|
312
317
|
adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
|
313
|
-
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] =
|
318
|
+
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = np.round(
|
319
|
+
post_prob
|
320
|
+
).astype("int64")
|
314
321
|
adata.uns["mixscape"] = gv_list
|
315
322
|
|
316
323
|
if copy:
|
@@ -339,18 +346,18 @@ class Mixscape:
|
|
339
346
|
control: Control category from the `pert_key` column.
|
340
347
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
341
348
|
layer: Key from `adata.layers` whose value will be used to perform tests on.
|
342
|
-
control: Control category from the `pert_key` column.
|
343
|
-
n_comps: Number of principal components to use.
|
349
|
+
control: Control category from the `pert_key` column.
|
350
|
+
n_comps: Number of principal components to use.
|
344
351
|
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
345
|
-
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
352
|
+
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
346
353
|
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
|
347
354
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
348
|
-
perturbation_type:
|
355
|
+
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
349
356
|
copy: Determines whether a copy of the `adata` is returned.
|
350
357
|
|
351
358
|
Returns:
|
352
359
|
If `copy=True`, returns the copy of `adata` with the LDA result in `.uns`.
|
353
|
-
Otherwise writes the results directly to `.uns` of the provided `adata`.
|
360
|
+
Otherwise, writes the results directly to `.uns` of the provided `adata`.
|
354
361
|
|
355
362
|
mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`).
|
356
363
|
LDA result.
|
@@ -360,10 +367,10 @@ class Mixscape:
|
|
360
367
|
|
361
368
|
>>> import pertpy as pt
|
362
369
|
>>> mdata = pt.dt.papalexi_2021()
|
363
|
-
>>>
|
364
|
-
>>>
|
365
|
-
>>>
|
366
|
-
>>>
|
370
|
+
>>> ms_pt = pt.tl.Mixscape()
|
371
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
372
|
+
>>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
|
373
|
+
>>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
|
367
374
|
"""
|
368
375
|
if copy:
|
369
376
|
adata = adata.copy()
|
@@ -437,7 +444,7 @@ class Mixscape:
|
|
437
444
|
min_de_genes: float,
|
438
445
|
logfc_threshold: float,
|
439
446
|
) -> dict[tuple, np.ndarray]:
|
440
|
-
"""
|
447
|
+
"""Determine gene sets across all splits/groups through differential gene expression
|
441
448
|
|
442
449
|
Args:
|
443
450
|
adata: :class:`~anndata.AnnData` object
|
@@ -454,7 +461,13 @@ class Mixscape:
|
|
454
461
|
adata_split = adata[split_mask].copy()
|
455
462
|
# find top DE genes between cells with targeting and non-targeting gRNAs
|
456
463
|
sc.tl.rank_genes_groups(
|
457
|
-
adata_split,
|
464
|
+
adata_split,
|
465
|
+
layer=layer,
|
466
|
+
groupby=labels,
|
467
|
+
groups=genes,
|
468
|
+
reference=control,
|
469
|
+
method="t-test",
|
470
|
+
use_raw=False,
|
458
471
|
)
|
459
472
|
# get DE genes for each gene
|
460
473
|
for gene in genes:
|
@@ -469,15 +482,6 @@ class Mixscape:
|
|
469
482
|
return perturbation_markers
|
470
483
|
|
471
484
|
def _get_column_indices(self, adata, col_names):
|
472
|
-
"""Fetches the column indices in X for a given list of column names
|
473
|
-
|
474
|
-
Args:
|
475
|
-
adata: :class:`~anndata.AnnData` object
|
476
|
-
col_names: Column names to extract the indices for
|
477
|
-
|
478
|
-
Returns:
|
479
|
-
Set of column indices
|
480
|
-
"""
|
481
485
|
if isinstance(col_names, str): # pragma: no cover
|
482
486
|
col_names = [col_names]
|
483
487
|
|
@@ -501,3 +505,621 @@ class Mixscape:
|
|
501
505
|
sd = X.std()
|
502
506
|
|
503
507
|
return [mu, sd]
|
508
|
+
|
509
|
+
def plot_barplot( # pragma: no cover
|
510
|
+
self,
|
511
|
+
adata: AnnData,
|
512
|
+
guide_rna_column: str,
|
513
|
+
mixscape_class_global: str = "mixscape_class_global",
|
514
|
+
axis_text_x_size: int = 8,
|
515
|
+
axis_text_y_size: int = 6,
|
516
|
+
axis_title_size: int = 8,
|
517
|
+
legend_title_size: int = 8,
|
518
|
+
legend_text_size: int = 8,
|
519
|
+
return_fig: bool | None = None,
|
520
|
+
ax: Axes | None = None,
|
521
|
+
show: bool | None = None,
|
522
|
+
save: bool | str | None = None,
|
523
|
+
):
|
524
|
+
"""Barplot to visualize perturbation scores calculated by the `mixscape` function.
|
525
|
+
|
526
|
+
Args:
|
527
|
+
adata: The annotated data object.
|
528
|
+
guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
|
529
|
+
The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
|
530
|
+
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
531
|
+
show: Show the plot, do not return axis.
|
532
|
+
save: If True or a str, save the figure. A string is appended to the default filename.
|
533
|
+
Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
|
534
|
+
|
535
|
+
Returns:
|
536
|
+
If `show==False`, return a :class:`~matplotlib.axes.Axes.
|
537
|
+
|
538
|
+
Examples:
|
539
|
+
>>> import pertpy as pt
|
540
|
+
>>> mdata = pt.dt.papalexi_2021()
|
541
|
+
>>> ms_pt = pt.tl.Mixscape()
|
542
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
543
|
+
>>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
|
544
|
+
>>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
|
545
|
+
|
546
|
+
Preview:
|
547
|
+
.. image:: /_static/docstring_previews/mixscape_barplot.png
|
548
|
+
"""
|
549
|
+
if mixscape_class_global not in adata.obs:
|
550
|
+
raise ValueError("Please run the `mixscape` function first.")
|
551
|
+
count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
|
552
|
+
all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
|
553
|
+
KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
|
554
|
+
KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
|
555
|
+
|
556
|
+
new_levels = KO_cells_percentage[guide_rna_column]
|
557
|
+
all_cells_percentage[guide_rna_column] = pd.Categorical(
|
558
|
+
all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
|
559
|
+
)
|
560
|
+
all_cells_percentage[mixscape_class_global] = pd.Categorical(
|
561
|
+
all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
|
562
|
+
)
|
563
|
+
all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
|
564
|
+
all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
|
565
|
+
all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
|
566
|
+
NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
|
567
|
+
|
568
|
+
if show:
|
569
|
+
color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
|
570
|
+
unique_genes = NP_KO_cells["gene"].unique()
|
571
|
+
fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True)
|
572
|
+
for i, gene in enumerate(unique_genes):
|
573
|
+
ax = axs[int(i / 5), i % 5]
|
574
|
+
grouped_df = (
|
575
|
+
NP_KO_cells[NP_KO_cells["gene"] == gene]
|
576
|
+
.groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
|
577
|
+
.sum()
|
578
|
+
.unstack()
|
579
|
+
)
|
580
|
+
grouped_df.plot(
|
581
|
+
kind="bar",
|
582
|
+
stacked=True,
|
583
|
+
color=[color_mapping[col] for col in grouped_df.columns],
|
584
|
+
ax=ax,
|
585
|
+
width=0.8,
|
586
|
+
legend=False,
|
587
|
+
)
|
588
|
+
ax.set_title(
|
589
|
+
gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size
|
590
|
+
)
|
591
|
+
ax.set(xlabel="sgRNA", ylabel="% of cells")
|
592
|
+
sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
|
593
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
|
594
|
+
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
|
595
|
+
fig.subplots_adjust(right=0.8)
|
596
|
+
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
597
|
+
ax.legend(
|
598
|
+
title="mixscape_class_global",
|
599
|
+
loc="center right",
|
600
|
+
bbox_to_anchor=(2.2, 3.5),
|
601
|
+
frameon=True,
|
602
|
+
fontsize=legend_text_size,
|
603
|
+
title_fontsize=legend_title_size,
|
604
|
+
)
|
605
|
+
|
606
|
+
plt.tight_layout()
|
607
|
+
_utils.savefig_or_show("mixscape_barplot", show=show, save=save)
|
608
|
+
|
609
|
+
def plot_heatmap( # pragma: no cover
|
610
|
+
self,
|
611
|
+
adata: AnnData,
|
612
|
+
labels: str,
|
613
|
+
target_gene: str,
|
614
|
+
control: str,
|
615
|
+
layer: str | None = None,
|
616
|
+
method: str | None = "wilcoxon",
|
617
|
+
subsample_number: int | None = 900,
|
618
|
+
vmin: float | None = -2,
|
619
|
+
vmax: float | None = 2,
|
620
|
+
return_fig: bool | None = None,
|
621
|
+
show: bool | None = None,
|
622
|
+
save: bool | str | None = None,
|
623
|
+
**kwds,
|
624
|
+
) -> Axes | None:
|
625
|
+
"""Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
|
626
|
+
|
627
|
+
Args:
|
628
|
+
adata: The annotated data object.
|
629
|
+
labels: The column of `.obs` with target gene labels.
|
630
|
+
target_gene: Target gene name to visualize heatmap for.
|
631
|
+
control: Control category from the `pert_key` column.
|
632
|
+
layer: Key from `adata.layers` whose value will be used to perform tests on.
|
633
|
+
method: The default method is 'wilcoxon', see `method` parameter in `scanpy.tl.rank_genes_groups` for more options.
|
634
|
+
subsample_number: Subsample to this number of observations.
|
635
|
+
vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
|
636
|
+
vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
|
637
|
+
show: Show the plot, do not return axis.
|
638
|
+
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
639
|
+
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
640
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
641
|
+
**kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
|
642
|
+
|
643
|
+
Returns:
|
644
|
+
If `show==False`, return a :class:`~matplotlib.axes.Axes`.
|
645
|
+
|
646
|
+
Examples:
|
647
|
+
>>> import pertpy as pt
|
648
|
+
>>> mdata = pt.dt.papalexi_2021()
|
649
|
+
>>> ms_pt = pt.tl.Mixscape()
|
650
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
651
|
+
>>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
|
652
|
+
>>> ms_pt.plot_heatmap(
|
653
|
+
... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
|
654
|
+
... )
|
655
|
+
|
656
|
+
Preview:
|
657
|
+
.. image:: /_static/docstring_previews/mixscape_heatmap.png
|
658
|
+
"""
|
659
|
+
if "mixscape_class" not in adata.obs:
|
660
|
+
raise ValueError("Please run `pt.tl.mixscape` first.")
|
661
|
+
adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
|
662
|
+
sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
|
663
|
+
sc.pp.scale(adata_subset, max_value=vmax)
|
664
|
+
sc.pp.subsample(adata_subset, n_obs=subsample_number)
|
665
|
+
|
666
|
+
return sc.pl.rank_genes_groups_heatmap(
|
667
|
+
adata_subset,
|
668
|
+
groupby="mixscape_class",
|
669
|
+
vmin=vmin,
|
670
|
+
vmax=vmax,
|
671
|
+
n_genes=20,
|
672
|
+
groups=["NT"],
|
673
|
+
return_fig=return_fig,
|
674
|
+
show=show,
|
675
|
+
save=save,
|
676
|
+
**kwds,
|
677
|
+
)
|
678
|
+
|
679
|
+
def plot_perturbscore( # pragma: no cover
|
680
|
+
self,
|
681
|
+
adata: AnnData,
|
682
|
+
labels: str,
|
683
|
+
target_gene: str,
|
684
|
+
mixscape_class: str = "mixscape_class",
|
685
|
+
color: str = "orange",
|
686
|
+
palette: dict[str, str] = None,
|
687
|
+
split_by: str = None,
|
688
|
+
before_mixscape: bool = False,
|
689
|
+
perturbation_type: str = "KO",
|
690
|
+
return_fig: bool | None = None,
|
691
|
+
ax: Axes | None = None,
|
692
|
+
show: bool | None = None,
|
693
|
+
save: bool | str | None = None,
|
694
|
+
) -> None:
|
695
|
+
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
|
696
|
+
|
697
|
+
Requires `pt.tl.mixscape` to be run first.
|
698
|
+
|
699
|
+
https://satijalab.org/seurat/reference/plotperturbscore
|
700
|
+
|
701
|
+
Args:
|
702
|
+
adata: The annotated data object.
|
703
|
+
labels: The column of `.obs` with target gene labels.
|
704
|
+
target_gene: Target gene name to visualize perturbation scores for.
|
705
|
+
mixscape_class: The column of `.obs` with mixscape classifications.
|
706
|
+
color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
|
707
|
+
palette: Optional full color palette to overwrite all colors.
|
708
|
+
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
|
709
|
+
the perturbation signature for every replicate separately.
|
710
|
+
before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
|
711
|
+
Default is set to NULL and plots cells by original class ID.
|
712
|
+
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
713
|
+
|
714
|
+
Examples:
|
715
|
+
Visualizing the perturbation scores for the cells in a dataset:
|
716
|
+
|
717
|
+
>>> import pertpy as pt
|
718
|
+
>>> mdata = pt.dt.papalexi_2021()
|
719
|
+
>>> ms_pt = pt.tl.Mixscape()
|
720
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
721
|
+
>>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
|
722
|
+
>>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
|
723
|
+
|
724
|
+
Preview:
|
725
|
+
.. image:: /_static/docstring_previews/mixscape_perturbscore.png
|
726
|
+
"""
|
727
|
+
if "mixscape" not in adata.uns:
|
728
|
+
raise ValueError("Please run the `mixscape` function first.")
|
729
|
+
perturbation_score = None
|
730
|
+
for key in adata.uns["mixscape"][target_gene].keys():
|
731
|
+
perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
|
732
|
+
perturbation_score_temp["name"] = key
|
733
|
+
if perturbation_score is None:
|
734
|
+
perturbation_score = copy.deepcopy(perturbation_score_temp)
|
735
|
+
else:
|
736
|
+
perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
|
737
|
+
perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
|
738
|
+
gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
|
739
|
+
|
740
|
+
# If before_mixscape is True, split densities based on original target gene classification
|
741
|
+
if before_mixscape is True:
|
742
|
+
palette = {gd: "#7d7d7d", target_gene: color}
|
743
|
+
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
|
744
|
+
top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
|
745
|
+
plt.close()
|
746
|
+
perturbation_score["y_jitter"] = perturbation_score["pvec"]
|
747
|
+
rng = np.random.default_rng()
|
748
|
+
perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
|
749
|
+
low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
|
750
|
+
)
|
751
|
+
perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
|
752
|
+
low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
|
753
|
+
)
|
754
|
+
# If split_by is provided, split densities based on the split_by
|
755
|
+
if split_by is not None:
|
756
|
+
sns.set_theme(style="whitegrid")
|
757
|
+
g = sns.FacetGrid(
|
758
|
+
data=perturbation_score, col=split_by, hue=split_by, palette=palette, height=5, sharey=False
|
759
|
+
)
|
760
|
+
g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, palette=palette)
|
761
|
+
g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5, palette=palette)
|
762
|
+
g.set_axis_labels("Perturbation score", "Cell density")
|
763
|
+
g.add_legend(title=split_by, fontsize=14, title_fontsize=16)
|
764
|
+
g.despine(left=True)
|
765
|
+
|
766
|
+
# If split_by is not provided, create a single plot
|
767
|
+
else:
|
768
|
+
sns.set_theme(style="whitegrid")
|
769
|
+
sns.kdeplot(
|
770
|
+
data=perturbation_score, x="pvec", hue="gene_target", fill=True, common_norm=False, palette=palette
|
771
|
+
)
|
772
|
+
sns.scatterplot(
|
773
|
+
data=perturbation_score, x="pvec", y="y_jitter", hue="gene_target", palette=palette, s=10, alpha=0.5
|
774
|
+
)
|
775
|
+
plt.xlabel("Perturbation score", fontsize=16)
|
776
|
+
plt.ylabel("Cell density", fontsize=16)
|
777
|
+
plt.title("Density Plot", fontsize=18)
|
778
|
+
plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
|
779
|
+
sns.despine()
|
780
|
+
|
781
|
+
if save:
|
782
|
+
plt.savefig(save, bbox_inches="tight")
|
783
|
+
if show:
|
784
|
+
plt.show()
|
785
|
+
if return_fig:
|
786
|
+
return plt.gcf()
|
787
|
+
if not (show or save):
|
788
|
+
return plt.gca()
|
789
|
+
|
790
|
+
# If before_mixscape is False, split densities based on mixscape classifications
|
791
|
+
else:
|
792
|
+
if palette is None:
|
793
|
+
palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
|
794
|
+
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
|
795
|
+
top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
|
796
|
+
plt.close()
|
797
|
+
perturbation_score["y_jitter"] = perturbation_score["pvec"]
|
798
|
+
rng = np.random.default_rng()
|
799
|
+
gd2 = list(
|
800
|
+
set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
|
801
|
+
)[0]
|
802
|
+
perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
|
803
|
+
low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
|
804
|
+
).astype(np.float32)
|
805
|
+
perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"] = (
|
806
|
+
rng.uniform(
|
807
|
+
low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
|
808
|
+
)
|
809
|
+
)
|
810
|
+
perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
|
811
|
+
low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
|
812
|
+
)
|
813
|
+
# If split_by is provided, split densities based on the split_by
|
814
|
+
if split_by is not None:
|
815
|
+
sns.set_theme(style="whitegrid")
|
816
|
+
g = sns.FacetGrid(
|
817
|
+
data=perturbation_score, col=split_by, hue="mix", palette=palette, height=5, sharey=False
|
818
|
+
)
|
819
|
+
g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, alpha=0.7)
|
820
|
+
g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5)
|
821
|
+
g.set_axis_labels("Perturbation score", "Cell density")
|
822
|
+
g.add_legend(title="mix", fontsize=14, title_fontsize=16)
|
823
|
+
g.despine(left=True)
|
824
|
+
|
825
|
+
# If split_by is not provided, create a single plot
|
826
|
+
else:
|
827
|
+
sns.set_theme(style="whitegrid")
|
828
|
+
sns.kdeplot(
|
829
|
+
data=perturbation_score,
|
830
|
+
x="pvec",
|
831
|
+
hue="mix",
|
832
|
+
fill=True,
|
833
|
+
common_norm=False,
|
834
|
+
palette=palette,
|
835
|
+
alpha=0.7,
|
836
|
+
)
|
837
|
+
sns.scatterplot(
|
838
|
+
data=perturbation_score, x="pvec", y="y_jitter", hue="mix", palette=palette, s=10, alpha=0.5
|
839
|
+
)
|
840
|
+
plt.xlabel("Perturbation score", fontsize=16)
|
841
|
+
plt.ylabel("Cell density", fontsize=16)
|
842
|
+
plt.title("Density", fontsize=18)
|
843
|
+
plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
|
844
|
+
sns.despine()
|
845
|
+
|
846
|
+
if save:
|
847
|
+
plt.savefig(save, bbox_inches="tight")
|
848
|
+
if show:
|
849
|
+
plt.show()
|
850
|
+
if return_fig:
|
851
|
+
return plt.gcf()
|
852
|
+
if not (show or save):
|
853
|
+
return plt.gca()
|
854
|
+
|
855
|
+
def plot_violin( # pragma: no cover
|
856
|
+
self,
|
857
|
+
adata: AnnData,
|
858
|
+
target_gene_idents: str | list[str],
|
859
|
+
keys: str | Sequence[str] = "mixscape_class_p_ko",
|
860
|
+
groupby: str | None = "mixscape_class",
|
861
|
+
log: bool = False,
|
862
|
+
use_raw: bool | None = None,
|
863
|
+
stripplot: bool = True,
|
864
|
+
hue: str | None = None,
|
865
|
+
jitter: float | bool = True,
|
866
|
+
size: int = 1,
|
867
|
+
layer: str | None = None,
|
868
|
+
scale: Literal["area", "count", "width"] = "width",
|
869
|
+
order: Sequence[str] | None = None,
|
870
|
+
multi_panel: bool | None = None,
|
871
|
+
xlabel: str = "",
|
872
|
+
ylabel: str | Sequence[str] | None = None,
|
873
|
+
rotation: float | None = None,
|
874
|
+
ax: Axes | None = None,
|
875
|
+
show: bool | None = None,
|
876
|
+
save: bool | str | None = None,
|
877
|
+
**kwargs,
|
878
|
+
):
|
879
|
+
"""Violin plot using mixscape results.
|
880
|
+
|
881
|
+
Requires `pt.tl.mixscape` to be run first.
|
882
|
+
|
883
|
+
Args:
|
884
|
+
adata: The annotated data object.
|
885
|
+
target_gene_idents: Target gene name to plot.
|
886
|
+
keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
|
887
|
+
groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
|
888
|
+
log: Plot on logarithmic axis.
|
889
|
+
use_raw: Whether to use `raw` attribute of `adata`.
|
890
|
+
stripplot: Add a stripplot on top of the violin plot.
|
891
|
+
order: Order in which to show the categories.
|
892
|
+
xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
|
893
|
+
ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
|
894
|
+
If `None` and `groubpy` is not `None`, defaults to `keys`.
|
895
|
+
show: Show the plot, do not return axis.
|
896
|
+
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
897
|
+
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
898
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
899
|
+
**kwargs: Additional arguments to `seaborn.violinplot`.
|
900
|
+
|
901
|
+
Returns:
|
902
|
+
A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
|
903
|
+
|
904
|
+
Examples:
|
905
|
+
>>> import pertpy as pt
|
906
|
+
>>> mdata = pt.dt.papalexi_2021()
|
907
|
+
>>> ms_pt = pt.tl.Mixscape()
|
908
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
909
|
+
>>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
|
910
|
+
>>> ms_pt.plot_violin(
|
911
|
+
... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
|
912
|
+
... )
|
913
|
+
|
914
|
+
Preview:
|
915
|
+
.. image:: /_static/docstring_previews/mixscape_violin.png
|
916
|
+
"""
|
917
|
+
if isinstance(target_gene_idents, str):
|
918
|
+
mixscape_class_mask = adata.obs[groupby] == target_gene_idents
|
919
|
+
elif isinstance(target_gene_idents, list):
|
920
|
+
mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
|
921
|
+
for ident in target_gene_idents:
|
922
|
+
mixscape_class_mask |= adata.obs[groupby] == ident
|
923
|
+
adata = adata[mixscape_class_mask]
|
924
|
+
|
925
|
+
sanitize_anndata(adata)
|
926
|
+
use_raw = _check_use_raw(adata, use_raw)
|
927
|
+
if isinstance(keys, str):
|
928
|
+
keys = [keys]
|
929
|
+
keys = list(OrderedDict.fromkeys(keys)) # remove duplicates, preserving the order
|
930
|
+
|
931
|
+
if isinstance(ylabel, str | type(None)):
|
932
|
+
ylabel = [ylabel] * (1 if groupby is None else len(keys))
|
933
|
+
if groupby is None:
|
934
|
+
if len(ylabel) != 1:
|
935
|
+
raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
|
936
|
+
elif len(ylabel) != len(keys):
|
937
|
+
raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
|
938
|
+
|
939
|
+
if groupby is not None:
|
940
|
+
if hue is not None:
|
941
|
+
obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
|
942
|
+
else:
|
943
|
+
obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
|
944
|
+
|
945
|
+
else:
|
946
|
+
obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
|
947
|
+
if groupby is None:
|
948
|
+
obs_tidy = pd.melt(obs_df, value_vars=keys)
|
949
|
+
x = "variable"
|
950
|
+
ys = ["value"]
|
951
|
+
else:
|
952
|
+
obs_tidy = obs_df
|
953
|
+
x = groupby
|
954
|
+
ys = keys
|
955
|
+
|
956
|
+
if multi_panel and groupby is None and len(ys) == 1:
|
957
|
+
# This is a quick and dirty way for adapting scales across several
|
958
|
+
# keys if groupby is None.
|
959
|
+
y = ys[0]
|
960
|
+
|
961
|
+
g = sns.catplot(
|
962
|
+
y=y,
|
963
|
+
data=obs_tidy,
|
964
|
+
kind="violin",
|
965
|
+
scale=scale,
|
966
|
+
col=x,
|
967
|
+
col_order=keys,
|
968
|
+
sharey=False,
|
969
|
+
order=keys,
|
970
|
+
cut=0,
|
971
|
+
inner=None,
|
972
|
+
**kwargs,
|
973
|
+
)
|
974
|
+
|
975
|
+
if stripplot:
|
976
|
+
grouped_df = obs_tidy.groupby(x)
|
977
|
+
for ax_id, key in zip(range(g.axes.shape[1]), keys, strict=False):
|
978
|
+
sns.stripplot(
|
979
|
+
y=y,
|
980
|
+
data=grouped_df.get_group(key),
|
981
|
+
jitter=jitter,
|
982
|
+
size=size,
|
983
|
+
color="black",
|
984
|
+
ax=g.axes[0, ax_id],
|
985
|
+
)
|
986
|
+
if log:
|
987
|
+
g.set(yscale="log")
|
988
|
+
g.set_titles(col_template="{col_name}").set_xlabels("")
|
989
|
+
if rotation is not None:
|
990
|
+
for ax in g.axes[0]:
|
991
|
+
ax.tick_params(axis="x", labelrotation=rotation)
|
992
|
+
else:
|
993
|
+
# set by default the violin plot cut=0 to limit the extend
|
994
|
+
# of the violin plot (see stacked_violin code) for more info.
|
995
|
+
kwargs.setdefault("cut", 0)
|
996
|
+
kwargs.setdefault("inner")
|
997
|
+
|
998
|
+
if ax is None:
|
999
|
+
axs, _, _, _ = _utils.setup_axes(
|
1000
|
+
ax=ax,
|
1001
|
+
panels=["x"] if groupby is None else keys,
|
1002
|
+
show_ticks=True,
|
1003
|
+
right_margin=0.3,
|
1004
|
+
)
|
1005
|
+
else:
|
1006
|
+
axs = [ax]
|
1007
|
+
for ax, y, ylab in zip(axs, ys, ylabel, strict=False):
|
1008
|
+
ax = sns.violinplot(
|
1009
|
+
x=x,
|
1010
|
+
y=y,
|
1011
|
+
data=obs_tidy,
|
1012
|
+
order=order,
|
1013
|
+
orient="vertical",
|
1014
|
+
scale=scale,
|
1015
|
+
ax=ax,
|
1016
|
+
hue=hue,
|
1017
|
+
**kwargs,
|
1018
|
+
)
|
1019
|
+
# Get the handles and labels.
|
1020
|
+
handles, labels = ax.get_legend_handles_labels()
|
1021
|
+
if stripplot:
|
1022
|
+
ax = sns.stripplot(
|
1023
|
+
x=x,
|
1024
|
+
y=y,
|
1025
|
+
data=obs_tidy,
|
1026
|
+
order=order,
|
1027
|
+
jitter=jitter,
|
1028
|
+
color="black",
|
1029
|
+
size=size,
|
1030
|
+
ax=ax,
|
1031
|
+
hue=hue,
|
1032
|
+
dodge=True,
|
1033
|
+
)
|
1034
|
+
if xlabel == "" and groupby is not None and rotation is None:
|
1035
|
+
xlabel = groupby.replace("_", " ")
|
1036
|
+
ax.set_xlabel(xlabel)
|
1037
|
+
if ylab is not None:
|
1038
|
+
ax.set_ylabel(ylab)
|
1039
|
+
|
1040
|
+
if log:
|
1041
|
+
ax.set_yscale("log")
|
1042
|
+
if rotation is not None:
|
1043
|
+
ax.tick_params(axis="x", labelrotation=rotation)
|
1044
|
+
|
1045
|
+
show = settings.autoshow if show is None else show
|
1046
|
+
if hue is not None and stripplot is True:
|
1047
|
+
plt.legend(handles, labels)
|
1048
|
+
_utils.savefig_or_show("mixscape_violin", show=show, save=save)
|
1049
|
+
|
1050
|
+
if not show:
|
1051
|
+
if multi_panel and groupby is None and len(ys) == 1:
|
1052
|
+
return g
|
1053
|
+
elif len(axs) == 1:
|
1054
|
+
return axs[0]
|
1055
|
+
else:
|
1056
|
+
return axs
|
1057
|
+
|
1058
|
+
def plot_lda( # pragma: no cover
|
1059
|
+
self,
|
1060
|
+
adata: AnnData,
|
1061
|
+
control: str,
|
1062
|
+
mixscape_class: str = "mixscape_class",
|
1063
|
+
mixscape_class_global: str = "mixscape_class_global",
|
1064
|
+
perturbation_type: str | None = "KO",
|
1065
|
+
lda_key: str | None = "mixscape_lda",
|
1066
|
+
n_components: int | None = None,
|
1067
|
+
color_map: Colormap | str | None = None,
|
1068
|
+
palette: str | Sequence[str] | None = None,
|
1069
|
+
return_fig: bool | None = None,
|
1070
|
+
ax: Axes | None = None,
|
1071
|
+
show: bool | None = None,
|
1072
|
+
save: bool | str | None = None,
|
1073
|
+
**kwds,
|
1074
|
+
) -> None:
|
1075
|
+
"""Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
|
1076
|
+
|
1077
|
+
Args:
|
1078
|
+
adata: The annotated data object.
|
1079
|
+
control: Control category from the `pert_key` column.
|
1080
|
+
mixscape_class: The column of `.obs` with the mixscape classification result.
|
1081
|
+
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
1082
|
+
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
1083
|
+
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
|
1084
|
+
n_components: The number of dimensions of the embedding.
|
1085
|
+
show: Show the plot, do not return axis.
|
1086
|
+
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
1087
|
+
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
1088
|
+
**kwds: Additional arguments to `scanpy.pl.umap`.
|
1089
|
+
|
1090
|
+
Examples:
|
1091
|
+
>>> import pertpy as pt
|
1092
|
+
>>> mdata = pt.dt.papalexi_2021()
|
1093
|
+
>>> ms_pt = pt.tl.Mixscape()
|
1094
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
1095
|
+
>>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
|
1096
|
+
>>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
|
1097
|
+
>>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
|
1098
|
+
|
1099
|
+
Preview:
|
1100
|
+
.. image:: /_static/docstring_previews/mixscape_lda.png
|
1101
|
+
"""
|
1102
|
+
if mixscape_class not in adata.obs:
|
1103
|
+
raise ValueError(f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first.')
|
1104
|
+
if lda_key not in adata.uns:
|
1105
|
+
raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Please run the `lda` function first.')
|
1106
|
+
|
1107
|
+
adata_subset = adata[
|
1108
|
+
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
|
1109
|
+
].copy()
|
1110
|
+
adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
|
1111
|
+
if n_components is None:
|
1112
|
+
n_components = adata_subset.uns[lda_key].shape[1]
|
1113
|
+
sc.pp.neighbors(adata_subset, use_rep=lda_key)
|
1114
|
+
sc.tl.umap(adata_subset, n_components=n_components)
|
1115
|
+
sc.pl.umap(
|
1116
|
+
adata_subset,
|
1117
|
+
color=mixscape_class,
|
1118
|
+
palette=palette,
|
1119
|
+
color_map=color_map,
|
1120
|
+
return_fig=return_fig,
|
1121
|
+
show=show,
|
1122
|
+
save=save,
|
1123
|
+
ax=ax,
|
1124
|
+
**kwds,
|
1125
|
+
)
|