pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_mixscape.py CHANGED
@@ -1,12 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
- import warnings
4
- from typing import TYPE_CHECKING
3
+ import copy
4
+ from collections import OrderedDict
5
+ from typing import TYPE_CHECKING, Literal
5
6
 
7
+ import matplotlib.pyplot as plt
6
8
  import numpy as np
7
9
  import pandas as pd
8
10
  import scanpy as sc
9
- from rich import print
11
+ import seaborn as sns
12
+ from scanpy import get
13
+ from scanpy._settings import settings
14
+ from scanpy._utils import _check_use_raw, sanitize_anndata
15
+ from scanpy.plotting import _utils
10
16
  from scanpy.tools._utils import _choose_representation
11
17
  from scipy.sparse import csr_matrix, issparse, spmatrix
12
18
  from sklearn.mixture import GaussianMixture
@@ -14,11 +20,13 @@ from sklearn.mixture import GaussianMixture
14
20
  import pertpy as pt
15
21
 
16
22
  if TYPE_CHECKING:
23
+ from collections.abc import Sequence
24
+
17
25
  from anndata import AnnData
26
+ from matplotlib.axes import Axes
27
+ from matplotlib.colors import Colormap
18
28
  from scipy import sparse
19
29
 
20
- warnings.simplefilter("ignore")
21
-
22
30
 
23
31
  class Mixscape:
24
32
  """Python implementation of Mixscape."""
@@ -65,15 +73,15 @@ class Mixscape:
65
73
 
66
74
  Returns:
67
75
  If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`.
68
- Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
76
+ Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
69
77
 
70
78
  Examples:
71
79
  Calcutate perturbation signature for each cell in the dataset:
72
80
 
73
81
  >>> import pertpy as pt
74
82
  >>> mdata = pt.dt.papalexi_2021()
75
- >>> mixscape_identifier = pt.tl.Mixscape()
76
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
83
+ >>> ms_pt = pt.tl.Mixscape()
84
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
77
85
  """
78
86
  if copy:
79
87
  adata = adata.copy()
@@ -86,18 +94,17 @@ class Mixscape:
86
94
  split_masks = [np.full(adata.n_obs, True, dtype=bool)]
87
95
  else:
88
96
  split_obs = adata.obs[split_by]
89
- cats = split_obs.unique()
90
- split_masks = [split_obs == cat for cat in cats]
97
+ split_masks = [split_obs == cat for cat in split_obs.unique()]
91
98
 
92
- R = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
99
+ representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
93
100
 
94
101
  for split_mask in split_masks:
95
102
  control_mask_split = control_mask & split_mask
96
103
 
97
- R_split = R[split_mask]
98
- R_control = R[control_mask_split]
104
+ R_split = representation[split_mask]
105
+ R_control = representation[control_mask_split]
99
106
 
100
- from pynndescent import NNDescent # saves a lot of import time
107
+ from pynndescent import NNDescent
101
108
 
102
109
  eps = kwargs.pop("epsilon", 0.1)
103
110
  nn_index = NNDescent(R_control, **kwargs)
@@ -161,7 +168,6 @@ class Mixscape:
161
168
 
162
169
  Args:
163
170
  adata: The annotated data object.
164
- pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
165
171
  labels: The column of `.obs` with target gene labels.
166
172
  control: Control category from the `pert_key` column.
167
173
  new_class_name: Name of mixscape classification to be stored in `.obs`.
@@ -172,31 +178,31 @@ class Mixscape:
172
178
  split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
173
179
  the perturbation signature for every replicate separately.
174
180
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
175
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to KO.
181
+ perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
176
182
  copy: Determines whether a copy of the `adata` is returned.
177
183
 
178
184
  Returns:
179
185
  If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
180
- Otherwise writes the results directly to `.obs` of the provided `adata`.
186
+ Otherwise, writes the results directly to `.obs` of the provided `adata`.
181
187
 
182
- mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
183
- Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
188
+ - mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
189
+ Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
184
190
 
185
- mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
186
- Global classification result (perturbed, NP or NT)
191
+ - mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
192
+ Global classification result (perturbed, NP or NT).
187
193
 
188
- mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
189
- Posterior probabilities used to determine if a cell is KO (default).
190
- Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP
194
+ - mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
195
+ Posterior probabilities used to determine if a cell is KO (default).
196
+ Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP.
191
197
 
192
198
  Examples:
193
199
  Calcutate perturbation signature for each cell in the dataset:
194
200
 
195
201
  >>> import pertpy as pt
196
202
  >>> mdata = pt.dt.papalexi_2021()
197
- >>> mixscape_identifier = pt.tl.Mixscape()
198
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
199
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
203
+ >>> ms_pt = pt.tl.Mixscape()
204
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
205
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
200
206
  """
201
207
  if copy:
202
208
  adata = adata.copy()
@@ -220,10 +226,9 @@ class Mixscape:
220
226
  try:
221
227
  X = adata_comp.layers["X_pert"]
222
228
  except KeyError:
223
- print(
224
- '[bold yellow]No "X_pert" found in .layers! -- Please run pert_sign first to calculate perturbation signature!'
225
- )
226
- raise
229
+ raise KeyError(
230
+ "No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
231
+ ) from None
227
232
  # initialize return variables
228
233
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
229
234
  adata.obs[new_class_name] = adata.obs[labels].astype(str)
@@ -305,12 +310,14 @@ class Mixscape:
305
310
  old_classes = adata.obs[new_class_name][all_cells]
306
311
  n_iter += 1
307
312
 
308
- adata.obs.loc[
309
- (adata.obs[new_class_name] == gene) & split_mask, new_class_name
310
- ] = f"{gene} {perturbation_type}"
313
+ adata.obs.loc[(adata.obs[new_class_name] == gene) & split_mask, new_class_name] = (
314
+ f"{gene} {perturbation_type}"
315
+ )
311
316
 
312
317
  adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
313
- adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
318
+ adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = np.round(
319
+ post_prob
320
+ ).astype("int64")
314
321
  adata.uns["mixscape"] = gv_list
315
322
 
316
323
  if copy:
@@ -339,18 +346,18 @@ class Mixscape:
339
346
  control: Control category from the `pert_key` column.
340
347
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
341
348
  layer: Key from `adata.layers` whose value will be used to perform tests on.
342
- control: Control category from the `pert_key` column. Defaults to 'NT'.
343
- n_comps: Number of principal components to use. Defaults to 10.
349
+ control: Control category from the `pert_key` column.
350
+ n_comps: Number of principal components to use.
344
351
  min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
345
- logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. Defaults to 0.25.
352
+ logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
346
353
  split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
347
354
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
348
- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to KO.
355
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
349
356
  copy: Determines whether a copy of the `adata` is returned.
350
357
 
351
358
  Returns:
352
359
  If `copy=True`, returns the copy of `adata` with the LDA result in `.uns`.
353
- Otherwise writes the results directly to `.uns` of the provided `adata`.
360
+ Otherwise, writes the results directly to `.uns` of the provided `adata`.
354
361
 
355
362
  mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`).
356
363
  LDA result.
@@ -360,10 +367,10 @@ class Mixscape:
360
367
 
361
368
  >>> import pertpy as pt
362
369
  >>> mdata = pt.dt.papalexi_2021()
363
- >>> mixscape_identifier = pt.tl.Mixscape()
364
- >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
365
- >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
366
- >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
370
+ >>> ms_pt = pt.tl.Mixscape()
371
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
372
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
373
+ >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
367
374
  """
368
375
  if copy:
369
376
  adata = adata.copy()
@@ -437,7 +444,7 @@ class Mixscape:
437
444
  min_de_genes: float,
438
445
  logfc_threshold: float,
439
446
  ) -> dict[tuple, np.ndarray]:
440
- """determine gene sets across all splits/groups through differential gene expression
447
+ """Determine gene sets across all splits/groups through differential gene expression
441
448
 
442
449
  Args:
443
450
  adata: :class:`~anndata.AnnData` object
@@ -454,7 +461,13 @@ class Mixscape:
454
461
  adata_split = adata[split_mask].copy()
455
462
  # find top DE genes between cells with targeting and non-targeting gRNAs
456
463
  sc.tl.rank_genes_groups(
457
- adata_split, layer=layer, groupby=labels, groups=genes, reference=control, method="t-test"
464
+ adata_split,
465
+ layer=layer,
466
+ groupby=labels,
467
+ groups=genes,
468
+ reference=control,
469
+ method="t-test",
470
+ use_raw=False,
458
471
  )
459
472
  # get DE genes for each gene
460
473
  for gene in genes:
@@ -469,15 +482,6 @@ class Mixscape:
469
482
  return perturbation_markers
470
483
 
471
484
  def _get_column_indices(self, adata, col_names):
472
- """Fetches the column indices in X for a given list of column names
473
-
474
- Args:
475
- adata: :class:`~anndata.AnnData` object
476
- col_names: Column names to extract the indices for
477
-
478
- Returns:
479
- Set of column indices
480
- """
481
485
  if isinstance(col_names, str): # pragma: no cover
482
486
  col_names = [col_names]
483
487
 
@@ -501,3 +505,621 @@ class Mixscape:
501
505
  sd = X.std()
502
506
 
503
507
  return [mu, sd]
508
+
509
+ def plot_barplot( # pragma: no cover
510
+ self,
511
+ adata: AnnData,
512
+ guide_rna_column: str,
513
+ mixscape_class_global: str = "mixscape_class_global",
514
+ axis_text_x_size: int = 8,
515
+ axis_text_y_size: int = 6,
516
+ axis_title_size: int = 8,
517
+ legend_title_size: int = 8,
518
+ legend_text_size: int = 8,
519
+ return_fig: bool | None = None,
520
+ ax: Axes | None = None,
521
+ show: bool | None = None,
522
+ save: bool | str | None = None,
523
+ ):
524
+ """Barplot to visualize perturbation scores calculated by the `mixscape` function.
525
+
526
+ Args:
527
+ adata: The annotated data object.
528
+ guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
529
+ The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
530
+ mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
531
+ show: Show the plot, do not return axis.
532
+ save: If True or a str, save the figure. A string is appended to the default filename.
533
+ Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
534
+
535
+ Returns:
536
+ If `show==False`, return a :class:`~matplotlib.axes.Axes.
537
+
538
+ Examples:
539
+ >>> import pertpy as pt
540
+ >>> mdata = pt.dt.papalexi_2021()
541
+ >>> ms_pt = pt.tl.Mixscape()
542
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
543
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
544
+ >>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
545
+
546
+ Preview:
547
+ .. image:: /_static/docstring_previews/mixscape_barplot.png
548
+ """
549
+ if mixscape_class_global not in adata.obs:
550
+ raise ValueError("Please run the `mixscape` function first.")
551
+ count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
552
+ all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
553
+ KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
554
+ KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
555
+
556
+ new_levels = KO_cells_percentage[guide_rna_column]
557
+ all_cells_percentage[guide_rna_column] = pd.Categorical(
558
+ all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
559
+ )
560
+ all_cells_percentage[mixscape_class_global] = pd.Categorical(
561
+ all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
562
+ )
563
+ all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
564
+ all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
565
+ all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
566
+ NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
567
+
568
+ if show:
569
+ color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
570
+ unique_genes = NP_KO_cells["gene"].unique()
571
+ fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True)
572
+ for i, gene in enumerate(unique_genes):
573
+ ax = axs[int(i / 5), i % 5]
574
+ grouped_df = (
575
+ NP_KO_cells[NP_KO_cells["gene"] == gene]
576
+ .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
577
+ .sum()
578
+ .unstack()
579
+ )
580
+ grouped_df.plot(
581
+ kind="bar",
582
+ stacked=True,
583
+ color=[color_mapping[col] for col in grouped_df.columns],
584
+ ax=ax,
585
+ width=0.8,
586
+ legend=False,
587
+ )
588
+ ax.set_title(
589
+ gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size
590
+ )
591
+ ax.set(xlabel="sgRNA", ylabel="% of cells")
592
+ sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
593
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
594
+ ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
595
+ fig.subplots_adjust(right=0.8)
596
+ fig.subplots_adjust(hspace=0.5, wspace=0.5)
597
+ ax.legend(
598
+ title="mixscape_class_global",
599
+ loc="center right",
600
+ bbox_to_anchor=(2.2, 3.5),
601
+ frameon=True,
602
+ fontsize=legend_text_size,
603
+ title_fontsize=legend_title_size,
604
+ )
605
+
606
+ plt.tight_layout()
607
+ _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
608
+
609
+ def plot_heatmap( # pragma: no cover
610
+ self,
611
+ adata: AnnData,
612
+ labels: str,
613
+ target_gene: str,
614
+ control: str,
615
+ layer: str | None = None,
616
+ method: str | None = "wilcoxon",
617
+ subsample_number: int | None = 900,
618
+ vmin: float | None = -2,
619
+ vmax: float | None = 2,
620
+ return_fig: bool | None = None,
621
+ show: bool | None = None,
622
+ save: bool | str | None = None,
623
+ **kwds,
624
+ ) -> Axes | None:
625
+ """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
626
+
627
+ Args:
628
+ adata: The annotated data object.
629
+ labels: The column of `.obs` with target gene labels.
630
+ target_gene: Target gene name to visualize heatmap for.
631
+ control: Control category from the `pert_key` column.
632
+ layer: Key from `adata.layers` whose value will be used to perform tests on.
633
+ method: The default method is 'wilcoxon', see `method` parameter in `scanpy.tl.rank_genes_groups` for more options.
634
+ subsample_number: Subsample to this number of observations.
635
+ vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
636
+ vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
637
+ show: Show the plot, do not return axis.
638
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
639
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
640
+ ax: A matplotlib axes object. Only works if plotting a single component.
641
+ **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
642
+
643
+ Returns:
644
+ If `show==False`, return a :class:`~matplotlib.axes.Axes`.
645
+
646
+ Examples:
647
+ >>> import pertpy as pt
648
+ >>> mdata = pt.dt.papalexi_2021()
649
+ >>> ms_pt = pt.tl.Mixscape()
650
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
651
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
652
+ >>> ms_pt.plot_heatmap(
653
+ ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
654
+ ... )
655
+
656
+ Preview:
657
+ .. image:: /_static/docstring_previews/mixscape_heatmap.png
658
+ """
659
+ if "mixscape_class" not in adata.obs:
660
+ raise ValueError("Please run `pt.tl.mixscape` first.")
661
+ adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
662
+ sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
663
+ sc.pp.scale(adata_subset, max_value=vmax)
664
+ sc.pp.subsample(adata_subset, n_obs=subsample_number)
665
+
666
+ return sc.pl.rank_genes_groups_heatmap(
667
+ adata_subset,
668
+ groupby="mixscape_class",
669
+ vmin=vmin,
670
+ vmax=vmax,
671
+ n_genes=20,
672
+ groups=["NT"],
673
+ return_fig=return_fig,
674
+ show=show,
675
+ save=save,
676
+ **kwds,
677
+ )
678
+
679
+ def plot_perturbscore( # pragma: no cover
680
+ self,
681
+ adata: AnnData,
682
+ labels: str,
683
+ target_gene: str,
684
+ mixscape_class: str = "mixscape_class",
685
+ color: str = "orange",
686
+ palette: dict[str, str] = None,
687
+ split_by: str = None,
688
+ before_mixscape: bool = False,
689
+ perturbation_type: str = "KO",
690
+ return_fig: bool | None = None,
691
+ ax: Axes | None = None,
692
+ show: bool | None = None,
693
+ save: bool | str | None = None,
694
+ ) -> None:
695
+ """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
696
+
697
+ Requires `pt.tl.mixscape` to be run first.
698
+
699
+ https://satijalab.org/seurat/reference/plotperturbscore
700
+
701
+ Args:
702
+ adata: The annotated data object.
703
+ labels: The column of `.obs` with target gene labels.
704
+ target_gene: Target gene name to visualize perturbation scores for.
705
+ mixscape_class: The column of `.obs` with mixscape classifications.
706
+ color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
707
+ palette: Optional full color palette to overwrite all colors.
708
+ split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
709
+ the perturbation signature for every replicate separately.
710
+ before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
711
+ Default is set to NULL and plots cells by original class ID.
712
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
713
+
714
+ Examples:
715
+ Visualizing the perturbation scores for the cells in a dataset:
716
+
717
+ >>> import pertpy as pt
718
+ >>> mdata = pt.dt.papalexi_2021()
719
+ >>> ms_pt = pt.tl.Mixscape()
720
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
721
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
722
+ >>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
723
+
724
+ Preview:
725
+ .. image:: /_static/docstring_previews/mixscape_perturbscore.png
726
+ """
727
+ if "mixscape" not in adata.uns:
728
+ raise ValueError("Please run the `mixscape` function first.")
729
+ perturbation_score = None
730
+ for key in adata.uns["mixscape"][target_gene].keys():
731
+ perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
732
+ perturbation_score_temp["name"] = key
733
+ if perturbation_score is None:
734
+ perturbation_score = copy.deepcopy(perturbation_score_temp)
735
+ else:
736
+ perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
737
+ perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
738
+ gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
739
+
740
+ # If before_mixscape is True, split densities based on original target gene classification
741
+ if before_mixscape is True:
742
+ palette = {gd: "#7d7d7d", target_gene: color}
743
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
744
+ top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
745
+ plt.close()
746
+ perturbation_score["y_jitter"] = perturbation_score["pvec"]
747
+ rng = np.random.default_rng()
748
+ perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
749
+ low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
750
+ )
751
+ perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
752
+ low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
753
+ )
754
+ # If split_by is provided, split densities based on the split_by
755
+ if split_by is not None:
756
+ sns.set_theme(style="whitegrid")
757
+ g = sns.FacetGrid(
758
+ data=perturbation_score, col=split_by, hue=split_by, palette=palette, height=5, sharey=False
759
+ )
760
+ g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, palette=palette)
761
+ g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5, palette=palette)
762
+ g.set_axis_labels("Perturbation score", "Cell density")
763
+ g.add_legend(title=split_by, fontsize=14, title_fontsize=16)
764
+ g.despine(left=True)
765
+
766
+ # If split_by is not provided, create a single plot
767
+ else:
768
+ sns.set_theme(style="whitegrid")
769
+ sns.kdeplot(
770
+ data=perturbation_score, x="pvec", hue="gene_target", fill=True, common_norm=False, palette=palette
771
+ )
772
+ sns.scatterplot(
773
+ data=perturbation_score, x="pvec", y="y_jitter", hue="gene_target", palette=palette, s=10, alpha=0.5
774
+ )
775
+ plt.xlabel("Perturbation score", fontsize=16)
776
+ plt.ylabel("Cell density", fontsize=16)
777
+ plt.title("Density Plot", fontsize=18)
778
+ plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
779
+ sns.despine()
780
+
781
+ if save:
782
+ plt.savefig(save, bbox_inches="tight")
783
+ if show:
784
+ plt.show()
785
+ if return_fig:
786
+ return plt.gcf()
787
+ if not (show or save):
788
+ return plt.gca()
789
+
790
+ # If before_mixscape is False, split densities based on mixscape classifications
791
+ else:
792
+ if palette is None:
793
+ palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
794
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
795
+ top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
796
+ plt.close()
797
+ perturbation_score["y_jitter"] = perturbation_score["pvec"]
798
+ rng = np.random.default_rng()
799
+ gd2 = list(
800
+ set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
801
+ )[0]
802
+ perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
803
+ low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
804
+ ).astype(np.float32)
805
+ perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"] = (
806
+ rng.uniform(
807
+ low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
808
+ )
809
+ )
810
+ perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
811
+ low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
812
+ )
813
+ # If split_by is provided, split densities based on the split_by
814
+ if split_by is not None:
815
+ sns.set_theme(style="whitegrid")
816
+ g = sns.FacetGrid(
817
+ data=perturbation_score, col=split_by, hue="mix", palette=palette, height=5, sharey=False
818
+ )
819
+ g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, alpha=0.7)
820
+ g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5)
821
+ g.set_axis_labels("Perturbation score", "Cell density")
822
+ g.add_legend(title="mix", fontsize=14, title_fontsize=16)
823
+ g.despine(left=True)
824
+
825
+ # If split_by is not provided, create a single plot
826
+ else:
827
+ sns.set_theme(style="whitegrid")
828
+ sns.kdeplot(
829
+ data=perturbation_score,
830
+ x="pvec",
831
+ hue="mix",
832
+ fill=True,
833
+ common_norm=False,
834
+ palette=palette,
835
+ alpha=0.7,
836
+ )
837
+ sns.scatterplot(
838
+ data=perturbation_score, x="pvec", y="y_jitter", hue="mix", palette=palette, s=10, alpha=0.5
839
+ )
840
+ plt.xlabel("Perturbation score", fontsize=16)
841
+ plt.ylabel("Cell density", fontsize=16)
842
+ plt.title("Density", fontsize=18)
843
+ plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
844
+ sns.despine()
845
+
846
+ if save:
847
+ plt.savefig(save, bbox_inches="tight")
848
+ if show:
849
+ plt.show()
850
+ if return_fig:
851
+ return plt.gcf()
852
+ if not (show or save):
853
+ return plt.gca()
854
+
855
+ def plot_violin( # pragma: no cover
856
+ self,
857
+ adata: AnnData,
858
+ target_gene_idents: str | list[str],
859
+ keys: str | Sequence[str] = "mixscape_class_p_ko",
860
+ groupby: str | None = "mixscape_class",
861
+ log: bool = False,
862
+ use_raw: bool | None = None,
863
+ stripplot: bool = True,
864
+ hue: str | None = None,
865
+ jitter: float | bool = True,
866
+ size: int = 1,
867
+ layer: str | None = None,
868
+ scale: Literal["area", "count", "width"] = "width",
869
+ order: Sequence[str] | None = None,
870
+ multi_panel: bool | None = None,
871
+ xlabel: str = "",
872
+ ylabel: str | Sequence[str] | None = None,
873
+ rotation: float | None = None,
874
+ ax: Axes | None = None,
875
+ show: bool | None = None,
876
+ save: bool | str | None = None,
877
+ **kwargs,
878
+ ):
879
+ """Violin plot using mixscape results.
880
+
881
+ Requires `pt.tl.mixscape` to be run first.
882
+
883
+ Args:
884
+ adata: The annotated data object.
885
+ target_gene_idents: Target gene name to plot.
886
+ keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
887
+ groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
888
+ log: Plot on logarithmic axis.
889
+ use_raw: Whether to use `raw` attribute of `adata`.
890
+ stripplot: Add a stripplot on top of the violin plot.
891
+ order: Order in which to show the categories.
892
+ xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
893
+ ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
894
+ If `None` and `groubpy` is not `None`, defaults to `keys`.
895
+ show: Show the plot, do not return axis.
896
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
897
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
898
+ ax: A matplotlib axes object. Only works if plotting a single component.
899
+ **kwargs: Additional arguments to `seaborn.violinplot`.
900
+
901
+ Returns:
902
+ A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
903
+
904
+ Examples:
905
+ >>> import pertpy as pt
906
+ >>> mdata = pt.dt.papalexi_2021()
907
+ >>> ms_pt = pt.tl.Mixscape()
908
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
909
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
910
+ >>> ms_pt.plot_violin(
911
+ ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
912
+ ... )
913
+
914
+ Preview:
915
+ .. image:: /_static/docstring_previews/mixscape_violin.png
916
+ """
917
+ if isinstance(target_gene_idents, str):
918
+ mixscape_class_mask = adata.obs[groupby] == target_gene_idents
919
+ elif isinstance(target_gene_idents, list):
920
+ mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
921
+ for ident in target_gene_idents:
922
+ mixscape_class_mask |= adata.obs[groupby] == ident
923
+ adata = adata[mixscape_class_mask]
924
+
925
+ sanitize_anndata(adata)
926
+ use_raw = _check_use_raw(adata, use_raw)
927
+ if isinstance(keys, str):
928
+ keys = [keys]
929
+ keys = list(OrderedDict.fromkeys(keys)) # remove duplicates, preserving the order
930
+
931
+ if isinstance(ylabel, str | type(None)):
932
+ ylabel = [ylabel] * (1 if groupby is None else len(keys))
933
+ if groupby is None:
934
+ if len(ylabel) != 1:
935
+ raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
936
+ elif len(ylabel) != len(keys):
937
+ raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
938
+
939
+ if groupby is not None:
940
+ if hue is not None:
941
+ obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
942
+ else:
943
+ obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
944
+
945
+ else:
946
+ obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
947
+ if groupby is None:
948
+ obs_tidy = pd.melt(obs_df, value_vars=keys)
949
+ x = "variable"
950
+ ys = ["value"]
951
+ else:
952
+ obs_tidy = obs_df
953
+ x = groupby
954
+ ys = keys
955
+
956
+ if multi_panel and groupby is None and len(ys) == 1:
957
+ # This is a quick and dirty way for adapting scales across several
958
+ # keys if groupby is None.
959
+ y = ys[0]
960
+
961
+ g = sns.catplot(
962
+ y=y,
963
+ data=obs_tidy,
964
+ kind="violin",
965
+ scale=scale,
966
+ col=x,
967
+ col_order=keys,
968
+ sharey=False,
969
+ order=keys,
970
+ cut=0,
971
+ inner=None,
972
+ **kwargs,
973
+ )
974
+
975
+ if stripplot:
976
+ grouped_df = obs_tidy.groupby(x)
977
+ for ax_id, key in zip(range(g.axes.shape[1]), keys, strict=False):
978
+ sns.stripplot(
979
+ y=y,
980
+ data=grouped_df.get_group(key),
981
+ jitter=jitter,
982
+ size=size,
983
+ color="black",
984
+ ax=g.axes[0, ax_id],
985
+ )
986
+ if log:
987
+ g.set(yscale="log")
988
+ g.set_titles(col_template="{col_name}").set_xlabels("")
989
+ if rotation is not None:
990
+ for ax in g.axes[0]:
991
+ ax.tick_params(axis="x", labelrotation=rotation)
992
+ else:
993
+ # set by default the violin plot cut=0 to limit the extend
994
+ # of the violin plot (see stacked_violin code) for more info.
995
+ kwargs.setdefault("cut", 0)
996
+ kwargs.setdefault("inner")
997
+
998
+ if ax is None:
999
+ axs, _, _, _ = _utils.setup_axes(
1000
+ ax=ax,
1001
+ panels=["x"] if groupby is None else keys,
1002
+ show_ticks=True,
1003
+ right_margin=0.3,
1004
+ )
1005
+ else:
1006
+ axs = [ax]
1007
+ for ax, y, ylab in zip(axs, ys, ylabel, strict=False):
1008
+ ax = sns.violinplot(
1009
+ x=x,
1010
+ y=y,
1011
+ data=obs_tidy,
1012
+ order=order,
1013
+ orient="vertical",
1014
+ scale=scale,
1015
+ ax=ax,
1016
+ hue=hue,
1017
+ **kwargs,
1018
+ )
1019
+ # Get the handles and labels.
1020
+ handles, labels = ax.get_legend_handles_labels()
1021
+ if stripplot:
1022
+ ax = sns.stripplot(
1023
+ x=x,
1024
+ y=y,
1025
+ data=obs_tidy,
1026
+ order=order,
1027
+ jitter=jitter,
1028
+ color="black",
1029
+ size=size,
1030
+ ax=ax,
1031
+ hue=hue,
1032
+ dodge=True,
1033
+ )
1034
+ if xlabel == "" and groupby is not None and rotation is None:
1035
+ xlabel = groupby.replace("_", " ")
1036
+ ax.set_xlabel(xlabel)
1037
+ if ylab is not None:
1038
+ ax.set_ylabel(ylab)
1039
+
1040
+ if log:
1041
+ ax.set_yscale("log")
1042
+ if rotation is not None:
1043
+ ax.tick_params(axis="x", labelrotation=rotation)
1044
+
1045
+ show = settings.autoshow if show is None else show
1046
+ if hue is not None and stripplot is True:
1047
+ plt.legend(handles, labels)
1048
+ _utils.savefig_or_show("mixscape_violin", show=show, save=save)
1049
+
1050
+ if not show:
1051
+ if multi_panel and groupby is None and len(ys) == 1:
1052
+ return g
1053
+ elif len(axs) == 1:
1054
+ return axs[0]
1055
+ else:
1056
+ return axs
1057
+
1058
+ def plot_lda( # pragma: no cover
1059
+ self,
1060
+ adata: AnnData,
1061
+ control: str,
1062
+ mixscape_class: str = "mixscape_class",
1063
+ mixscape_class_global: str = "mixscape_class_global",
1064
+ perturbation_type: str | None = "KO",
1065
+ lda_key: str | None = "mixscape_lda",
1066
+ n_components: int | None = None,
1067
+ color_map: Colormap | str | None = None,
1068
+ palette: str | Sequence[str] | None = None,
1069
+ return_fig: bool | None = None,
1070
+ ax: Axes | None = None,
1071
+ show: bool | None = None,
1072
+ save: bool | str | None = None,
1073
+ **kwds,
1074
+ ) -> None:
1075
+ """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
1076
+
1077
+ Args:
1078
+ adata: The annotated data object.
1079
+ control: Control category from the `pert_key` column.
1080
+ mixscape_class: The column of `.obs` with the mixscape classification result.
1081
+ mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
1082
+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
1083
+ lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
1084
+ n_components: The number of dimensions of the embedding.
1085
+ show: Show the plot, do not return axis.
1086
+ save: If `True` or a `str`, save the figure. A string is appended to the default filename.
1087
+ Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
1088
+ **kwds: Additional arguments to `scanpy.pl.umap`.
1089
+
1090
+ Examples:
1091
+ >>> import pertpy as pt
1092
+ >>> mdata = pt.dt.papalexi_2021()
1093
+ >>> ms_pt = pt.tl.Mixscape()
1094
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
1095
+ >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1096
+ >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1097
+ >>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
1098
+
1099
+ Preview:
1100
+ .. image:: /_static/docstring_previews/mixscape_lda.png
1101
+ """
1102
+ if mixscape_class not in adata.obs:
1103
+ raise ValueError(f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first.')
1104
+ if lda_key not in adata.uns:
1105
+ raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Please run the `lda` function first.')
1106
+
1107
+ adata_subset = adata[
1108
+ (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
1109
+ ].copy()
1110
+ adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
1111
+ if n_components is None:
1112
+ n_components = adata_subset.uns[lda_key].shape[1]
1113
+ sc.pp.neighbors(adata_subset, use_rep=lda_key)
1114
+ sc.tl.umap(adata_subset, n_components=n_components)
1115
+ sc.pl.umap(
1116
+ adata_subset,
1117
+ color=mixscape_class,
1118
+ palette=palette,
1119
+ color_map=color_map,
1120
+ return_fig=return_fig,
1121
+ show=show,
1122
+ save=save,
1123
+ ax=ax,
1124
+ **kwds,
1125
+ )