pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_mixscape.py CHANGED
@@ -1,12 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
- import warnings
4
- from typing import TYPE_CHECKING
3
+ import copy
4
+ from collections import OrderedDict
5
+ from typing import TYPE_CHECKING, Literal
5
6
 
7
+ import matplotlib.pyplot as plt
6
8
  import numpy as np
7
9
  import pandas as pd
8
10
  import scanpy as sc
9
- from rich import print
11
+ import seaborn as sns
12
+ from scanpy import get
13
+ from scanpy._settings import settings
14
+ from scanpy._utils import _check_use_raw, sanitize_anndata
15
+ from scanpy.plotting import _utils
10
16
  from scanpy.tools._utils import _choose_representation
11
17
  from scipy.sparse import csr_matrix, issparse, spmatrix
12
18
  from sklearn.mixture import GaussianMixture
@@ -14,11 +20,13 @@ from sklearn.mixture import GaussianMixture
14
20
  import pertpy as pt
15
21
 
16
22
  if TYPE_CHECKING:
23
+ from collections.abc import Sequence
24
+
17
25
  from anndata import AnnData
26
+ from matplotlib.axes import Axes
27
+ from matplotlib.colors import Colormap
18
28
  from scipy import sparse
19
29
 
20
- warnings.simplefilter("ignore")
21
-
22
30
 
23
31
  class Mixscape:
24
32
  """Python implementation of Mixscape."""
@@ -65,15 +73,15 @@ class Mixscape:
65
73
 
66
74
  Returns:
67
75
  If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`.
68
- Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
76
+ Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
69
77
 
70
78
  Examples:
71
79
  Calcutate perturbation signature for each cell in the dataset:
72
80
 
73
81
  >>> import pertpy as pt
74
82
  >>> mdata = pt.dt.papalexi_2021()
75
- >>> mixscape_identifier = pt.tl.Mixscape()
76
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
83
+ >>> ms_pt = pt.tl.Mixscape()
84
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
77
85
  """
78
86
  if copy:
79
87
  adata = adata.copy()
@@ -86,18 +94,17 @@ class Mixscape:
86
94
  split_masks = [np.full(adata.n_obs, True, dtype=bool)]
87
95
  else:
88
96
  split_obs = adata.obs[split_by]
89
- cats = split_obs.unique()
90
- split_masks = [split_obs == cat for cat in cats]
97
+ split_masks = [split_obs == cat for cat in split_obs.unique()]
91
98
 
92
- R = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
99
+ representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
93
100
 
94
101
  for split_mask in split_masks:
95
102
  control_mask_split = control_mask & split_mask
96
103
 
97
- R_split = R[split_mask]
98
- R_control = R[control_mask_split]
104
+ R_split = representation[split_mask]
105
+ R_control = representation[control_mask_split]
99
106
 
100
- from pynndescent import NNDescent # saves a lot of import time
107
+ from pynndescent import NNDescent
101
108
 
102
109
  eps = kwargs.pop("epsilon", 0.1)
103
110
  nn_index = NNDescent(R_control, **kwargs)
@@ -161,7 +168,6 @@ class Mixscape:
161
168
 
162
169
  Args:
163
170
  adata: The annotated data object.
164
- pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
165
171
  labels: The column of `.obs` with target gene labels.
166
172
  control: Control category from the `pert_key` column.
167
173
  new_class_name: Name of mixscape classification to be stored in `.obs`.
@@ -177,26 +183,26 @@ class Mixscape:
177
183
 
178
184
  Returns:
179
185
  If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
180
- Otherwise writes the results directly to `.obs` of the provided `adata`.
186
+ Otherwise, writes the results directly to `.obs` of the provided `adata`.
181
187
 
182
- mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
183
- Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
188
+ - mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
189
+ Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
184
190
 
185
- mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
186
- Global classification result (perturbed, NP or NT)
191
+ - mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
192
+ Global classification result (perturbed, NP or NT).
187
193
 
188
- mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
189
- Posterior probabilities used to determine if a cell is KO (default).
190
- Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP
194
+ - mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
195
+ Posterior probabilities used to determine if a cell is KO (default).
196
+ Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP.
191
197
 
192
198
  Examples:
193
199
  Calcutate perturbation signature for each cell in the dataset:
194
200
 
195
201
  >>> import pertpy as pt
196
202
  >>> mdata = pt.dt.papalexi_2021()
197
- >>> mixscape_identifier = pt.tl.Mixscape()
198
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
199
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
203
+ >>> ms_pt = pt.tl.Mixscape()
204
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
205
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
200
206
  """
201
207
  if copy:
202
208
  adata = adata.copy()
@@ -220,10 +226,9 @@ class Mixscape:
220
226
  try:
221
227
  X = adata_comp.layers["X_pert"]
222
228
  except KeyError:
223
- print(
224
- '[bold yellow]No "X_pert" found in .layers! -- Please run pert_sign first to calculate perturbation signature!'
225
- )
226
- raise
229
+ raise KeyError(
230
+ "No 'X_pert' found in .layers! Please run pert_sign first to calculate perturbation signature!"
231
+ ) from None
227
232
  # initialize return variables
228
233
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
229
234
  adata.obs[new_class_name] = adata.obs[labels].astype(str)
@@ -305,9 +310,9 @@ class Mixscape:
305
310
  old_classes = adata.obs[new_class_name][all_cells]
306
311
  n_iter += 1
307
312
 
308
- adata.obs.loc[
309
- (adata.obs[new_class_name] == gene) & split_mask, new_class_name
310
- ] = f"{gene} {perturbation_type}"
313
+ adata.obs.loc[(adata.obs[new_class_name] == gene) & split_mask, new_class_name] = (
314
+ f"{gene} {perturbation_type}"
315
+ )
311
316
 
312
317
  adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
313
318
  adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
@@ -342,15 +347,17 @@ class Mixscape:
342
347
  control: Control category from the `pert_key` column. Defaults to 'NT'.
343
348
  n_comps: Number of principal components to use. Defaults to 10.
344
349
  min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
345
- logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. Defaults to 0.25.
350
+ logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
351
+ Defaults to 0.25.
346
352
  split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
347
353
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
348
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to KO.
354
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
355
+ Defaults to KO.
349
356
  copy: Determines whether a copy of the `adata` is returned.
350
357
 
351
358
  Returns:
352
359
  If `copy=True`, returns the copy of `adata` with the LDA result in `.uns`.
353
- Otherwise writes the results directly to `.uns` of the provided `adata`.
360
+ Otherwise, writes the results directly to `.uns` of the provided `adata`.
354
361
 
355
362
  mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`).
356
363
  LDA result.
@@ -360,10 +367,10 @@ class Mixscape:
360
367
 
361
368
  >>> import pertpy as pt
362
369
  >>> mdata = pt.dt.papalexi_2021()
363
- >>> mixscape_identifier = pt.tl.Mixscape()
364
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
365
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
366
- >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
370
+ >>> ms_pt = pt.tl.Mixscape()
371
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
372
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
373
+ >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
367
374
  """
368
375
  if copy:
369
376
  adata = adata.copy()
@@ -437,7 +444,7 @@ class Mixscape:
437
444
  min_de_genes: float,
438
445
  logfc_threshold: float,
439
446
  ) -> dict[tuple, np.ndarray]:
440
- """determine gene sets across all splits/groups through differential gene expression
447
+ """Determine gene sets across all splits/groups through differential gene expression
441
448
 
442
449
  Args:
443
450
  adata: :class:`~anndata.AnnData` object
@@ -469,15 +476,6 @@ class Mixscape:
469
476
  return perturbation_markers
470
477
 
471
478
  def _get_column_indices(self, adata, col_names):
472
- """Fetches the column indices in X for a given list of column names
473
-
474
- Args:
475
- adata: :class:`~anndata.AnnData` object
476
- col_names: Column names to extract the indices for
477
-
478
- Returns:
479
- Set of column indices
480
- """
481
479
  if isinstance(col_names, str): # pragma: no cover
482
480
  col_names = [col_names]
483
481
 
@@ -501,3 +499,623 @@ class Mixscape:
501
499
  sd = X.std()
502
500
 
503
501
  return [mu, sd]
502
+
503
+ def plot_barplot( # pragma: no cover
504
+ self,
505
+ adata: AnnData,
506
+ guide_rna_column: str,
507
+ mixscape_class_global: str = "mixscape_class_global",
508
+ axis_text_x_size: int = 8,
509
+ axis_text_y_size: int = 6,
510
+ axis_title_size: int = 8,
511
+ legend_title_size: int = 8,
512
+ legend_text_size: int = 8,
513
+ return_fig: bool | None = None,
514
+ ax: Axes | None = None,
515
+ show: bool | None = None,
516
+ save: bool | str | None = None,
517
+ ):
518
+ """Barplot to visualize perturbation scores calculated by the `mixscape` function.
519
+
520
+ Args:
521
+ adata: The annotated data object.
522
+ guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
523
+ The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
524
+ mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
525
+ show: Show the plot, do not return axis.
526
+ save: If True or a str, save the figure. A string is appended to the default filename.
527
+ Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
528
+
529
+ Returns:
530
+ If `show==False`, return a :class:`~matplotlib.axes.Axes.
531
+
532
+ Examples:
533
+ >>> import pertpy as pt
534
+ >>> mdata = pt.dt.papalexi_2021()
535
+ >>> ms_pt = pt.tl.Mixscape()
536
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
537
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
538
+ >>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
539
+
540
+ Preview:
541
+ .. image:: /_static/docstring_previews/mixscape_barplot.png
542
+ """
543
+ if mixscape_class_global not in adata.obs:
544
+ raise ValueError("Please run the `mixscape` function first.")
545
+ count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
546
+ all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
547
+ KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
548
+ KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
549
+
550
+ new_levels = KO_cells_percentage[guide_rna_column]
551
+ all_cells_percentage[guide_rna_column] = pd.Categorical(
552
+ all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
553
+ )
554
+ all_cells_percentage[mixscape_class_global] = pd.Categorical(
555
+ all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
556
+ )
557
+ all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
558
+ all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
559
+ all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
560
+ NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
561
+
562
+ if show:
563
+ color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
564
+ unique_genes = NP_KO_cells["gene"].unique()
565
+ fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True)
566
+ for i, gene in enumerate(unique_genes):
567
+ ax = axs[int(i / 5), i % 5]
568
+ grouped_df = (
569
+ NP_KO_cells[NP_KO_cells["gene"] == gene]
570
+ .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
571
+ .sum()
572
+ .unstack()
573
+ )
574
+ grouped_df.plot(
575
+ kind="bar",
576
+ stacked=True,
577
+ color=[color_mapping[col] for col in grouped_df.columns],
578
+ ax=ax,
579
+ width=0.8,
580
+ legend=False,
581
+ )
582
+ ax.set_title(
583
+ gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size
584
+ )
585
+ ax.set(xlabel="sgRNA", ylabel="% of cells")
586
+ sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
587
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
588
+ ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
589
+ fig.subplots_adjust(right=0.8)
590
+ fig.subplots_adjust(hspace=0.5, wspace=0.5)
591
+ ax.legend(
592
+ title="mixscape_class_global",
593
+ loc="center right",
594
+ bbox_to_anchor=(2.2, 3.5),
595
+ frameon=True,
596
+ fontsize=legend_text_size,
597
+ title_fontsize=legend_title_size,
598
+ )
599
+
600
+ plt.tight_layout()
601
+ _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
602
+
603
+ def plot_heatmap( # pragma: no cover
604
+ self,
605
+ adata: AnnData,
606
+ labels: str,
607
+ target_gene: str,
608
+ control: str,
609
+ layer: str | None = None,
610
+ method: str | None = "wilcoxon",
611
+ subsample_number: int | None = 900,
612
+ vmin: float | None = -2,
613
+ vmax: float | None = 2,
614
+ return_fig: bool | None = None,
615
+ show: bool | None = None,
616
+ save: bool | str | None = None,
617
+ **kwds,
618
+ ) -> Axes | None:
619
+ """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
620
+
621
+ Args:
622
+ adata: The annotated data object.
623
+ labels: The column of `.obs` with target gene labels.
624
+ target_gene: Target gene name to visualize heatmap for.
625
+ control: Control category from the `pert_key` column.
626
+ layer: Key from `adata.layers` whose value will be used to perform tests on.
627
+ method: The default method is 'wilcoxon', see `method` parameter in `scanpy.tl.rank_genes_groups` for more options.
628
+ subsample_number: Subsample to this number of observations.
629
+ vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
630
+ vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
631
+ show: Show the plot, do not return axis.
632
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
633
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
634
+ ax: A matplotlib axes object. Only works if plotting a single component.
635
+ **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
636
+
637
+ Returns:
638
+ If `show==False`, return a :class:`~matplotlib.axes.Axes`.
639
+
640
+ Examples:
641
+ >>> import pertpy as pt
642
+ >>> mdata = pt.dt.papalexi_2021()
643
+ >>> ms_pt = pt.tl.Mixscape()
644
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
645
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
646
+ >>> ms_pt.plot_heatmap(
647
+ ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
648
+ ... )
649
+
650
+ Preview:
651
+ .. image:: /_static/docstring_previews/mixscape_heatmap.png
652
+ """
653
+ if "mixscape_class" not in adata.obs:
654
+ raise ValueError("Please run `pt.tl.mixscape` first.")
655
+ adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
656
+ sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
657
+ sc.pp.scale(adata_subset, max_value=vmax)
658
+ sc.pp.subsample(adata_subset, n_obs=subsample_number)
659
+
660
+ return sc.pl.rank_genes_groups_heatmap(
661
+ adata_subset,
662
+ groupby="mixscape_class",
663
+ vmin=vmin,
664
+ vmax=vmax,
665
+ n_genes=20,
666
+ groups=["NT"],
667
+ return_fig=return_fig,
668
+ show=show,
669
+ save=save,
670
+ **kwds,
671
+ )
672
+
673
+ def plot_perturbscore( # pragma: no cover
674
+ self,
675
+ adata: AnnData,
676
+ labels: str,
677
+ target_gene: str,
678
+ mixscape_class: str = "mixscape_class",
679
+ color: str = "orange",
680
+ palette: dict[str, str] = None,
681
+ split_by: str = None,
682
+ before_mixscape: bool = False,
683
+ perturbation_type: str = "KO",
684
+ return_fig: bool | None = None,
685
+ ax: Axes | None = None,
686
+ show: bool | None = None,
687
+ save: bool | str | None = None,
688
+ ) -> None:
689
+ """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
690
+
691
+ Requires `pt.tl.mixscape` to be run first.
692
+
693
+ https://satijalab.org/seurat/reference/plotperturbscore
694
+
695
+ Args:
696
+ adata: The annotated data object.
697
+ labels: The column of `.obs` with target gene labels.
698
+ target_gene: Target gene name to visualize perturbation scores for.
699
+ mixscape_class: The column of `.obs` with mixscape classifications.
700
+ color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
701
+ palette: Optional full color palette to overwrite all colors.
702
+ split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
703
+ the perturbation signature for every replicate separately.
704
+ before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
705
+ Default is set to NULL and plots cells by original class ID.
706
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
707
+ Defaults to `KO`.
708
+
709
+ Examples:
710
+ Visualizing the perturbation scores for the cells in a dataset:
711
+
712
+ >>> import pertpy as pt
713
+ >>> mdata = pt.dt.papalexi_2021()
714
+ >>> ms_pt = pt.tl.Mixscape()
715
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
716
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
717
+ >>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
718
+
719
+ Preview:
720
+ .. image:: /_static/docstring_previews/mixscape_perturbscore.png
721
+ """
722
+ if "mixscape" not in adata.uns:
723
+ raise ValueError("Please run the `mixscape` function first.")
724
+ perturbation_score = None
725
+ for key in adata.uns["mixscape"][target_gene].keys():
726
+ perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
727
+ perturbation_score_temp["name"] = key
728
+ if perturbation_score is None:
729
+ perturbation_score = copy.deepcopy(perturbation_score_temp)
730
+ else:
731
+ perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
732
+ perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
733
+ gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
734
+
735
+ # If before_mixscape is True, split densities based on original target gene classification
736
+ if before_mixscape is True:
737
+ palette = {gd: "#7d7d7d", target_gene: color}
738
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
739
+ top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
740
+ plt.close()
741
+ perturbation_score["y_jitter"] = perturbation_score["pvec"]
742
+ rng = np.random.default_rng()
743
+ perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
744
+ low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
745
+ )
746
+ perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
747
+ low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
748
+ )
749
+ # If split_by is provided, split densities based on the split_by
750
+ if split_by is not None:
751
+ sns.set_theme(style="whitegrid")
752
+ g = sns.FacetGrid(
753
+ data=perturbation_score, col=split_by, hue=split_by, palette=palette, height=5, sharey=False
754
+ )
755
+ g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, palette=palette)
756
+ g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5, palette=palette)
757
+ g.set_axis_labels("Perturbation score", "Cell density")
758
+ g.add_legend(title=split_by, fontsize=14, title_fontsize=16)
759
+ g.despine(left=True)
760
+
761
+ # If split_by is not provided, create a single plot
762
+ else:
763
+ sns.set_theme(style="whitegrid")
764
+ sns.kdeplot(
765
+ data=perturbation_score, x="pvec", hue="gene_target", fill=True, common_norm=False, palette=palette
766
+ )
767
+ sns.scatterplot(
768
+ data=perturbation_score, x="pvec", y="y_jitter", hue="gene_target", palette=palette, s=10, alpha=0.5
769
+ )
770
+ plt.xlabel("Perturbation score", fontsize=16)
771
+ plt.ylabel("Cell density", fontsize=16)
772
+ plt.title("Density Plot", fontsize=18)
773
+ plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
774
+ sns.despine()
775
+
776
+ if save:
777
+ plt.savefig(save, bbox_inches="tight")
778
+ if show:
779
+ plt.show()
780
+ if return_fig:
781
+ return plt.gcf()
782
+ if not (show or save):
783
+ return plt.gca()
784
+
785
+ # If before_mixscape is False, split densities based on mixscape classifications
786
+ else:
787
+ if palette is None:
788
+ palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
789
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
790
+ top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
791
+ plt.close()
792
+ perturbation_score["y_jitter"] = perturbation_score["pvec"]
793
+ rng = np.random.default_rng()
794
+ gd2 = list(
795
+ set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
796
+ )[0]
797
+ perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
798
+ low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
799
+ ).astype(np.float32)
800
+ perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"] = (
801
+ rng.uniform(
802
+ low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
803
+ )
804
+ )
805
+ perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
806
+ low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
807
+ )
808
+ # If split_by is provided, split densities based on the split_by
809
+ if split_by is not None:
810
+ sns.set_theme(style="whitegrid")
811
+ g = sns.FacetGrid(
812
+ data=perturbation_score, col=split_by, hue="mix", palette=palette, height=5, sharey=False
813
+ )
814
+ g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, alpha=0.7)
815
+ g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5)
816
+ g.set_axis_labels("Perturbation score", "Cell density")
817
+ g.add_legend(title="mix", fontsize=14, title_fontsize=16)
818
+ g.despine(left=True)
819
+
820
+ # If split_by is not provided, create a single plot
821
+ else:
822
+ sns.set_theme(style="whitegrid")
823
+ sns.kdeplot(
824
+ data=perturbation_score,
825
+ x="pvec",
826
+ hue="mix",
827
+ fill=True,
828
+ common_norm=False,
829
+ palette=palette,
830
+ alpha=0.7,
831
+ )
832
+ sns.scatterplot(
833
+ data=perturbation_score, x="pvec", y="y_jitter", hue="mix", palette=palette, s=10, alpha=0.5
834
+ )
835
+ plt.xlabel("Perturbation score", fontsize=16)
836
+ plt.ylabel("Cell density", fontsize=16)
837
+ plt.title("Density", fontsize=18)
838
+ plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
839
+ sns.despine()
840
+
841
+ if save:
842
+ plt.savefig(save, bbox_inches="tight")
843
+ if show:
844
+ plt.show()
845
+ if return_fig:
846
+ return plt.gcf()
847
+ if not (show or save):
848
+ return plt.gca()
849
+
850
+ def plot_violin( # pragma: no cover
851
+ self,
852
+ adata: AnnData,
853
+ target_gene_idents: str | list[str],
854
+ keys: str | Sequence[str] = "mixscape_class_p_ko",
855
+ groupby: str | None = "mixscape_class",
856
+ log: bool = False,
857
+ use_raw: bool | None = None,
858
+ stripplot: bool = True,
859
+ hue: str | None = None,
860
+ jitter: float | bool = True,
861
+ size: int = 1,
862
+ layer: str | None = None,
863
+ scale: Literal["area", "count", "width"] = "width",
864
+ order: Sequence[str] | None = None,
865
+ multi_panel: bool | None = None,
866
+ xlabel: str = "",
867
+ ylabel: str | Sequence[str] | None = None,
868
+ rotation: float | None = None,
869
+ ax: Axes | None = None,
870
+ show: bool | None = None,
871
+ save: bool | str | None = None,
872
+ **kwargs,
873
+ ):
874
+ """Violin plot using mixscape results.
875
+
876
+ Requires `pt.tl.mixscape` to be run first.
877
+
878
+ Args:
879
+ adata: The annotated data object.
880
+ target_gene_idents: Target gene name to plot.
881
+ keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
882
+ groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
883
+ log: Plot on logarithmic axis.
884
+ use_raw: Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
885
+ stripplot: Add a stripplot on top of the violin plot.
886
+ order: Order in which to show the categories.
887
+ xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
888
+ ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
889
+ If `None` and `groubpy` is not `None`, defaults to `keys`.
890
+ show: Show the plot, do not return axis.
891
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
892
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
893
+ ax: A matplotlib axes object. Only works if plotting a single component.
894
+ **kwargs: Additional arguments to `seaborn.violinplot`.
895
+
896
+ Returns:
897
+ A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
898
+
899
+ Examples:
900
+ >>> import pertpy as pt
901
+ >>> mdata = pt.dt.papalexi_2021()
902
+ >>> ms_pt = pt.tl.Mixscape()
903
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
904
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
905
+ >>> ms_pt.plot_violin(
906
+ ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
907
+ ... )
908
+
909
+ Preview:
910
+ .. image:: /_static/docstring_previews/mixscape_violin.png
911
+ """
912
+ if isinstance(target_gene_idents, str):
913
+ mixscape_class_mask = adata.obs[groupby] == target_gene_idents
914
+ elif isinstance(target_gene_idents, list):
915
+ mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
916
+ for ident in target_gene_idents:
917
+ mixscape_class_mask |= adata.obs[groupby] == ident
918
+ adata = adata[mixscape_class_mask]
919
+
920
+ sanitize_anndata(adata)
921
+ use_raw = _check_use_raw(adata, use_raw)
922
+ if isinstance(keys, str):
923
+ keys = [keys]
924
+ keys = list(OrderedDict.fromkeys(keys)) # remove duplicates, preserving the order
925
+
926
+ if isinstance(ylabel, str | type(None)):
927
+ ylabel = [ylabel] * (1 if groupby is None else len(keys))
928
+ if groupby is None:
929
+ if len(ylabel) != 1:
930
+ raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
931
+ elif len(ylabel) != len(keys):
932
+ raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
933
+
934
+ if groupby is not None:
935
+ if hue is not None:
936
+ obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
937
+ else:
938
+ obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
939
+
940
+ else:
941
+ obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
942
+ if groupby is None:
943
+ obs_tidy = pd.melt(obs_df, value_vars=keys)
944
+ x = "variable"
945
+ ys = ["value"]
946
+ else:
947
+ obs_tidy = obs_df
948
+ x = groupby
949
+ ys = keys
950
+
951
+ if multi_panel and groupby is None and len(ys) == 1:
952
+ # This is a quick and dirty way for adapting scales across several
953
+ # keys if groupby is None.
954
+ y = ys[0]
955
+
956
+ g = sns.catplot(
957
+ y=y,
958
+ data=obs_tidy,
959
+ kind="violin",
960
+ scale=scale,
961
+ col=x,
962
+ col_order=keys,
963
+ sharey=False,
964
+ order=keys,
965
+ cut=0,
966
+ inner=None,
967
+ **kwargs,
968
+ )
969
+
970
+ if stripplot:
971
+ grouped_df = obs_tidy.groupby(x)
972
+ for ax_id, key in zip(range(g.axes.shape[1]), keys, strict=False):
973
+ sns.stripplot(
974
+ y=y,
975
+ data=grouped_df.get_group(key),
976
+ jitter=jitter,
977
+ size=size,
978
+ color="black",
979
+ ax=g.axes[0, ax_id],
980
+ )
981
+ if log:
982
+ g.set(yscale="log")
983
+ g.set_titles(col_template="{col_name}").set_xlabels("")
984
+ if rotation is not None:
985
+ for ax in g.axes[0]:
986
+ ax.tick_params(axis="x", labelrotation=rotation)
987
+ else:
988
+ # set by default the violin plot cut=0 to limit the extend
989
+ # of the violin plot (see stacked_violin code) for more info.
990
+ kwargs.setdefault("cut", 0)
991
+ kwargs.setdefault("inner")
992
+
993
+ if ax is None:
994
+ axs, _, _, _ = _utils.setup_axes(
995
+ ax=ax,
996
+ panels=["x"] if groupby is None else keys,
997
+ show_ticks=True,
998
+ right_margin=0.3,
999
+ )
1000
+ else:
1001
+ axs = [ax]
1002
+ for ax, y, ylab in zip(axs, ys, ylabel, strict=False):
1003
+ ax = sns.violinplot(
1004
+ x=x,
1005
+ y=y,
1006
+ data=obs_tidy,
1007
+ order=order,
1008
+ orient="vertical",
1009
+ scale=scale,
1010
+ ax=ax,
1011
+ hue=hue,
1012
+ **kwargs,
1013
+ )
1014
+ # Get the handles and labels.
1015
+ handles, labels = ax.get_legend_handles_labels()
1016
+ if stripplot:
1017
+ ax = sns.stripplot(
1018
+ x=x,
1019
+ y=y,
1020
+ data=obs_tidy,
1021
+ order=order,
1022
+ jitter=jitter,
1023
+ color="black",
1024
+ size=size,
1025
+ ax=ax,
1026
+ hue=hue,
1027
+ dodge=True,
1028
+ )
1029
+ if xlabel == "" and groupby is not None and rotation is None:
1030
+ xlabel = groupby.replace("_", " ")
1031
+ ax.set_xlabel(xlabel)
1032
+ if ylab is not None:
1033
+ ax.set_ylabel(ylab)
1034
+
1035
+ if log:
1036
+ ax.set_yscale("log")
1037
+ if rotation is not None:
1038
+ ax.tick_params(axis="x", labelrotation=rotation)
1039
+
1040
+ show = settings.autoshow if show is None else show
1041
+ if hue is not None and stripplot is True:
1042
+ plt.legend(handles, labels)
1043
+ _utils.savefig_or_show("mixscape_violin", show=show, save=save)
1044
+
1045
+ if not show:
1046
+ if multi_panel and groupby is None and len(ys) == 1:
1047
+ return g
1048
+ elif len(axs) == 1:
1049
+ return axs[0]
1050
+ else:
1051
+ return axs
1052
+
1053
+ def plot_lda( # pragma: no cover
1054
+ self,
1055
+ adata: AnnData,
1056
+ control: str,
1057
+ mixscape_class: str = "mixscape_class",
1058
+ mixscape_class_global: str = "mixscape_class_global",
1059
+ perturbation_type: str | None = "KO",
1060
+ lda_key: str | None = "mixscape_lda",
1061
+ n_components: int | None = None,
1062
+ color_map: Colormap | str | None = None,
1063
+ palette: str | Sequence[str] | None = None,
1064
+ return_fig: bool | None = None,
1065
+ ax: Axes | None = None,
1066
+ show: bool | None = None,
1067
+ save: bool | str | None = None,
1068
+ **kwds,
1069
+ ) -> None:
1070
+ """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
1071
+
1072
+ Args:
1073
+ adata: The annotated data object.
1074
+ control: Control category from the `pert_key` column.
1075
+ mixscape_class: The column of `.obs` with the mixscape classification result.
1076
+ mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
1077
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
1078
+ Defaults to 'KO'.
1079
+ lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
1080
+ n_components: The number of dimensions of the embedding.
1081
+ show: Show the plot, do not return axis.
1082
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
1083
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
1084
+ **kwds: Additional arguments to `scanpy.pl.umap`.
1085
+
1086
+ Examples:
1087
+ >>> import pertpy as pt
1088
+ >>> mdata = pt.dt.papalexi_2021()
1089
+ >>> ms_pt = pt.tl.Mixscape()
1090
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
1091
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1092
+ >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1093
+ >>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
1094
+
1095
+ Preview:
1096
+ .. image:: /_static/docstring_previews/mixscape_lda.png
1097
+ """
1098
+ if mixscape_class not in adata.obs:
1099
+ raise ValueError(f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first.')
1100
+ if lda_key not in adata.uns:
1101
+ raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Please run the `lda` function first.')
1102
+
1103
+ adata_subset = adata[
1104
+ (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
1105
+ ].copy()
1106
+ adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
1107
+ if n_components is None:
1108
+ n_components = adata_subset.uns[lda_key].shape[1]
1109
+ sc.pp.neighbors(adata_subset, use_rep=lda_key)
1110
+ sc.tl.umap(adata_subset, n_components=n_components)
1111
+ sc.pl.umap(
1112
+ adata_subset,
1113
+ color=mixscape_class,
1114
+ palette=palette,
1115
+ color_map=color_map,
1116
+ return_fig=return_fig,
1117
+ show=show,
1118
+ save=save,
1119
+ ax=ax,
1120
+ **kwds,
1121
+ )