pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_mixscape.py CHANGED
@@ -1,12 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
- import warnings
4
- from typing import TYPE_CHECKING
3
+ import copy
4
+ from collections import OrderedDict
5
+ from typing import TYPE_CHECKING, Literal
5
6
 
7
+ import matplotlib.pyplot as plt
6
8
  import numpy as np
7
9
  import pandas as pd
8
10
  import scanpy as sc
9
- from rich import print
11
+ import seaborn as sns
12
+ from scanpy import get
13
+ from scanpy._settings import settings
14
+ from scanpy._utils import _check_use_raw, sanitize_anndata
15
+ from scanpy.plotting import _utils
10
16
  from scanpy.tools._utils import _choose_representation
11
17
  from scipy.sparse import csr_matrix, issparse, spmatrix
12
18
  from sklearn.mixture import GaussianMixture
@@ -14,11 +20,13 @@ from sklearn.mixture import GaussianMixture
14
20
  import pertpy as pt
15
21
 
16
22
  if TYPE_CHECKING:
23
+ from collections.abc import Sequence
24
+
17
25
  from anndata import AnnData
26
+ from matplotlib.axes import Axes
27
+ from matplotlib.colors import Colormap
18
28
  from scipy import sparse
19
29
 
20
- warnings.simplefilter("ignore")
21
-
22
30
 
23
31
  class Mixscape:
24
32
  """Python implementation of Mixscape."""
@@ -65,15 +73,15 @@ class Mixscape:
65
73
 
66
74
  Returns:
67
75
  If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`.
68
- Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
76
+ Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
69
77
 
70
78
  Examples:
71
79
  Calcutate perturbation signature for each cell in the dataset:
72
80
 
73
81
  >>> import pertpy as pt
74
82
  >>> mdata = pt.dt.papalexi_2021()
75
- >>> mixscape_identifier = pt.tl.Mixscape()
76
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
83
+ >>> ms_pt = pt.tl.Mixscape()
84
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
77
85
  """
78
86
  if copy:
79
87
  adata = adata.copy()
@@ -86,18 +94,17 @@ class Mixscape:
86
94
  split_masks = [np.full(adata.n_obs, True, dtype=bool)]
87
95
  else:
88
96
  split_obs = adata.obs[split_by]
89
- cats = split_obs.unique()
90
- split_masks = [split_obs == cat for cat in cats]
97
+ split_masks = [split_obs == cat for cat in split_obs.unique()]
91
98
 
92
- R = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
99
+ representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
93
100
 
94
101
  for split_mask in split_masks:
95
102
  control_mask_split = control_mask & split_mask
96
103
 
97
- R_split = R[split_mask]
98
- R_control = R[control_mask_split]
104
+ R_split = representation[split_mask]
105
+ R_control = representation[control_mask_split]
99
106
 
100
- from pynndescent import NNDescent # saves a lot of import time
107
+ from pynndescent import NNDescent
101
108
 
102
109
  eps = kwargs.pop("epsilon", 0.1)
103
110
  nn_index = NNDescent(R_control, **kwargs)
@@ -161,7 +168,6 @@ class Mixscape:
161
168
 
162
169
  Args:
163
170
  adata: The annotated data object.
164
- pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
165
171
  labels: The column of `.obs` with target gene labels.
166
172
  control: Control category from the `pert_key` column.
167
173
  new_class_name: Name of mixscape classification to be stored in `.obs`.
@@ -177,26 +183,26 @@ class Mixscape:
177
183
 
178
184
  Returns:
179
185
  If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
180
- Otherwise writes the results directly to `.obs` of the provided `adata`.
186
+ Otherwise, writes the results directly to `.obs` of the provided `adata`.
181
187
 
182
- mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
183
- Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
188
+ - mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
189
+ Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
184
190
 
185
- mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
186
- Global classification result (perturbed, NP or NT)
191
+ - mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
192
+ Global classification result (perturbed, NP or NT).
187
193
 
188
- mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
189
- Posterior probabilities used to determine if a cell is KO (default).
190
- Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP
194
+ - mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
195
+ Posterior probabilities used to determine if a cell is KO (default).
196
+ Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP.
191
197
 
192
198
  Examples:
193
199
  Calcutate perturbation signature for each cell in the dataset:
194
200
 
195
201
  >>> import pertpy as pt
196
202
  >>> mdata = pt.dt.papalexi_2021()
197
- >>> mixscape_identifier = pt.tl.Mixscape()
198
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
199
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
203
+ >>> ms_pt = pt.tl.Mixscape()
204
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
205
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
200
206
  """
201
207
  if copy:
202
208
  adata = adata.copy()
@@ -220,10 +226,9 @@ class Mixscape:
220
226
  try:
221
227
  X = adata_comp.layers["X_pert"]
222
228
  except KeyError:
223
- print(
224
- '[bold yellow]No "X_pert" found in .layers! -- Please run pert_sign first to calculate perturbation signature!'
225
- )
226
- raise
229
+ raise KeyError(
230
+ "No 'X_pert' found in .layers! Please run pert_sign first to calculate perturbation signature!"
231
+ ) from None
227
232
  # initialize return variables
228
233
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
229
234
  adata.obs[new_class_name] = adata.obs[labels].astype(str)
@@ -305,9 +310,9 @@ class Mixscape:
305
310
  old_classes = adata.obs[new_class_name][all_cells]
306
311
  n_iter += 1
307
312
 
308
- adata.obs.loc[
309
- (adata.obs[new_class_name] == gene) & split_mask, new_class_name
310
- ] = f"{gene} {perturbation_type}"
313
+ adata.obs.loc[(adata.obs[new_class_name] == gene) & split_mask, new_class_name] = (
314
+ f"{gene} {perturbation_type}"
315
+ )
311
316
 
312
317
  adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
313
318
  adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
@@ -342,15 +347,17 @@ class Mixscape:
342
347
  control: Control category from the `pert_key` column. Defaults to 'NT'.
343
348
  n_comps: Number of principal components to use. Defaults to 10.
344
349
  min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
345
- logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. Defaults to 0.25.
350
+ logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
351
+ Defaults to 0.25.
346
352
  split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
347
353
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
348
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to KO.
354
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
355
+ Defaults to KO.
349
356
  copy: Determines whether a copy of the `adata` is returned.
350
357
 
351
358
  Returns:
352
359
  If `copy=True`, returns the copy of `adata` with the LDA result in `.uns`.
353
- Otherwise writes the results directly to `.uns` of the provided `adata`.
360
+ Otherwise, writes the results directly to `.uns` of the provided `adata`.
354
361
 
355
362
  mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`).
356
363
  LDA result.
@@ -360,10 +367,10 @@ class Mixscape:
360
367
 
361
368
  >>> import pertpy as pt
362
369
  >>> mdata = pt.dt.papalexi_2021()
363
- >>> mixscape_identifier = pt.tl.Mixscape()
364
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
365
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
366
- >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
370
+ >>> ms_pt = pt.tl.Mixscape()
371
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
372
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
373
+ >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
367
374
  """
368
375
  if copy:
369
376
  adata = adata.copy()
@@ -437,7 +444,7 @@ class Mixscape:
437
444
  min_de_genes: float,
438
445
  logfc_threshold: float,
439
446
  ) -> dict[tuple, np.ndarray]:
440
- """determine gene sets across all splits/groups through differential gene expression
447
+ """Determine gene sets across all splits/groups through differential gene expression
441
448
 
442
449
  Args:
443
450
  adata: :class:`~anndata.AnnData` object
@@ -469,15 +476,6 @@ class Mixscape:
469
476
  return perturbation_markers
470
477
 
471
478
  def _get_column_indices(self, adata, col_names):
472
- """Fetches the column indices in X for a given list of column names
473
-
474
- Args:
475
- adata: :class:`~anndata.AnnData` object
476
- col_names: Column names to extract the indices for
477
-
478
- Returns:
479
- Set of column indices
480
- """
481
479
  if isinstance(col_names, str): # pragma: no cover
482
480
  col_names = [col_names]
483
481
 
@@ -501,3 +499,623 @@ class Mixscape:
501
499
  sd = X.std()
502
500
 
503
501
  return [mu, sd]
502
+
503
+ def plot_barplot( # pragma: no cover
504
+ self,
505
+ adata: AnnData,
506
+ guide_rna_column: str,
507
+ mixscape_class_global: str = "mixscape_class_global",
508
+ axis_text_x_size: int = 8,
509
+ axis_text_y_size: int = 6,
510
+ axis_title_size: int = 8,
511
+ legend_title_size: int = 8,
512
+ legend_text_size: int = 8,
513
+ return_fig: bool | None = None,
514
+ ax: Axes | None = None,
515
+ show: bool | None = None,
516
+ save: bool | str | None = None,
517
+ ):
518
+ """Barplot to visualize perturbation scores calculated by the `mixscape` function.
519
+
520
+ Args:
521
+ adata: The annotated data object.
522
+ guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
523
+ The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
524
+ mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
525
+ show: Show the plot, do not return axis.
526
+ save: If True or a str, save the figure. A string is appended to the default filename.
527
+ Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
528
+
529
+ Returns:
530
+ If `show==False`, return a :class:`~matplotlib.axes.Axes.
531
+
532
+ Examples:
533
+ >>> import pertpy as pt
534
+ >>> mdata = pt.dt.papalexi_2021()
535
+ >>> ms_pt = pt.tl.Mixscape()
536
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
537
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
538
+ >>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
539
+
540
+ Preview:
541
+ .. image:: /_static/docstring_previews/mixscape_barplot.png
542
+ """
543
+ if mixscape_class_global not in adata.obs:
544
+ raise ValueError("Please run the `mixscape` function first.")
545
+ count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
546
+ all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
547
+ KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
548
+ KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
549
+
550
+ new_levels = KO_cells_percentage[guide_rna_column]
551
+ all_cells_percentage[guide_rna_column] = pd.Categorical(
552
+ all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
553
+ )
554
+ all_cells_percentage[mixscape_class_global] = pd.Categorical(
555
+ all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
556
+ )
557
+ all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
558
+ all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
559
+ all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
560
+ NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
561
+
562
+ if show:
563
+ color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
564
+ unique_genes = NP_KO_cells["gene"].unique()
565
+ fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True)
566
+ for i, gene in enumerate(unique_genes):
567
+ ax = axs[int(i / 5), i % 5]
568
+ grouped_df = (
569
+ NP_KO_cells[NP_KO_cells["gene"] == gene]
570
+ .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
571
+ .sum()
572
+ .unstack()
573
+ )
574
+ grouped_df.plot(
575
+ kind="bar",
576
+ stacked=True,
577
+ color=[color_mapping[col] for col in grouped_df.columns],
578
+ ax=ax,
579
+ width=0.8,
580
+ legend=False,
581
+ )
582
+ ax.set_title(
583
+ gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size
584
+ )
585
+ ax.set(xlabel="sgRNA", ylabel="% of cells")
586
+ sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
587
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
588
+ ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
589
+ fig.subplots_adjust(right=0.8)
590
+ fig.subplots_adjust(hspace=0.5, wspace=0.5)
591
+ ax.legend(
592
+ title="mixscape_class_global",
593
+ loc="center right",
594
+ bbox_to_anchor=(2.2, 3.5),
595
+ frameon=True,
596
+ fontsize=legend_text_size,
597
+ title_fontsize=legend_title_size,
598
+ )
599
+
600
+ plt.tight_layout()
601
+ _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
602
+
603
+ def plot_heatmap( # pragma: no cover
604
+ self,
605
+ adata: AnnData,
606
+ labels: str,
607
+ target_gene: str,
608
+ control: str,
609
+ layer: str | None = None,
610
+ method: str | None = "wilcoxon",
611
+ subsample_number: int | None = 900,
612
+ vmin: float | None = -2,
613
+ vmax: float | None = 2,
614
+ return_fig: bool | None = None,
615
+ show: bool | None = None,
616
+ save: bool | str | None = None,
617
+ **kwds,
618
+ ) -> Axes | None:
619
+ """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
620
+
621
+ Args:
622
+ adata: The annotated data object.
623
+ labels: The column of `.obs` with target gene labels.
624
+ target_gene: Target gene name to visualize heatmap for.
625
+ control: Control category from the `pert_key` column.
626
+ layer: Key from `adata.layers` whose value will be used to perform tests on.
627
+ method: The default method is 'wilcoxon', see `method` parameter in `scanpy.tl.rank_genes_groups` for more options.
628
+ subsample_number: Subsample to this number of observations.
629
+ vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
630
+ vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
631
+ show: Show the plot, do not return axis.
632
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
633
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
634
+ ax: A matplotlib axes object. Only works if plotting a single component.
635
+ **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
636
+
637
+ Returns:
638
+ If `show==False`, return a :class:`~matplotlib.axes.Axes`.
639
+
640
+ Examples:
641
+ >>> import pertpy as pt
642
+ >>> mdata = pt.dt.papalexi_2021()
643
+ >>> ms_pt = pt.tl.Mixscape()
644
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
645
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
646
+ >>> ms_pt.plot_heatmap(
647
+ ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
648
+ ... )
649
+
650
+ Preview:
651
+ .. image:: /_static/docstring_previews/mixscape_heatmap.png
652
+ """
653
+ if "mixscape_class" not in adata.obs:
654
+ raise ValueError("Please run `pt.tl.mixscape` first.")
655
+ adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
656
+ sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
657
+ sc.pp.scale(adata_subset, max_value=vmax)
658
+ sc.pp.subsample(adata_subset, n_obs=subsample_number)
659
+
660
+ return sc.pl.rank_genes_groups_heatmap(
661
+ adata_subset,
662
+ groupby="mixscape_class",
663
+ vmin=vmin,
664
+ vmax=vmax,
665
+ n_genes=20,
666
+ groups=["NT"],
667
+ return_fig=return_fig,
668
+ show=show,
669
+ save=save,
670
+ **kwds,
671
+ )
672
+
673
+ def plot_perturbscore( # pragma: no cover
674
+ self,
675
+ adata: AnnData,
676
+ labels: str,
677
+ target_gene: str,
678
+ mixscape_class: str = "mixscape_class",
679
+ color: str = "orange",
680
+ palette: dict[str, str] = None,
681
+ split_by: str = None,
682
+ before_mixscape: bool = False,
683
+ perturbation_type: str = "KO",
684
+ return_fig: bool | None = None,
685
+ ax: Axes | None = None,
686
+ show: bool | None = None,
687
+ save: bool | str | None = None,
688
+ ) -> None:
689
+ """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
690
+
691
+ Requires `pt.tl.mixscape` to be run first.
692
+
693
+ https://satijalab.org/seurat/reference/plotperturbscore
694
+
695
+ Args:
696
+ adata: The annotated data object.
697
+ labels: The column of `.obs` with target gene labels.
698
+ target_gene: Target gene name to visualize perturbation scores for.
699
+ mixscape_class: The column of `.obs` with mixscape classifications.
700
+ color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
701
+ palette: Optional full color palette to overwrite all colors.
702
+ split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
703
+ the perturbation signature for every replicate separately.
704
+ before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
705
+ Default is set to NULL and plots cells by original class ID.
706
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
707
+ Defaults to `KO`.
708
+
709
+ Examples:
710
+ Visualizing the perturbation scores for the cells in a dataset:
711
+
712
+ >>> import pertpy as pt
713
+ >>> mdata = pt.dt.papalexi_2021()
714
+ >>> ms_pt = pt.tl.Mixscape()
715
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
716
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
717
+ >>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
718
+
719
+ Preview:
720
+ .. image:: /_static/docstring_previews/mixscape_perturbscore.png
721
+ """
722
+ if "mixscape" not in adata.uns:
723
+ raise ValueError("Please run the `mixscape` function first.")
724
+ perturbation_score = None
725
+ for key in adata.uns["mixscape"][target_gene].keys():
726
+ perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
727
+ perturbation_score_temp["name"] = key
728
+ if perturbation_score is None:
729
+ perturbation_score = copy.deepcopy(perturbation_score_temp)
730
+ else:
731
+ perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
732
+ perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
733
+ gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
734
+
735
+ # If before_mixscape is True, split densities based on original target gene classification
736
+ if before_mixscape is True:
737
+ palette = {gd: "#7d7d7d", target_gene: color}
738
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
739
+ top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
740
+ plt.close()
741
+ perturbation_score["y_jitter"] = perturbation_score["pvec"]
742
+ rng = np.random.default_rng()
743
+ perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
744
+ low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
745
+ )
746
+ perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
747
+ low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
748
+ )
749
+ # If split_by is provided, split densities based on the split_by
750
+ if split_by is not None:
751
+ sns.set_theme(style="whitegrid")
752
+ g = sns.FacetGrid(
753
+ data=perturbation_score, col=split_by, hue=split_by, palette=palette, height=5, sharey=False
754
+ )
755
+ g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, palette=palette)
756
+ g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5, palette=palette)
757
+ g.set_axis_labels("Perturbation score", "Cell density")
758
+ g.add_legend(title=split_by, fontsize=14, title_fontsize=16)
759
+ g.despine(left=True)
760
+
761
+ # If split_by is not provided, create a single plot
762
+ else:
763
+ sns.set_theme(style="whitegrid")
764
+ sns.kdeplot(
765
+ data=perturbation_score, x="pvec", hue="gene_target", fill=True, common_norm=False, palette=palette
766
+ )
767
+ sns.scatterplot(
768
+ data=perturbation_score, x="pvec", y="y_jitter", hue="gene_target", palette=palette, s=10, alpha=0.5
769
+ )
770
+ plt.xlabel("Perturbation score", fontsize=16)
771
+ plt.ylabel("Cell density", fontsize=16)
772
+ plt.title("Density Plot", fontsize=18)
773
+ plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
774
+ sns.despine()
775
+
776
+ if save:
777
+ plt.savefig(save, bbox_inches="tight")
778
+ if show:
779
+ plt.show()
780
+ if return_fig:
781
+ return plt.gcf()
782
+ if not (show or save):
783
+ return plt.gca()
784
+
785
+ # If before_mixscape is False, split densities based on mixscape classifications
786
+ else:
787
+ if palette is None:
788
+ palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
789
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
790
+ top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
791
+ plt.close()
792
+ perturbation_score["y_jitter"] = perturbation_score["pvec"]
793
+ rng = np.random.default_rng()
794
+ gd2 = list(
795
+ set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
796
+ )[0]
797
+ perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
798
+ low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
799
+ ).astype(np.float32)
800
+ perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"] = (
801
+ rng.uniform(
802
+ low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
803
+ )
804
+ )
805
+ perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
806
+ low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
807
+ )
808
+ # If split_by is provided, split densities based on the split_by
809
+ if split_by is not None:
810
+ sns.set_theme(style="whitegrid")
811
+ g = sns.FacetGrid(
812
+ data=perturbation_score, col=split_by, hue="mix", palette=palette, height=5, sharey=False
813
+ )
814
+ g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, alpha=0.7)
815
+ g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5)
816
+ g.set_axis_labels("Perturbation score", "Cell density")
817
+ g.add_legend(title="mix", fontsize=14, title_fontsize=16)
818
+ g.despine(left=True)
819
+
820
+ # If split_by is not provided, create a single plot
821
+ else:
822
+ sns.set_theme(style="whitegrid")
823
+ sns.kdeplot(
824
+ data=perturbation_score,
825
+ x="pvec",
826
+ hue="mix",
827
+ fill=True,
828
+ common_norm=False,
829
+ palette=palette,
830
+ alpha=0.7,
831
+ )
832
+ sns.scatterplot(
833
+ data=perturbation_score, x="pvec", y="y_jitter", hue="mix", palette=palette, s=10, alpha=0.5
834
+ )
835
+ plt.xlabel("Perturbation score", fontsize=16)
836
+ plt.ylabel("Cell density", fontsize=16)
837
+ plt.title("Density", fontsize=18)
838
+ plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
839
+ sns.despine()
840
+
841
+ if save:
842
+ plt.savefig(save, bbox_inches="tight")
843
+ if show:
844
+ plt.show()
845
+ if return_fig:
846
+ return plt.gcf()
847
+ if not (show or save):
848
+ return plt.gca()
849
+
850
+ def plot_violin( # pragma: no cover
851
+ self,
852
+ adata: AnnData,
853
+ target_gene_idents: str | list[str],
854
+ keys: str | Sequence[str] = "mixscape_class_p_ko",
855
+ groupby: str | None = "mixscape_class",
856
+ log: bool = False,
857
+ use_raw: bool | None = None,
858
+ stripplot: bool = True,
859
+ hue: str | None = None,
860
+ jitter: float | bool = True,
861
+ size: int = 1,
862
+ layer: str | None = None,
863
+ scale: Literal["area", "count", "width"] = "width",
864
+ order: Sequence[str] | None = None,
865
+ multi_panel: bool | None = None,
866
+ xlabel: str = "",
867
+ ylabel: str | Sequence[str] | None = None,
868
+ rotation: float | None = None,
869
+ ax: Axes | None = None,
870
+ show: bool | None = None,
871
+ save: bool | str | None = None,
872
+ **kwargs,
873
+ ):
874
+ """Violin plot using mixscape results.
875
+
876
+ Requires `pt.tl.mixscape` to be run first.
877
+
878
+ Args:
879
+ adata: The annotated data object.
880
+ target_gene_idents: Target gene name to plot.
881
+ keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
882
+ groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
883
+ log: Plot on logarithmic axis.
884
+ use_raw: Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
885
+ stripplot: Add a stripplot on top of the violin plot.
886
+ order: Order in which to show the categories.
887
+ xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
888
+ ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
889
+ If `None` and `groubpy` is not `None`, defaults to `keys`.
890
+ show: Show the plot, do not return axis.
891
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
892
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
893
+ ax: A matplotlib axes object. Only works if plotting a single component.
894
+ **kwargs: Additional arguments to `seaborn.violinplot`.
895
+
896
+ Returns:
897
+ A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
898
+
899
+ Examples:
900
+ >>> import pertpy as pt
901
+ >>> mdata = pt.dt.papalexi_2021()
902
+ >>> ms_pt = pt.tl.Mixscape()
903
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
904
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
905
+ >>> ms_pt.plot_violin(
906
+ ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
907
+ ... )
908
+
909
+ Preview:
910
+ .. image:: /_static/docstring_previews/mixscape_violin.png
911
+ """
912
+ if isinstance(target_gene_idents, str):
913
+ mixscape_class_mask = adata.obs[groupby] == target_gene_idents
914
+ elif isinstance(target_gene_idents, list):
915
+ mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
916
+ for ident in target_gene_idents:
917
+ mixscape_class_mask |= adata.obs[groupby] == ident
918
+ adata = adata[mixscape_class_mask]
919
+
920
+ sanitize_anndata(adata)
921
+ use_raw = _check_use_raw(adata, use_raw)
922
+ if isinstance(keys, str):
923
+ keys = [keys]
924
+ keys = list(OrderedDict.fromkeys(keys)) # remove duplicates, preserving the order
925
+
926
+ if isinstance(ylabel, str | type(None)):
927
+ ylabel = [ylabel] * (1 if groupby is None else len(keys))
928
+ if groupby is None:
929
+ if len(ylabel) != 1:
930
+ raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
931
+ elif len(ylabel) != len(keys):
932
+ raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
933
+
934
+ if groupby is not None:
935
+ if hue is not None:
936
+ obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
937
+ else:
938
+ obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
939
+
940
+ else:
941
+ obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
942
+ if groupby is None:
943
+ obs_tidy = pd.melt(obs_df, value_vars=keys)
944
+ x = "variable"
945
+ ys = ["value"]
946
+ else:
947
+ obs_tidy = obs_df
948
+ x = groupby
949
+ ys = keys
950
+
951
+ if multi_panel and groupby is None and len(ys) == 1:
952
+ # This is a quick and dirty way for adapting scales across several
953
+ # keys if groupby is None.
954
+ y = ys[0]
955
+
956
+ g = sns.catplot(
957
+ y=y,
958
+ data=obs_tidy,
959
+ kind="violin",
960
+ scale=scale,
961
+ col=x,
962
+ col_order=keys,
963
+ sharey=False,
964
+ order=keys,
965
+ cut=0,
966
+ inner=None,
967
+ **kwargs,
968
+ )
969
+
970
+ if stripplot:
971
+ grouped_df = obs_tidy.groupby(x)
972
+ for ax_id, key in zip(range(g.axes.shape[1]), keys, strict=False):
973
+ sns.stripplot(
974
+ y=y,
975
+ data=grouped_df.get_group(key),
976
+ jitter=jitter,
977
+ size=size,
978
+ color="black",
979
+ ax=g.axes[0, ax_id],
980
+ )
981
+ if log:
982
+ g.set(yscale="log")
983
+ g.set_titles(col_template="{col_name}").set_xlabels("")
984
+ if rotation is not None:
985
+ for ax in g.axes[0]:
986
+ ax.tick_params(axis="x", labelrotation=rotation)
987
+ else:
988
+ # set by default the violin plot cut=0 to limit the extend
989
+ # of the violin plot (see stacked_violin code) for more info.
990
+ kwargs.setdefault("cut", 0)
991
+ kwargs.setdefault("inner")
992
+
993
+ if ax is None:
994
+ axs, _, _, _ = _utils.setup_axes(
995
+ ax=ax,
996
+ panels=["x"] if groupby is None else keys,
997
+ show_ticks=True,
998
+ right_margin=0.3,
999
+ )
1000
+ else:
1001
+ axs = [ax]
1002
+ for ax, y, ylab in zip(axs, ys, ylabel, strict=False):
1003
+ ax = sns.violinplot(
1004
+ x=x,
1005
+ y=y,
1006
+ data=obs_tidy,
1007
+ order=order,
1008
+ orient="vertical",
1009
+ scale=scale,
1010
+ ax=ax,
1011
+ hue=hue,
1012
+ **kwargs,
1013
+ )
1014
+ # Get the handles and labels.
1015
+ handles, labels = ax.get_legend_handles_labels()
1016
+ if stripplot:
1017
+ ax = sns.stripplot(
1018
+ x=x,
1019
+ y=y,
1020
+ data=obs_tidy,
1021
+ order=order,
1022
+ jitter=jitter,
1023
+ color="black",
1024
+ size=size,
1025
+ ax=ax,
1026
+ hue=hue,
1027
+ dodge=True,
1028
+ )
1029
+ if xlabel == "" and groupby is not None and rotation is None:
1030
+ xlabel = groupby.replace("_", " ")
1031
+ ax.set_xlabel(xlabel)
1032
+ if ylab is not None:
1033
+ ax.set_ylabel(ylab)
1034
+
1035
+ if log:
1036
+ ax.set_yscale("log")
1037
+ if rotation is not None:
1038
+ ax.tick_params(axis="x", labelrotation=rotation)
1039
+
1040
+ show = settings.autoshow if show is None else show
1041
+ if hue is not None and stripplot is True:
1042
+ plt.legend(handles, labels)
1043
+ _utils.savefig_or_show("mixscape_violin", show=show, save=save)
1044
+
1045
+ if not show:
1046
+ if multi_panel and groupby is None and len(ys) == 1:
1047
+ return g
1048
+ elif len(axs) == 1:
1049
+ return axs[0]
1050
+ else:
1051
+ return axs
1052
+
1053
+ def plot_lda( # pragma: no cover
1054
+ self,
1055
+ adata: AnnData,
1056
+ control: str,
1057
+ mixscape_class: str = "mixscape_class",
1058
+ mixscape_class_global: str = "mixscape_class_global",
1059
+ perturbation_type: str | None = "KO",
1060
+ lda_key: str | None = "mixscape_lda",
1061
+ n_components: int | None = None,
1062
+ color_map: Colormap | str | None = None,
1063
+ palette: str | Sequence[str] | None = None,
1064
+ return_fig: bool | None = None,
1065
+ ax: Axes | None = None,
1066
+ show: bool | None = None,
1067
+ save: bool | str | None = None,
1068
+ **kwds,
1069
+ ) -> None:
1070
+ """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
1071
+
1072
+ Args:
1073
+ adata: The annotated data object.
1074
+ control: Control category from the `pert_key` column.
1075
+ mixscape_class: The column of `.obs` with the mixscape classification result.
1076
+ mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
1077
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
1078
+ Defaults to 'KO'.
1079
+ lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
1080
+ n_components: The number of dimensions of the embedding.
1081
+ show: Show the plot, do not return axis.
1082
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
1083
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
1084
+ **kwds: Additional arguments to `scanpy.pl.umap`.
1085
+
1086
+ Examples:
1087
+ >>> import pertpy as pt
1088
+ >>> mdata = pt.dt.papalexi_2021()
1089
+ >>> ms_pt = pt.tl.Mixscape()
1090
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
1091
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1092
+ >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1093
+ >>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
1094
+
1095
+ Preview:
1096
+ .. image:: /_static/docstring_previews/mixscape_lda.png
1097
+ """
1098
+ if mixscape_class not in adata.obs:
1099
+ raise ValueError(f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first.')
1100
+ if lda_key not in adata.uns:
1101
+ raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Please run the `lda` function first.')
1102
+
1103
+ adata_subset = adata[
1104
+ (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
1105
+ ].copy()
1106
+ adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
1107
+ if n_components is None:
1108
+ n_components = adata_subset.uns[lda_key].shape[1]
1109
+ sc.pp.neighbors(adata_subset, use_rep=lda_key)
1110
+ sc.tl.umap(adata_subset, n_components=n_components)
1111
+ sc.pl.umap(
1112
+ adata_subset,
1113
+ color=mixscape_class,
1114
+ palette=palette,
1115
+ color_map=color_map,
1116
+ return_fig=return_fig,
1117
+ show=show,
1118
+ save=save,
1119
+ ax=ax,
1120
+ **kwds,
1121
+ )