pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
 - pertpy/data/__init__.py +66 -1
 - pertpy/data/_dataloader.py +28 -26
 - pertpy/data/_datasets.py +261 -92
 - pertpy/metadata/__init__.py +6 -0
 - pertpy/metadata/_cell_line.py +795 -0
 - pertpy/metadata/_compound.py +128 -0
 - pertpy/metadata/_drug.py +238 -0
 - pertpy/metadata/_look_up.py +569 -0
 - pertpy/metadata/_metadata.py +70 -0
 - pertpy/metadata/_moa.py +125 -0
 - pertpy/plot/__init__.py +0 -13
 - pertpy/preprocessing/__init__.py +2 -0
 - pertpy/preprocessing/_guide_rna.py +89 -6
 - pertpy/tools/__init__.py +48 -15
 - pertpy/tools/_augur.py +329 -32
 - pertpy/tools/_cinemaot.py +145 -6
 - pertpy/tools/_coda/_base_coda.py +1237 -116
 - pertpy/tools/_coda/_sccoda.py +66 -36
 - pertpy/tools/_coda/_tasccoda.py +46 -39
 - pertpy/tools/_dialogue.py +180 -77
 - pertpy/tools/_differential_gene_expression/__init__.py +20 -0
 - pertpy/tools/_differential_gene_expression/_base.py +657 -0
 - pertpy/tools/_differential_gene_expression/_checks.py +41 -0
 - pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
 - pertpy/tools/_differential_gene_expression/_edger.py +125 -0
 - pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
 - pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
 - pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
 - pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
 - pertpy/tools/_distances/_distance_tests.py +29 -24
 - pertpy/tools/_distances/_distances.py +584 -98
 - pertpy/tools/_enrichment.py +460 -0
 - pertpy/tools/_kernel_pca.py +1 -1
 - pertpy/tools/_milo.py +406 -49
 - pertpy/tools/_mixscape.py +677 -55
 - pertpy/tools/_perturbation_space/_clustering.py +10 -3
 - pertpy/tools/_perturbation_space/_comparison.py +112 -0
 - pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
 - pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
 - pertpy/tools/_perturbation_space/_simple.py +52 -11
 - pertpy/tools/_scgen/__init__.py +1 -1
 - pertpy/tools/_scgen/_base_components.py +2 -3
 - pertpy/tools/_scgen/_scgen.py +706 -0
 - pertpy/tools/_scgen/_utils.py +3 -5
 - pertpy/tools/decoupler_LICENSE +674 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
 - pertpy-0.8.0.dist-info/RECORD +57 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
 - pertpy/plot/_augur.py +0 -234
 - pertpy/plot/_cinemaot.py +0 -81
 - pertpy/plot/_coda.py +0 -1001
 - pertpy/plot/_dialogue.py +0 -91
 - pertpy/plot/_guide_rna.py +0 -82
 - pertpy/plot/_milopy.py +0 -284
 - pertpy/plot/_mixscape.py +0 -594
 - pertpy/plot/_scgen.py +0 -337
 - pertpy/tools/_differential_gene_expression.py +0 -99
 - pertpy/tools/_metadata/__init__.py +0 -0
 - pertpy/tools/_metadata/_cell_line.py +0 -613
 - pertpy/tools/_metadata/_look_up.py +0 -342
 - pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
 - pertpy/tools/_scgen/_jax_scgen.py +0 -370
 - pertpy-0.6.0.dist-info/RECORD +0 -50
 - /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
 
    
        pertpy/tools/_mixscape.py
    CHANGED
    
    | 
         @@ -1,12 +1,18 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
       3 
     | 
    
         
            -
            import  
     | 
| 
       4 
     | 
    
         
            -
            from  
     | 
| 
      
 3 
     | 
    
         
            +
            import copy
         
     | 
| 
      
 4 
     | 
    
         
            +
            from collections import OrderedDict
         
     | 
| 
      
 5 
     | 
    
         
            +
            from typing import TYPE_CHECKING, Literal
         
     | 
| 
       5 
6 
     | 
    
         | 
| 
      
 7 
     | 
    
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 
       6 
8 
     | 
    
         
             
            import numpy as np
         
     | 
| 
       7 
9 
     | 
    
         
             
            import pandas as pd
         
     | 
| 
       8 
10 
     | 
    
         
             
            import scanpy as sc
         
     | 
| 
       9 
     | 
    
         
            -
             
     | 
| 
      
 11 
     | 
    
         
            +
            import seaborn as sns
         
     | 
| 
      
 12 
     | 
    
         
            +
            from scanpy import get
         
     | 
| 
      
 13 
     | 
    
         
            +
            from scanpy._settings import settings
         
     | 
| 
      
 14 
     | 
    
         
            +
            from scanpy._utils import _check_use_raw, sanitize_anndata
         
     | 
| 
      
 15 
     | 
    
         
            +
            from scanpy.plotting import _utils
         
     | 
| 
       10 
16 
     | 
    
         
             
            from scanpy.tools._utils import _choose_representation
         
     | 
| 
       11 
17 
     | 
    
         
             
            from scipy.sparse import csr_matrix, issparse, spmatrix
         
     | 
| 
       12 
18 
     | 
    
         
             
            from sklearn.mixture import GaussianMixture
         
     | 
| 
         @@ -14,11 +20,13 @@ from sklearn.mixture import GaussianMixture 
     | 
|
| 
       14 
20 
     | 
    
         
             
            import pertpy as pt
         
     | 
| 
       15 
21 
     | 
    
         | 
| 
       16 
22 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
      
 23 
     | 
    
         
            +
                from collections.abc import Sequence
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
       17 
25 
     | 
    
         
             
                from anndata import AnnData
         
     | 
| 
      
 26 
     | 
    
         
            +
                from matplotlib.axes import Axes
         
     | 
| 
      
 27 
     | 
    
         
            +
                from matplotlib.colors import Colormap
         
     | 
| 
       18 
28 
     | 
    
         
             
                from scipy import sparse
         
     | 
| 
       19 
29 
     | 
    
         | 
| 
       20 
     | 
    
         
            -
            warnings.simplefilter("ignore")
         
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
30 
     | 
    
         | 
| 
       23 
31 
     | 
    
         
             
            class Mixscape:
         
     | 
| 
       24 
32 
     | 
    
         
             
                """Python implementation of Mixscape."""
         
     | 
| 
         @@ -65,15 +73,15 @@ class Mixscape: 
     | 
|
| 
       65 
73 
     | 
    
         | 
| 
       66 
74 
     | 
    
         
             
                    Returns:
         
     | 
| 
       67 
75 
     | 
    
         
             
                        If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`.
         
     | 
| 
       68 
     | 
    
         
            -
                        Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
         
     | 
| 
      
 76 
     | 
    
         
            +
                        Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
         
     | 
| 
       69 
77 
     | 
    
         | 
| 
       70 
78 
     | 
    
         
             
                    Examples:
         
     | 
| 
       71 
79 
     | 
    
         
             
                        Calcutate perturbation signature for each cell in the dataset:
         
     | 
| 
       72 
80 
     | 
    
         | 
| 
       73 
81 
     | 
    
         
             
                        >>> import pertpy as pt
         
     | 
| 
       74 
82 
     | 
    
         
             
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
       75 
     | 
    
         
            -
                        >>>  
     | 
| 
       76 
     | 
    
         
            -
                        >>>  
     | 
| 
      
 83 
     | 
    
         
            +
                        >>> ms_pt = pt.tl.Mixscape()
         
     | 
| 
      
 84 
     | 
    
         
            +
                        >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
         
     | 
| 
       77 
85 
     | 
    
         
             
                    """
         
     | 
| 
       78 
86 
     | 
    
         
             
                    if copy:
         
     | 
| 
       79 
87 
     | 
    
         
             
                        adata = adata.copy()
         
     | 
| 
         @@ -86,18 +94,17 @@ class Mixscape: 
     | 
|
| 
       86 
94 
     | 
    
         
             
                        split_masks = [np.full(adata.n_obs, True, dtype=bool)]
         
     | 
| 
       87 
95 
     | 
    
         
             
                    else:
         
     | 
| 
       88 
96 
     | 
    
         
             
                        split_obs = adata.obs[split_by]
         
     | 
| 
       89 
     | 
    
         
            -
                         
     | 
| 
       90 
     | 
    
         
            -
                        split_masks = [split_obs == cat for cat in cats]
         
     | 
| 
      
 97 
     | 
    
         
            +
                        split_masks = [split_obs == cat for cat in split_obs.unique()]
         
     | 
| 
       91 
98 
     | 
    
         | 
| 
       92 
     | 
    
         
            -
                     
     | 
| 
      
 99 
     | 
    
         
            +
                    representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
         
     | 
| 
       93 
100 
     | 
    
         | 
| 
       94 
101 
     | 
    
         
             
                    for split_mask in split_masks:
         
     | 
| 
       95 
102 
     | 
    
         
             
                        control_mask_split = control_mask & split_mask
         
     | 
| 
       96 
103 
     | 
    
         | 
| 
       97 
     | 
    
         
            -
                        R_split =  
     | 
| 
       98 
     | 
    
         
            -
                        R_control =  
     | 
| 
      
 104 
     | 
    
         
            +
                        R_split = representation[split_mask]
         
     | 
| 
      
 105 
     | 
    
         
            +
                        R_control = representation[control_mask_split]
         
     | 
| 
       99 
106 
     | 
    
         | 
| 
       100 
     | 
    
         
            -
                        from pynndescent import NNDescent 
     | 
| 
      
 107 
     | 
    
         
            +
                        from pynndescent import NNDescent
         
     | 
| 
       101 
108 
     | 
    
         | 
| 
       102 
109 
     | 
    
         
             
                        eps = kwargs.pop("epsilon", 0.1)
         
     | 
| 
       103 
110 
     | 
    
         
             
                        nn_index = NNDescent(R_control, **kwargs)
         
     | 
| 
         @@ -161,7 +168,6 @@ class Mixscape: 
     | 
|
| 
       161 
168 
     | 
    
         | 
| 
       162 
169 
     | 
    
         
             
                    Args:
         
     | 
| 
       163 
170 
     | 
    
         
             
                        adata: The annotated data object.
         
     | 
| 
       164 
     | 
    
         
            -
                        pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
         
     | 
| 
       165 
171 
     | 
    
         
             
                        labels: The column of `.obs` with target gene labels.
         
     | 
| 
       166 
172 
     | 
    
         
             
                        control: Control category from the `pert_key` column.
         
     | 
| 
       167 
173 
     | 
    
         
             
                        new_class_name: Name of mixscape classification to be stored in `.obs`.
         
     | 
| 
         @@ -172,31 +178,31 @@ class Mixscape: 
     | 
|
| 
       172 
178 
     | 
    
         
             
                        split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
         
     | 
| 
       173 
179 
     | 
    
         
             
                                the perturbation signature for every replicate separately.
         
     | 
| 
       174 
180 
     | 
    
         
             
                        pval_cutoff: P-value cut-off for selection of significantly DE genes.
         
     | 
| 
       175 
     | 
    
         
            -
                        perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. 
     | 
| 
      
 181 
     | 
    
         
            +
                        perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
         
     | 
| 
       176 
182 
     | 
    
         
             
                        copy: Determines whether a copy of the `adata` is returned.
         
     | 
| 
       177 
183 
     | 
    
         | 
| 
       178 
184 
     | 
    
         
             
                    Returns:
         
     | 
| 
       179 
185 
     | 
    
         
             
                        If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
         
     | 
| 
       180 
     | 
    
         
            -
                        Otherwise writes the results directly to `.obs` of the provided `adata`.
         
     | 
| 
      
 186 
     | 
    
         
            +
                        Otherwise, writes the results directly to `.obs` of the provided `adata`.
         
     | 
| 
       181 
187 
     | 
    
         | 
| 
       182 
     | 
    
         
            -
                        mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
         
     | 
| 
       183 
     | 
    
         
            -
             
     | 
| 
      
 188 
     | 
    
         
            +
                        - mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
         
     | 
| 
      
 189 
     | 
    
         
            +
                          Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
         
     | 
| 
       184 
190 
     | 
    
         | 
| 
       185 
     | 
    
         
            -
                        mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
         
     | 
| 
       186 
     | 
    
         
            -
             
     | 
| 
      
 191 
     | 
    
         
            +
                        - mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
         
     | 
| 
      
 192 
     | 
    
         
            +
                          Global classification result (perturbed, NP or NT).
         
     | 
| 
       187 
193 
     | 
    
         | 
| 
       188 
     | 
    
         
            -
                        mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
         
     | 
| 
       189 
     | 
    
         
            -
             
     | 
| 
       190 
     | 
    
         
            -
             
     | 
| 
      
 194 
     | 
    
         
            +
                        - mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
         
     | 
| 
      
 195 
     | 
    
         
            +
                          Posterior probabilities used to determine if a cell is KO (default).
         
     | 
| 
      
 196 
     | 
    
         
            +
                          Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP.
         
     | 
| 
       191 
197 
     | 
    
         | 
| 
       192 
198 
     | 
    
         
             
                    Examples:
         
     | 
| 
       193 
199 
     | 
    
         
             
                        Calcutate perturbation signature for each cell in the dataset:
         
     | 
| 
       194 
200 
     | 
    
         | 
| 
       195 
201 
     | 
    
         
             
                        >>> import pertpy as pt
         
     | 
| 
       196 
202 
     | 
    
         
             
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
       197 
     | 
    
         
            -
                        >>>  
     | 
| 
       198 
     | 
    
         
            -
                        >>>  
     | 
| 
       199 
     | 
    
         
            -
                        >>>  
     | 
| 
      
 203 
     | 
    
         
            +
                        >>> ms_pt = pt.tl.Mixscape()
         
     | 
| 
      
 204 
     | 
    
         
            +
                        >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
         
     | 
| 
      
 205 
     | 
    
         
            +
                        >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
         
     | 
| 
       200 
206 
     | 
    
         
             
                    """
         
     | 
| 
       201 
207 
     | 
    
         
             
                    if copy:
         
     | 
| 
       202 
208 
     | 
    
         
             
                        adata = adata.copy()
         
     | 
| 
         @@ -220,10 +226,9 @@ class Mixscape: 
     | 
|
| 
       220 
226 
     | 
    
         
             
                        try:
         
     | 
| 
       221 
227 
     | 
    
         
             
                            X = adata_comp.layers["X_pert"]
         
     | 
| 
       222 
228 
     | 
    
         
             
                        except KeyError:
         
     | 
| 
       223 
     | 
    
         
            -
                             
     | 
| 
       224 
     | 
    
         
            -
                                 
     | 
| 
       225 
     | 
    
         
            -
                            )
         
     | 
| 
       226 
     | 
    
         
            -
                            raise
         
     | 
| 
      
 229 
     | 
    
         
            +
                            raise KeyError(
         
     | 
| 
      
 230 
     | 
    
         
            +
                                "No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
         
     | 
| 
      
 231 
     | 
    
         
            +
                            ) from None
         
     | 
| 
       227 
232 
     | 
    
         
             
                    # initialize return variables
         
     | 
| 
       228 
233 
     | 
    
         
             
                    adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
         
     | 
| 
       229 
234 
     | 
    
         
             
                    adata.obs[new_class_name] = adata.obs[labels].astype(str)
         
     | 
| 
         @@ -305,12 +310,14 @@ class Mixscape: 
     | 
|
| 
       305 
310 
     | 
    
         
             
                                    old_classes = adata.obs[new_class_name][all_cells]
         
     | 
| 
       306 
311 
     | 
    
         
             
                                    n_iter += 1
         
     | 
| 
       307 
312 
     | 
    
         | 
| 
       308 
     | 
    
         
            -
                                adata.obs.loc[
         
     | 
| 
       309 
     | 
    
         
            -
                                     
     | 
| 
       310 
     | 
    
         
            -
                                 
     | 
| 
      
 313 
     | 
    
         
            +
                                adata.obs.loc[(adata.obs[new_class_name] == gene) & split_mask, new_class_name] = (
         
     | 
| 
      
 314 
     | 
    
         
            +
                                    f"{gene} {perturbation_type}"
         
     | 
| 
      
 315 
     | 
    
         
            +
                                )
         
     | 
| 
       311 
316 
     | 
    
         | 
| 
       312 
317 
     | 
    
         
             
                            adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
         
     | 
| 
       313 
     | 
    
         
            -
                            adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] =  
     | 
| 
      
 318 
     | 
    
         
            +
                            adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = np.round(
         
     | 
| 
      
 319 
     | 
    
         
            +
                                post_prob
         
     | 
| 
      
 320 
     | 
    
         
            +
                            ).astype("int64")
         
     | 
| 
       314 
321 
     | 
    
         
             
                    adata.uns["mixscape"] = gv_list
         
     | 
| 
       315 
322 
     | 
    
         | 
| 
       316 
323 
     | 
    
         
             
                    if copy:
         
     | 
| 
         @@ -339,18 +346,18 @@ class Mixscape: 
     | 
|
| 
       339 
346 
     | 
    
         
             
                        control: Control category from the `pert_key` column.
         
     | 
| 
       340 
347 
     | 
    
         
             
                        mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
         
     | 
| 
       341 
348 
     | 
    
         
             
                        layer: Key from `adata.layers` whose value will be used to perform tests on.
         
     | 
| 
       342 
     | 
    
         
            -
                        control: Control category from the `pert_key` column. 
     | 
| 
       343 
     | 
    
         
            -
                        n_comps: Number of principal components to use. 
     | 
| 
      
 349 
     | 
    
         
            +
                        control: Control category from the `pert_key` column.
         
     | 
| 
      
 350 
     | 
    
         
            +
                        n_comps: Number of principal components to use.
         
     | 
| 
       344 
351 
     | 
    
         
             
                        min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
         
     | 
| 
       345 
     | 
    
         
            -
                        logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. 
     | 
| 
      
 352 
     | 
    
         
            +
                        logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
         
     | 
| 
       346 
353 
     | 
    
         
             
                        split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
         
     | 
| 
       347 
354 
     | 
    
         
             
                        pval_cutoff: P-value cut-off for selection of significantly DE genes.
         
     | 
| 
       348 
     | 
    
         
            -
                        perturbation_type:  
     | 
| 
      
 355 
     | 
    
         
            +
                        perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
         
     | 
| 
       349 
356 
     | 
    
         
             
                        copy: Determines whether a copy of the `adata` is returned.
         
     | 
| 
       350 
357 
     | 
    
         | 
| 
       351 
358 
     | 
    
         
             
                    Returns:
         
     | 
| 
       352 
359 
     | 
    
         
             
                        If `copy=True`, returns the copy of `adata` with the LDA result in `.uns`.
         
     | 
| 
       353 
     | 
    
         
            -
                        Otherwise writes the results directly to `.uns` of the provided `adata`.
         
     | 
| 
      
 360 
     | 
    
         
            +
                        Otherwise, writes the results directly to `.uns` of the provided `adata`.
         
     | 
| 
       354 
361 
     | 
    
         | 
| 
       355 
362 
     | 
    
         
             
                        mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`).
         
     | 
| 
       356 
363 
     | 
    
         
             
                        LDA result.
         
     | 
| 
         @@ -360,10 +367,10 @@ class Mixscape: 
     | 
|
| 
       360 
367 
     | 
    
         | 
| 
       361 
368 
     | 
    
         
             
                        >>> import pertpy as pt
         
     | 
| 
       362 
369 
     | 
    
         
             
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
       363 
     | 
    
         
            -
                        >>>  
     | 
| 
       364 
     | 
    
         
            -
                        >>>  
     | 
| 
       365 
     | 
    
         
            -
                        >>>  
     | 
| 
       366 
     | 
    
         
            -
                        >>>  
     | 
| 
      
 370 
     | 
    
         
            +
                        >>> ms_pt = pt.tl.Mixscape()
         
     | 
| 
      
 371 
     | 
    
         
            +
                        >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
         
     | 
| 
      
 372 
     | 
    
         
            +
                        >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
         
     | 
| 
      
 373 
     | 
    
         
            +
                        >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
         
     | 
| 
       367 
374 
     | 
    
         
             
                    """
         
     | 
| 
       368 
375 
     | 
    
         
             
                    if copy:
         
     | 
| 
       369 
376 
     | 
    
         
             
                        adata = adata.copy()
         
     | 
| 
         @@ -437,7 +444,7 @@ class Mixscape: 
     | 
|
| 
       437 
444 
     | 
    
         
             
                    min_de_genes: float,
         
     | 
| 
       438 
445 
     | 
    
         
             
                    logfc_threshold: float,
         
     | 
| 
       439 
446 
     | 
    
         
             
                ) -> dict[tuple, np.ndarray]:
         
     | 
| 
       440 
     | 
    
         
            -
                    """ 
     | 
| 
      
 447 
     | 
    
         
            +
                    """Determine gene sets across all splits/groups through differential gene expression
         
     | 
| 
       441 
448 
     | 
    
         | 
| 
       442 
449 
     | 
    
         
             
                    Args:
         
     | 
| 
       443 
450 
     | 
    
         
             
                        adata: :class:`~anndata.AnnData` object
         
     | 
| 
         @@ -454,7 +461,13 @@ class Mixscape: 
     | 
|
| 
       454 
461 
     | 
    
         
             
                        adata_split = adata[split_mask].copy()
         
     | 
| 
       455 
462 
     | 
    
         
             
                        # find top DE genes between cells with targeting and non-targeting gRNAs
         
     | 
| 
       456 
463 
     | 
    
         
             
                        sc.tl.rank_genes_groups(
         
     | 
| 
       457 
     | 
    
         
            -
                            adata_split, 
     | 
| 
      
 464 
     | 
    
         
            +
                            adata_split,
         
     | 
| 
      
 465 
     | 
    
         
            +
                            layer=layer,
         
     | 
| 
      
 466 
     | 
    
         
            +
                            groupby=labels,
         
     | 
| 
      
 467 
     | 
    
         
            +
                            groups=genes,
         
     | 
| 
      
 468 
     | 
    
         
            +
                            reference=control,
         
     | 
| 
      
 469 
     | 
    
         
            +
                            method="t-test",
         
     | 
| 
      
 470 
     | 
    
         
            +
                            use_raw=False,
         
     | 
| 
       458 
471 
     | 
    
         
             
                        )
         
     | 
| 
       459 
472 
     | 
    
         
             
                        # get DE genes for each gene
         
     | 
| 
       460 
473 
     | 
    
         
             
                        for gene in genes:
         
     | 
| 
         @@ -469,15 +482,6 @@ class Mixscape: 
     | 
|
| 
       469 
482 
     | 
    
         
             
                    return perturbation_markers
         
     | 
| 
       470 
483 
     | 
    
         | 
| 
       471 
484 
     | 
    
         
             
                def _get_column_indices(self, adata, col_names):
         
     | 
| 
       472 
     | 
    
         
            -
                    """Fetches the column indices in X for a given list of column names
         
     | 
| 
       473 
     | 
    
         
            -
             
     | 
| 
       474 
     | 
    
         
            -
                    Args:
         
     | 
| 
       475 
     | 
    
         
            -
                        adata: :class:`~anndata.AnnData` object
         
     | 
| 
       476 
     | 
    
         
            -
                        col_names: Column names to extract the indices for
         
     | 
| 
       477 
     | 
    
         
            -
             
     | 
| 
       478 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       479 
     | 
    
         
            -
                        Set of column indices
         
     | 
| 
       480 
     | 
    
         
            -
                    """
         
     | 
| 
       481 
485 
     | 
    
         
             
                    if isinstance(col_names, str):  # pragma: no cover
         
     | 
| 
       482 
486 
     | 
    
         
             
                        col_names = [col_names]
         
     | 
| 
       483 
487 
     | 
    
         | 
| 
         @@ -501,3 +505,621 @@ class Mixscape: 
     | 
|
| 
       501 
505 
     | 
    
         
             
                    sd = X.std()
         
     | 
| 
       502 
506 
     | 
    
         | 
| 
       503 
507 
     | 
    
         
             
                    return [mu, sd]
         
     | 
| 
      
 508 
     | 
    
         
            +
             
     | 
| 
      
 509 
     | 
    
         
            +
                def plot_barplot(  # pragma: no cover
         
     | 
| 
      
 510 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 511 
     | 
    
         
            +
                    adata: AnnData,
         
     | 
| 
      
 512 
     | 
    
         
            +
                    guide_rna_column: str,
         
     | 
| 
      
 513 
     | 
    
         
            +
                    mixscape_class_global: str = "mixscape_class_global",
         
     | 
| 
      
 514 
     | 
    
         
            +
                    axis_text_x_size: int = 8,
         
     | 
| 
      
 515 
     | 
    
         
            +
                    axis_text_y_size: int = 6,
         
     | 
| 
      
 516 
     | 
    
         
            +
                    axis_title_size: int = 8,
         
     | 
| 
      
 517 
     | 
    
         
            +
                    legend_title_size: int = 8,
         
     | 
| 
      
 518 
     | 
    
         
            +
                    legend_text_size: int = 8,
         
     | 
| 
      
 519 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 520 
     | 
    
         
            +
                    ax: Axes | None = None,
         
     | 
| 
      
 521 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 522 
     | 
    
         
            +
                    save: bool | str | None = None,
         
     | 
| 
      
 523 
     | 
    
         
            +
                ):
         
     | 
| 
      
 524 
     | 
    
         
            +
                    """Barplot to visualize perturbation scores calculated by the `mixscape` function.
         
     | 
| 
      
 525 
     | 
    
         
            +
             
     | 
| 
      
 526 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 527 
     | 
    
         
            +
                        adata: The annotated data object.
         
     | 
| 
      
 528 
     | 
    
         
            +
                        guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
         
     | 
| 
      
 529 
     | 
    
         
            +
                                          The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
         
     | 
| 
      
 530 
     | 
    
         
            +
                        mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
         
     | 
| 
      
 531 
     | 
    
         
            +
                        show: Show the plot, do not return axis.
         
     | 
| 
      
 532 
     | 
    
         
            +
                        save: If True or a str, save the figure. A string is appended to the default filename.
         
     | 
| 
      
 533 
     | 
    
         
            +
                              Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
         
     | 
| 
      
 534 
     | 
    
         
            +
             
     | 
| 
      
 535 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 536 
     | 
    
         
            +
                        If `show==False`, return a :class:`~matplotlib.axes.Axes.
         
     | 
| 
      
 537 
     | 
    
         
            +
             
     | 
| 
      
 538 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 539 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 540 
     | 
    
         
            +
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
      
 541 
     | 
    
         
            +
                        >>> ms_pt = pt.tl.Mixscape()
         
     | 
| 
      
 542 
     | 
    
         
            +
                        >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
         
     | 
| 
      
 543 
     | 
    
         
            +
                        >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
         
     | 
| 
      
 544 
     | 
    
         
            +
                        >>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
         
     | 
| 
      
 545 
     | 
    
         
            +
             
     | 
| 
      
 546 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 547 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/mixscape_barplot.png
         
     | 
| 
      
 548 
     | 
    
         
            +
                    """
         
     | 
| 
      
 549 
     | 
    
         
            +
                    if mixscape_class_global not in adata.obs:
         
     | 
| 
      
 550 
     | 
    
         
            +
                        raise ValueError("Please run the `mixscape` function first.")
         
     | 
| 
      
 551 
     | 
    
         
            +
                    count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
         
     | 
| 
      
 552 
     | 
    
         
            +
                    all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
         
     | 
| 
      
 553 
     | 
    
         
            +
                    KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
         
     | 
| 
      
 554 
     | 
    
         
            +
                    KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
         
     | 
| 
      
 555 
     | 
    
         
            +
             
     | 
| 
      
 556 
     | 
    
         
            +
                    new_levels = KO_cells_percentage[guide_rna_column]
         
     | 
| 
      
 557 
     | 
    
         
            +
                    all_cells_percentage[guide_rna_column] = pd.Categorical(
         
     | 
| 
      
 558 
     | 
    
         
            +
                        all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
         
     | 
| 
      
 559 
     | 
    
         
            +
                    )
         
     | 
| 
      
 560 
     | 
    
         
            +
                    all_cells_percentage[mixscape_class_global] = pd.Categorical(
         
     | 
| 
      
 561 
     | 
    
         
            +
                        all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
         
     | 
| 
      
 562 
     | 
    
         
            +
                    )
         
     | 
| 
      
 563 
     | 
    
         
            +
                    all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
         
     | 
| 
      
 564 
     | 
    
         
            +
                    all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
         
     | 
| 
      
 565 
     | 
    
         
            +
                    all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
         
     | 
| 
      
 566 
     | 
    
         
            +
                    NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
         
     | 
| 
      
 567 
     | 
    
         
            +
             
     | 
| 
      
 568 
     | 
    
         
            +
                    if show:
         
     | 
| 
      
 569 
     | 
    
         
            +
                        color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
         
     | 
| 
      
 570 
     | 
    
         
            +
                        unique_genes = NP_KO_cells["gene"].unique()
         
     | 
| 
      
 571 
     | 
    
         
            +
                        fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True)
         
     | 
| 
      
 572 
     | 
    
         
            +
                        for i, gene in enumerate(unique_genes):
         
     | 
| 
      
 573 
     | 
    
         
            +
                            ax = axs[int(i / 5), i % 5]
         
     | 
| 
      
 574 
     | 
    
         
            +
                            grouped_df = (
         
     | 
| 
      
 575 
     | 
    
         
            +
                                NP_KO_cells[NP_KO_cells["gene"] == gene]
         
     | 
| 
      
 576 
     | 
    
         
            +
                                .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
         
     | 
| 
      
 577 
     | 
    
         
            +
                                .sum()
         
     | 
| 
      
 578 
     | 
    
         
            +
                                .unstack()
         
     | 
| 
      
 579 
     | 
    
         
            +
                            )
         
     | 
| 
      
 580 
     | 
    
         
            +
                            grouped_df.plot(
         
     | 
| 
      
 581 
     | 
    
         
            +
                                kind="bar",
         
     | 
| 
      
 582 
     | 
    
         
            +
                                stacked=True,
         
     | 
| 
      
 583 
     | 
    
         
            +
                                color=[color_mapping[col] for col in grouped_df.columns],
         
     | 
| 
      
 584 
     | 
    
         
            +
                                ax=ax,
         
     | 
| 
      
 585 
     | 
    
         
            +
                                width=0.8,
         
     | 
| 
      
 586 
     | 
    
         
            +
                                legend=False,
         
     | 
| 
      
 587 
     | 
    
         
            +
                            )
         
     | 
| 
      
 588 
     | 
    
         
            +
                            ax.set_title(
         
     | 
| 
      
 589 
     | 
    
         
            +
                                gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size
         
     | 
| 
      
 590 
     | 
    
         
            +
                            )
         
     | 
| 
      
 591 
     | 
    
         
            +
                            ax.set(xlabel="sgRNA", ylabel="% of cells")
         
     | 
| 
      
 592 
     | 
    
         
            +
                            sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
         
     | 
| 
      
 593 
     | 
    
         
            +
                            ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
         
     | 
| 
      
 594 
     | 
    
         
            +
                            ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
         
     | 
| 
      
 595 
     | 
    
         
            +
                        fig.subplots_adjust(right=0.8)
         
     | 
| 
      
 596 
     | 
    
         
            +
                        fig.subplots_adjust(hspace=0.5, wspace=0.5)
         
     | 
| 
      
 597 
     | 
    
         
            +
                        ax.legend(
         
     | 
| 
      
 598 
     | 
    
         
            +
                            title="mixscape_class_global",
         
     | 
| 
      
 599 
     | 
    
         
            +
                            loc="center right",
         
     | 
| 
      
 600 
     | 
    
         
            +
                            bbox_to_anchor=(2.2, 3.5),
         
     | 
| 
      
 601 
     | 
    
         
            +
                            frameon=True,
         
     | 
| 
      
 602 
     | 
    
         
            +
                            fontsize=legend_text_size,
         
     | 
| 
      
 603 
     | 
    
         
            +
                            title_fontsize=legend_title_size,
         
     | 
| 
      
 604 
     | 
    
         
            +
                        )
         
     | 
| 
      
 605 
     | 
    
         
            +
             
     | 
| 
      
 606 
     | 
    
         
            +
                    plt.tight_layout()
         
     | 
| 
      
 607 
     | 
    
         
            +
                    _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
         
     | 
| 
      
 608 
     | 
    
         
            +
             
     | 
| 
      
 609 
     | 
    
         
            +
                def plot_heatmap(  # pragma: no cover
         
     | 
| 
      
 610 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 611 
     | 
    
         
            +
                    adata: AnnData,
         
     | 
| 
      
 612 
     | 
    
         
            +
                    labels: str,
         
     | 
| 
      
 613 
     | 
    
         
            +
                    target_gene: str,
         
     | 
| 
      
 614 
     | 
    
         
            +
                    control: str,
         
     | 
| 
      
 615 
     | 
    
         
            +
                    layer: str | None = None,
         
     | 
| 
      
 616 
     | 
    
         
            +
                    method: str | None = "wilcoxon",
         
     | 
| 
      
 617 
     | 
    
         
            +
                    subsample_number: int | None = 900,
         
     | 
| 
      
 618 
     | 
    
         
            +
                    vmin: float | None = -2,
         
     | 
| 
      
 619 
     | 
    
         
            +
                    vmax: float | None = 2,
         
     | 
| 
      
 620 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 621 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 622 
     | 
    
         
            +
                    save: bool | str | None = None,
         
     | 
| 
      
 623 
     | 
    
         
            +
                    **kwds,
         
     | 
| 
      
 624 
     | 
    
         
            +
                ) -> Axes | None:
         
     | 
| 
      
 625 
     | 
    
         
            +
                    """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
         
     | 
| 
      
 626 
     | 
    
         
            +
             
     | 
| 
      
 627 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 628 
     | 
    
         
            +
                        adata: The annotated data object.
         
     | 
| 
      
 629 
     | 
    
         
            +
                        labels: The column of `.obs` with target gene labels.
         
     | 
| 
      
 630 
     | 
    
         
            +
                        target_gene: Target gene name to visualize heatmap for.
         
     | 
| 
      
 631 
     | 
    
         
            +
                        control: Control category from the `pert_key` column.
         
     | 
| 
      
 632 
     | 
    
         
            +
                        layer: Key from `adata.layers` whose value will be used to perform tests on.
         
     | 
| 
      
 633 
     | 
    
         
            +
                        method: The default method is 'wilcoxon', see `method` parameter in `scanpy.tl.rank_genes_groups` for more options.
         
     | 
| 
      
 634 
     | 
    
         
            +
                        subsample_number: Subsample to this number of observations.
         
     | 
| 
      
 635 
     | 
    
         
            +
                        vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
         
     | 
| 
      
 636 
     | 
    
         
            +
                        vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
         
     | 
| 
      
 637 
     | 
    
         
            +
                        show: Show the plot, do not return axis.
         
     | 
| 
      
 638 
     | 
    
         
            +
                        save: If `True` or a `str`, save the figure. A string is appended to the default filename.
         
     | 
| 
      
 639 
     | 
    
         
            +
                              Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
         
     | 
| 
      
 640 
     | 
    
         
            +
                        ax: A matplotlib axes object. Only works if plotting a single component.
         
     | 
| 
      
 641 
     | 
    
         
            +
                        **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
         
     | 
| 
      
 642 
     | 
    
         
            +
             
     | 
| 
      
 643 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 644 
     | 
    
         
            +
                        If `show==False`, return a :class:`~matplotlib.axes.Axes`.
         
     | 
| 
      
 645 
     | 
    
         
            +
             
     | 
| 
      
 646 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 647 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 648 
     | 
    
         
            +
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
      
 649 
     | 
    
         
            +
                        >>> ms_pt = pt.tl.Mixscape()
         
     | 
| 
      
 650 
     | 
    
         
            +
                        >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
         
     | 
| 
      
 651 
     | 
    
         
            +
                        >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
         
     | 
| 
      
 652 
     | 
    
         
            +
                        >>> ms_pt.plot_heatmap(
         
     | 
| 
      
 653 
     | 
    
         
            +
                        ...     adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
         
     | 
| 
      
 654 
     | 
    
         
            +
                        ... )
         
     | 
| 
      
 655 
     | 
    
         
            +
             
     | 
| 
      
 656 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 657 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/mixscape_heatmap.png
         
     | 
| 
      
 658 
     | 
    
         
            +
                    """
         
     | 
| 
      
 659 
     | 
    
         
            +
                    if "mixscape_class" not in adata.obs:
         
     | 
| 
      
 660 
     | 
    
         
            +
                        raise ValueError("Please run `pt.tl.mixscape` first.")
         
     | 
| 
      
 661 
     | 
    
         
            +
                    adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
         
     | 
| 
      
 662 
     | 
    
         
            +
                    sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
         
     | 
| 
      
 663 
     | 
    
         
            +
                    sc.pp.scale(adata_subset, max_value=vmax)
         
     | 
| 
      
 664 
     | 
    
         
            +
                    sc.pp.subsample(adata_subset, n_obs=subsample_number)
         
     | 
| 
      
 665 
     | 
    
         
            +
             
     | 
| 
      
 666 
     | 
    
         
            +
                    return sc.pl.rank_genes_groups_heatmap(
         
     | 
| 
      
 667 
     | 
    
         
            +
                        adata_subset,
         
     | 
| 
      
 668 
     | 
    
         
            +
                        groupby="mixscape_class",
         
     | 
| 
      
 669 
     | 
    
         
            +
                        vmin=vmin,
         
     | 
| 
      
 670 
     | 
    
         
            +
                        vmax=vmax,
         
     | 
| 
      
 671 
     | 
    
         
            +
                        n_genes=20,
         
     | 
| 
      
 672 
     | 
    
         
            +
                        groups=["NT"],
         
     | 
| 
      
 673 
     | 
    
         
            +
                        return_fig=return_fig,
         
     | 
| 
      
 674 
     | 
    
         
            +
                        show=show,
         
     | 
| 
      
 675 
     | 
    
         
            +
                        save=save,
         
     | 
| 
      
 676 
     | 
    
         
            +
                        **kwds,
         
     | 
| 
      
 677 
     | 
    
         
            +
                    )
         
     | 
| 
      
 678 
     | 
    
         
            +
             
     | 
| 
      
 679 
     | 
    
         
            +
                def plot_perturbscore(  # pragma: no cover
         
     | 
| 
      
 680 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 681 
     | 
    
         
            +
                    adata: AnnData,
         
     | 
| 
      
 682 
     | 
    
         
            +
                    labels: str,
         
     | 
| 
      
 683 
     | 
    
         
            +
                    target_gene: str,
         
     | 
| 
      
 684 
     | 
    
         
            +
                    mixscape_class: str = "mixscape_class",
         
     | 
| 
      
 685 
     | 
    
         
            +
                    color: str = "orange",
         
     | 
| 
      
 686 
     | 
    
         
            +
                    palette: dict[str, str] = None,
         
     | 
| 
      
 687 
     | 
    
         
            +
                    split_by: str = None,
         
     | 
| 
      
 688 
     | 
    
         
            +
                    before_mixscape: bool = False,
         
     | 
| 
      
 689 
     | 
    
         
            +
                    perturbation_type: str = "KO",
         
     | 
| 
      
 690 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 691 
     | 
    
         
            +
                    ax: Axes | None = None,
         
     | 
| 
      
 692 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 693 
     | 
    
         
            +
                    save: bool | str | None = None,
         
     | 
| 
      
 694 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 695 
     | 
    
         
            +
                    """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
         
     | 
| 
      
 696 
     | 
    
         
            +
             
     | 
| 
      
 697 
     | 
    
         
            +
                    Requires `pt.tl.mixscape` to be run first.
         
     | 
| 
      
 698 
     | 
    
         
            +
             
     | 
| 
      
 699 
     | 
    
         
            +
                    https://satijalab.org/seurat/reference/plotperturbscore
         
     | 
| 
      
 700 
     | 
    
         
            +
             
     | 
| 
      
 701 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 702 
     | 
    
         
            +
                        adata: The annotated data object.
         
     | 
| 
      
 703 
     | 
    
         
            +
                        labels: The column of `.obs` with target gene labels.
         
     | 
| 
      
 704 
     | 
    
         
            +
                        target_gene: Target gene name to visualize perturbation scores for.
         
     | 
| 
      
 705 
     | 
    
         
            +
                        mixscape_class: The column of `.obs` with mixscape classifications.
         
     | 
| 
      
 706 
     | 
    
         
            +
                        color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
         
     | 
| 
      
 707 
     | 
    
         
            +
                        palette: Optional full color palette to overwrite all colors.
         
     | 
| 
      
 708 
     | 
    
         
            +
                        split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
         
     | 
| 
      
 709 
     | 
    
         
            +
                                  the perturbation signature for every replicate separately.
         
     | 
| 
      
 710 
     | 
    
         
            +
                        before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
         
     | 
| 
      
 711 
     | 
    
         
            +
                                         Default is set to NULL and plots cells by original class ID.
         
     | 
| 
      
 712 
     | 
    
         
            +
                        perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
         
     | 
| 
      
 713 
     | 
    
         
            +
             
     | 
| 
      
 714 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 715 
     | 
    
         
            +
                        Visualizing the perturbation scores for the cells in a dataset:
         
     | 
| 
      
 716 
     | 
    
         
            +
             
     | 
| 
      
 717 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 718 
     | 
    
         
            +
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
      
 719 
     | 
    
         
            +
                        >>> ms_pt = pt.tl.Mixscape()
         
     | 
| 
      
 720 
     | 
    
         
            +
                        >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
         
     | 
| 
      
 721 
     | 
    
         
            +
                        >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
         
     | 
| 
      
 722 
     | 
    
         
            +
                        >>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
         
     | 
| 
      
 723 
     | 
    
         
            +
             
     | 
| 
      
 724 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 725 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/mixscape_perturbscore.png
         
     | 
| 
      
 726 
     | 
    
         
            +
                    """
         
     | 
| 
      
 727 
     | 
    
         
            +
                    if "mixscape" not in adata.uns:
         
     | 
| 
      
 728 
     | 
    
         
            +
                        raise ValueError("Please run the `mixscape` function first.")
         
     | 
| 
      
 729 
     | 
    
         
            +
                    perturbation_score = None
         
     | 
| 
      
 730 
     | 
    
         
            +
                    for key in adata.uns["mixscape"][target_gene].keys():
         
     | 
| 
      
 731 
     | 
    
         
            +
                        perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
         
     | 
| 
      
 732 
     | 
    
         
            +
                        perturbation_score_temp["name"] = key
         
     | 
| 
      
 733 
     | 
    
         
            +
                        if perturbation_score is None:
         
     | 
| 
      
 734 
     | 
    
         
            +
                            perturbation_score = copy.deepcopy(perturbation_score_temp)
         
     | 
| 
      
 735 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 736 
     | 
    
         
            +
                            perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
         
     | 
| 
      
 737 
     | 
    
         
            +
                    perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
         
     | 
| 
      
 738 
     | 
    
         
            +
                    gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
         
     | 
| 
      
 739 
     | 
    
         
            +
             
     | 
| 
      
 740 
     | 
    
         
            +
                    # If before_mixscape is True, split densities based on original target gene classification
         
     | 
| 
      
 741 
     | 
    
         
            +
                    if before_mixscape is True:
         
     | 
| 
      
 742 
     | 
    
         
            +
                        palette = {gd: "#7d7d7d", target_gene: color}
         
     | 
| 
      
 743 
     | 
    
         
            +
                        plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
         
     | 
| 
      
 744 
     | 
    
         
            +
                        top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
         
     | 
| 
      
 745 
     | 
    
         
            +
                        plt.close()
         
     | 
| 
      
 746 
     | 
    
         
            +
                        perturbation_score["y_jitter"] = perturbation_score["pvec"]
         
     | 
| 
      
 747 
     | 
    
         
            +
                        rng = np.random.default_rng()
         
     | 
| 
      
 748 
     | 
    
         
            +
                        perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
         
     | 
| 
      
 749 
     | 
    
         
            +
                            low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
         
     | 
| 
      
 750 
     | 
    
         
            +
                        )
         
     | 
| 
      
 751 
     | 
    
         
            +
                        perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
         
     | 
| 
      
 752 
     | 
    
         
            +
                            low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
         
     | 
| 
      
 753 
     | 
    
         
            +
                        )
         
     | 
| 
      
 754 
     | 
    
         
            +
                        # If split_by is provided, split densities based on the split_by
         
     | 
| 
      
 755 
     | 
    
         
            +
                        if split_by is not None:
         
     | 
| 
      
 756 
     | 
    
         
            +
                            sns.set_theme(style="whitegrid")
         
     | 
| 
      
 757 
     | 
    
         
            +
                            g = sns.FacetGrid(
         
     | 
| 
      
 758 
     | 
    
         
            +
                                data=perturbation_score, col=split_by, hue=split_by, palette=palette, height=5, sharey=False
         
     | 
| 
      
 759 
     | 
    
         
            +
                            )
         
     | 
| 
      
 760 
     | 
    
         
            +
                            g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, palette=palette)
         
     | 
| 
      
 761 
     | 
    
         
            +
                            g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5, palette=palette)
         
     | 
| 
      
 762 
     | 
    
         
            +
                            g.set_axis_labels("Perturbation score", "Cell density")
         
     | 
| 
      
 763 
     | 
    
         
            +
                            g.add_legend(title=split_by, fontsize=14, title_fontsize=16)
         
     | 
| 
      
 764 
     | 
    
         
            +
                            g.despine(left=True)
         
     | 
| 
      
 765 
     | 
    
         
            +
             
     | 
| 
      
 766 
     | 
    
         
            +
                        # If split_by is not provided, create a single plot
         
     | 
| 
      
 767 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 768 
     | 
    
         
            +
                            sns.set_theme(style="whitegrid")
         
     | 
| 
      
 769 
     | 
    
         
            +
                            sns.kdeplot(
         
     | 
| 
      
 770 
     | 
    
         
            +
                                data=perturbation_score, x="pvec", hue="gene_target", fill=True, common_norm=False, palette=palette
         
     | 
| 
      
 771 
     | 
    
         
            +
                            )
         
     | 
| 
      
 772 
     | 
    
         
            +
                            sns.scatterplot(
         
     | 
| 
      
 773 
     | 
    
         
            +
                                data=perturbation_score, x="pvec", y="y_jitter", hue="gene_target", palette=palette, s=10, alpha=0.5
         
     | 
| 
      
 774 
     | 
    
         
            +
                            )
         
     | 
| 
      
 775 
     | 
    
         
            +
                            plt.xlabel("Perturbation score", fontsize=16)
         
     | 
| 
      
 776 
     | 
    
         
            +
                            plt.ylabel("Cell density", fontsize=16)
         
     | 
| 
      
 777 
     | 
    
         
            +
                            plt.title("Density Plot", fontsize=18)
         
     | 
| 
      
 778 
     | 
    
         
            +
                            plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
         
     | 
| 
      
 779 
     | 
    
         
            +
                            sns.despine()
         
     | 
| 
      
 780 
     | 
    
         
            +
             
     | 
| 
      
 781 
     | 
    
         
            +
                        if save:
         
     | 
| 
      
 782 
     | 
    
         
            +
                            plt.savefig(save, bbox_inches="tight")
         
     | 
| 
      
 783 
     | 
    
         
            +
                        if show:
         
     | 
| 
      
 784 
     | 
    
         
            +
                            plt.show()
         
     | 
| 
      
 785 
     | 
    
         
            +
                        if return_fig:
         
     | 
| 
      
 786 
     | 
    
         
            +
                            return plt.gcf()
         
     | 
| 
      
 787 
     | 
    
         
            +
                        if not (show or save):
         
     | 
| 
      
 788 
     | 
    
         
            +
                            return plt.gca()
         
     | 
| 
      
 789 
     | 
    
         
            +
             
     | 
| 
      
 790 
     | 
    
         
            +
                    # If before_mixscape is False, split densities based on mixscape classifications
         
     | 
| 
      
 791 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 792 
     | 
    
         
            +
                        if palette is None:
         
     | 
| 
      
 793 
     | 
    
         
            +
                            palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
         
     | 
| 
      
 794 
     | 
    
         
            +
                        plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
         
     | 
| 
      
 795 
     | 
    
         
            +
                        top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
         
     | 
| 
      
 796 
     | 
    
         
            +
                        plt.close()
         
     | 
| 
      
 797 
     | 
    
         
            +
                        perturbation_score["y_jitter"] = perturbation_score["pvec"]
         
     | 
| 
      
 798 
     | 
    
         
            +
                        rng = np.random.default_rng()
         
     | 
| 
      
 799 
     | 
    
         
            +
                        gd2 = list(
         
     | 
| 
      
 800 
     | 
    
         
            +
                            set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
         
     | 
| 
      
 801 
     | 
    
         
            +
                        )[0]
         
     | 
| 
      
 802 
     | 
    
         
            +
                        perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
         
     | 
| 
      
 803 
     | 
    
         
            +
                            low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
         
     | 
| 
      
 804 
     | 
    
         
            +
                        ).astype(np.float32)
         
     | 
| 
      
 805 
     | 
    
         
            +
                        perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"] = (
         
     | 
| 
      
 806 
     | 
    
         
            +
                            rng.uniform(
         
     | 
| 
      
 807 
     | 
    
         
            +
                                low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
         
     | 
| 
      
 808 
     | 
    
         
            +
                            )
         
     | 
| 
      
 809 
     | 
    
         
            +
                        )
         
     | 
| 
      
 810 
     | 
    
         
            +
                        perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
         
     | 
| 
      
 811 
     | 
    
         
            +
                            low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
         
     | 
| 
      
 812 
     | 
    
         
            +
                        )
         
     | 
| 
      
 813 
     | 
    
         
            +
                        # If split_by is provided, split densities based on the split_by
         
     | 
| 
      
 814 
     | 
    
         
            +
                        if split_by is not None:
         
     | 
| 
      
 815 
     | 
    
         
            +
                            sns.set_theme(style="whitegrid")
         
     | 
| 
      
 816 
     | 
    
         
            +
                            g = sns.FacetGrid(
         
     | 
| 
      
 817 
     | 
    
         
            +
                                data=perturbation_score, col=split_by, hue="mix", palette=palette, height=5, sharey=False
         
     | 
| 
      
 818 
     | 
    
         
            +
                            )
         
     | 
| 
      
 819 
     | 
    
         
            +
                            g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, alpha=0.7)
         
     | 
| 
      
 820 
     | 
    
         
            +
                            g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5)
         
     | 
| 
      
 821 
     | 
    
         
            +
                            g.set_axis_labels("Perturbation score", "Cell density")
         
     | 
| 
      
 822 
     | 
    
         
            +
                            g.add_legend(title="mix", fontsize=14, title_fontsize=16)
         
     | 
| 
      
 823 
     | 
    
         
            +
                            g.despine(left=True)
         
     | 
| 
      
 824 
     | 
    
         
            +
             
     | 
| 
      
 825 
     | 
    
         
            +
                        # If split_by is not provided, create a single plot
         
     | 
| 
      
 826 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 827 
     | 
    
         
            +
                            sns.set_theme(style="whitegrid")
         
     | 
| 
      
 828 
     | 
    
         
            +
                            sns.kdeplot(
         
     | 
| 
      
 829 
     | 
    
         
            +
                                data=perturbation_score,
         
     | 
| 
      
 830 
     | 
    
         
            +
                                x="pvec",
         
     | 
| 
      
 831 
     | 
    
         
            +
                                hue="mix",
         
     | 
| 
      
 832 
     | 
    
         
            +
                                fill=True,
         
     | 
| 
      
 833 
     | 
    
         
            +
                                common_norm=False,
         
     | 
| 
      
 834 
     | 
    
         
            +
                                palette=palette,
         
     | 
| 
      
 835 
     | 
    
         
            +
                                alpha=0.7,
         
     | 
| 
      
 836 
     | 
    
         
            +
                            )
         
     | 
| 
      
 837 
     | 
    
         
            +
                            sns.scatterplot(
         
     | 
| 
      
 838 
     | 
    
         
            +
                                data=perturbation_score, x="pvec", y="y_jitter", hue="mix", palette=palette, s=10, alpha=0.5
         
     | 
| 
      
 839 
     | 
    
         
            +
                            )
         
     | 
| 
      
 840 
     | 
    
         
            +
                            plt.xlabel("Perturbation score", fontsize=16)
         
     | 
| 
      
 841 
     | 
    
         
            +
                            plt.ylabel("Cell density", fontsize=16)
         
     | 
| 
      
 842 
     | 
    
         
            +
                            plt.title("Density", fontsize=18)
         
     | 
| 
      
 843 
     | 
    
         
            +
                            plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
         
     | 
| 
      
 844 
     | 
    
         
            +
                            sns.despine()
         
     | 
| 
      
 845 
     | 
    
         
            +
             
     | 
| 
      
 846 
     | 
    
         
            +
                        if save:
         
     | 
| 
      
 847 
     | 
    
         
            +
                            plt.savefig(save, bbox_inches="tight")
         
     | 
| 
      
 848 
     | 
    
         
            +
                        if show:
         
     | 
| 
      
 849 
     | 
    
         
            +
                            plt.show()
         
     | 
| 
      
 850 
     | 
    
         
            +
                        if return_fig:
         
     | 
| 
      
 851 
     | 
    
         
            +
                            return plt.gcf()
         
     | 
| 
      
 852 
     | 
    
         
            +
                        if not (show or save):
         
     | 
| 
      
 853 
     | 
    
         
            +
                            return plt.gca()
         
     | 
| 
      
 854 
     | 
    
         
            +
             
     | 
| 
      
 855 
     | 
    
         
            +
                def plot_violin(  # pragma: no cover
         
     | 
| 
      
 856 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 857 
     | 
    
         
            +
                    adata: AnnData,
         
     | 
| 
      
 858 
     | 
    
         
            +
                    target_gene_idents: str | list[str],
         
     | 
| 
      
 859 
     | 
    
         
            +
                    keys: str | Sequence[str] = "mixscape_class_p_ko",
         
     | 
| 
      
 860 
     | 
    
         
            +
                    groupby: str | None = "mixscape_class",
         
     | 
| 
      
 861 
     | 
    
         
            +
                    log: bool = False,
         
     | 
| 
      
 862 
     | 
    
         
            +
                    use_raw: bool | None = None,
         
     | 
| 
      
 863 
     | 
    
         
            +
                    stripplot: bool = True,
         
     | 
| 
      
 864 
     | 
    
         
            +
                    hue: str | None = None,
         
     | 
| 
      
 865 
     | 
    
         
            +
                    jitter: float | bool = True,
         
     | 
| 
      
 866 
     | 
    
         
            +
                    size: int = 1,
         
     | 
| 
      
 867 
     | 
    
         
            +
                    layer: str | None = None,
         
     | 
| 
      
 868 
     | 
    
         
            +
                    scale: Literal["area", "count", "width"] = "width",
         
     | 
| 
      
 869 
     | 
    
         
            +
                    order: Sequence[str] | None = None,
         
     | 
| 
      
 870 
     | 
    
         
            +
                    multi_panel: bool | None = None,
         
     | 
| 
      
 871 
     | 
    
         
            +
                    xlabel: str = "",
         
     | 
| 
      
 872 
     | 
    
         
            +
                    ylabel: str | Sequence[str] | None = None,
         
     | 
| 
      
 873 
     | 
    
         
            +
                    rotation: float | None = None,
         
     | 
| 
      
 874 
     | 
    
         
            +
                    ax: Axes | None = None,
         
     | 
| 
      
 875 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 876 
     | 
    
         
            +
                    save: bool | str | None = None,
         
     | 
| 
      
 877 
     | 
    
         
            +
                    **kwargs,
         
     | 
| 
      
 878 
     | 
    
         
            +
                ):
         
     | 
| 
      
 879 
     | 
    
         
            +
                    """Violin plot using mixscape results.
         
     | 
| 
      
 880 
     | 
    
         
            +
             
     | 
| 
      
 881 
     | 
    
         
            +
                    Requires `pt.tl.mixscape` to be run first.
         
     | 
| 
      
 882 
     | 
    
         
            +
             
     | 
| 
      
 883 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 884 
     | 
    
         
            +
                        adata: The annotated data object.
         
     | 
| 
      
 885 
     | 
    
         
            +
                        target_gene_idents: Target gene name to plot.
         
     | 
| 
      
 886 
     | 
    
         
            +
                        keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
         
     | 
| 
      
 887 
     | 
    
         
            +
                        groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
         
     | 
| 
      
 888 
     | 
    
         
            +
                        log: Plot on logarithmic axis.
         
     | 
| 
      
 889 
     | 
    
         
            +
                        use_raw: Whether to use `raw` attribute of `adata`.
         
     | 
| 
      
 890 
     | 
    
         
            +
                        stripplot: Add a stripplot on top of the violin plot.
         
     | 
| 
      
 891 
     | 
    
         
            +
                        order: Order in which to show the categories.
         
     | 
| 
      
 892 
     | 
    
         
            +
                        xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
         
     | 
| 
      
 893 
     | 
    
         
            +
                        ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
         
     | 
| 
      
 894 
     | 
    
         
            +
                                If `None` and `groubpy` is not `None`, defaults to `keys`.
         
     | 
| 
      
 895 
     | 
    
         
            +
                        show: Show the plot, do not return axis.
         
     | 
| 
      
 896 
     | 
    
         
            +
                        save: If `True` or a `str`, save the figure. A string is appended to the default filename.
         
     | 
| 
      
 897 
     | 
    
         
            +
                              Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
         
     | 
| 
      
 898 
     | 
    
         
            +
                        ax: A matplotlib axes object. Only works if plotting a single component.
         
     | 
| 
      
 899 
     | 
    
         
            +
                        **kwargs: Additional arguments to `seaborn.violinplot`.
         
     | 
| 
      
 900 
     | 
    
         
            +
             
     | 
| 
      
 901 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 902 
     | 
    
         
            +
                        A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
         
     | 
| 
      
 903 
     | 
    
         
            +
             
     | 
| 
      
 904 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 905 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 906 
     | 
    
         
            +
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
      
 907 
     | 
    
         
            +
                        >>> ms_pt = pt.tl.Mixscape()
         
     | 
| 
      
 908 
     | 
    
         
            +
                        >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
         
     | 
| 
      
 909 
     | 
    
         
            +
                        >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
         
     | 
| 
      
 910 
     | 
    
         
            +
                        >>> ms_pt.plot_violin(
         
     | 
| 
      
 911 
     | 
    
         
            +
                        ...     adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
         
     | 
| 
      
 912 
     | 
    
         
            +
                        ... )
         
     | 
| 
      
 913 
     | 
    
         
            +
             
     | 
| 
      
 914 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 915 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/mixscape_violin.png
         
     | 
| 
      
 916 
     | 
    
         
            +
                    """
         
     | 
| 
      
 917 
     | 
    
         
            +
                    if isinstance(target_gene_idents, str):
         
     | 
| 
      
 918 
     | 
    
         
            +
                        mixscape_class_mask = adata.obs[groupby] == target_gene_idents
         
     | 
| 
      
 919 
     | 
    
         
            +
                    elif isinstance(target_gene_idents, list):
         
     | 
| 
      
 920 
     | 
    
         
            +
                        mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
         
     | 
| 
      
 921 
     | 
    
         
            +
                        for ident in target_gene_idents:
         
     | 
| 
      
 922 
     | 
    
         
            +
                            mixscape_class_mask |= adata.obs[groupby] == ident
         
     | 
| 
      
 923 
     | 
    
         
            +
                    adata = adata[mixscape_class_mask]
         
     | 
| 
      
 924 
     | 
    
         
            +
             
     | 
| 
      
 925 
     | 
    
         
            +
                    sanitize_anndata(adata)
         
     | 
| 
      
 926 
     | 
    
         
            +
                    use_raw = _check_use_raw(adata, use_raw)
         
     | 
| 
      
 927 
     | 
    
         
            +
                    if isinstance(keys, str):
         
     | 
| 
      
 928 
     | 
    
         
            +
                        keys = [keys]
         
     | 
| 
      
 929 
     | 
    
         
            +
                    keys = list(OrderedDict.fromkeys(keys))  # remove duplicates, preserving the order
         
     | 
| 
      
 930 
     | 
    
         
            +
             
     | 
| 
      
 931 
     | 
    
         
            +
                    if isinstance(ylabel, str | type(None)):
         
     | 
| 
      
 932 
     | 
    
         
            +
                        ylabel = [ylabel] * (1 if groupby is None else len(keys))
         
     | 
| 
      
 933 
     | 
    
         
            +
                    if groupby is None:
         
     | 
| 
      
 934 
     | 
    
         
            +
                        if len(ylabel) != 1:
         
     | 
| 
      
 935 
     | 
    
         
            +
                            raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
         
     | 
| 
      
 936 
     | 
    
         
            +
                    elif len(ylabel) != len(keys):
         
     | 
| 
      
 937 
     | 
    
         
            +
                        raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
         
     | 
| 
      
 938 
     | 
    
         
            +
             
     | 
| 
      
 939 
     | 
    
         
            +
                    if groupby is not None:
         
     | 
| 
      
 940 
     | 
    
         
            +
                        if hue is not None:
         
     | 
| 
      
 941 
     | 
    
         
            +
                            obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
         
     | 
| 
      
 942 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 943 
     | 
    
         
            +
                            obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
         
     | 
| 
      
 944 
     | 
    
         
            +
             
     | 
| 
      
 945 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 946 
     | 
    
         
            +
                        obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
         
     | 
| 
      
 947 
     | 
    
         
            +
                    if groupby is None:
         
     | 
| 
      
 948 
     | 
    
         
            +
                        obs_tidy = pd.melt(obs_df, value_vars=keys)
         
     | 
| 
      
 949 
     | 
    
         
            +
                        x = "variable"
         
     | 
| 
      
 950 
     | 
    
         
            +
                        ys = ["value"]
         
     | 
| 
      
 951 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 952 
     | 
    
         
            +
                        obs_tidy = obs_df
         
     | 
| 
      
 953 
     | 
    
         
            +
                        x = groupby
         
     | 
| 
      
 954 
     | 
    
         
            +
                        ys = keys
         
     | 
| 
      
 955 
     | 
    
         
            +
             
     | 
| 
      
 956 
     | 
    
         
            +
                    if multi_panel and groupby is None and len(ys) == 1:
         
     | 
| 
      
 957 
     | 
    
         
            +
                        # This is a quick and dirty way for adapting scales across several
         
     | 
| 
      
 958 
     | 
    
         
            +
                        # keys if groupby is None.
         
     | 
| 
      
 959 
     | 
    
         
            +
                        y = ys[0]
         
     | 
| 
      
 960 
     | 
    
         
            +
             
     | 
| 
      
 961 
     | 
    
         
            +
                        g = sns.catplot(
         
     | 
| 
      
 962 
     | 
    
         
            +
                            y=y,
         
     | 
| 
      
 963 
     | 
    
         
            +
                            data=obs_tidy,
         
     | 
| 
      
 964 
     | 
    
         
            +
                            kind="violin",
         
     | 
| 
      
 965 
     | 
    
         
            +
                            scale=scale,
         
     | 
| 
      
 966 
     | 
    
         
            +
                            col=x,
         
     | 
| 
      
 967 
     | 
    
         
            +
                            col_order=keys,
         
     | 
| 
      
 968 
     | 
    
         
            +
                            sharey=False,
         
     | 
| 
      
 969 
     | 
    
         
            +
                            order=keys,
         
     | 
| 
      
 970 
     | 
    
         
            +
                            cut=0,
         
     | 
| 
      
 971 
     | 
    
         
            +
                            inner=None,
         
     | 
| 
      
 972 
     | 
    
         
            +
                            **kwargs,
         
     | 
| 
      
 973 
     | 
    
         
            +
                        )
         
     | 
| 
      
 974 
     | 
    
         
            +
             
     | 
| 
      
 975 
     | 
    
         
            +
                        if stripplot:
         
     | 
| 
      
 976 
     | 
    
         
            +
                            grouped_df = obs_tidy.groupby(x)
         
     | 
| 
      
 977 
     | 
    
         
            +
                            for ax_id, key in zip(range(g.axes.shape[1]), keys, strict=False):
         
     | 
| 
      
 978 
     | 
    
         
            +
                                sns.stripplot(
         
     | 
| 
      
 979 
     | 
    
         
            +
                                    y=y,
         
     | 
| 
      
 980 
     | 
    
         
            +
                                    data=grouped_df.get_group(key),
         
     | 
| 
      
 981 
     | 
    
         
            +
                                    jitter=jitter,
         
     | 
| 
      
 982 
     | 
    
         
            +
                                    size=size,
         
     | 
| 
      
 983 
     | 
    
         
            +
                                    color="black",
         
     | 
| 
      
 984 
     | 
    
         
            +
                                    ax=g.axes[0, ax_id],
         
     | 
| 
      
 985 
     | 
    
         
            +
                                )
         
     | 
| 
      
 986 
     | 
    
         
            +
                        if log:
         
     | 
| 
      
 987 
     | 
    
         
            +
                            g.set(yscale="log")
         
     | 
| 
      
 988 
     | 
    
         
            +
                        g.set_titles(col_template="{col_name}").set_xlabels("")
         
     | 
| 
      
 989 
     | 
    
         
            +
                        if rotation is not None:
         
     | 
| 
      
 990 
     | 
    
         
            +
                            for ax in g.axes[0]:
         
     | 
| 
      
 991 
     | 
    
         
            +
                                ax.tick_params(axis="x", labelrotation=rotation)
         
     | 
| 
      
 992 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 993 
     | 
    
         
            +
                        # set by default the violin plot cut=0 to limit the extend
         
     | 
| 
      
 994 
     | 
    
         
            +
                        # of the violin plot (see stacked_violin code) for more info.
         
     | 
| 
      
 995 
     | 
    
         
            +
                        kwargs.setdefault("cut", 0)
         
     | 
| 
      
 996 
     | 
    
         
            +
                        kwargs.setdefault("inner")
         
     | 
| 
      
 997 
     | 
    
         
            +
             
     | 
| 
      
 998 
     | 
    
         
            +
                        if ax is None:
         
     | 
| 
      
 999 
     | 
    
         
            +
                            axs, _, _, _ = _utils.setup_axes(
         
     | 
| 
      
 1000 
     | 
    
         
            +
                                ax=ax,
         
     | 
| 
      
 1001 
     | 
    
         
            +
                                panels=["x"] if groupby is None else keys,
         
     | 
| 
      
 1002 
     | 
    
         
            +
                                show_ticks=True,
         
     | 
| 
      
 1003 
     | 
    
         
            +
                                right_margin=0.3,
         
     | 
| 
      
 1004 
     | 
    
         
            +
                            )
         
     | 
| 
      
 1005 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 1006 
     | 
    
         
            +
                            axs = [ax]
         
     | 
| 
      
 1007 
     | 
    
         
            +
                        for ax, y, ylab in zip(axs, ys, ylabel, strict=False):
         
     | 
| 
      
 1008 
     | 
    
         
            +
                            ax = sns.violinplot(
         
     | 
| 
      
 1009 
     | 
    
         
            +
                                x=x,
         
     | 
| 
      
 1010 
     | 
    
         
            +
                                y=y,
         
     | 
| 
      
 1011 
     | 
    
         
            +
                                data=obs_tidy,
         
     | 
| 
      
 1012 
     | 
    
         
            +
                                order=order,
         
     | 
| 
      
 1013 
     | 
    
         
            +
                                orient="vertical",
         
     | 
| 
      
 1014 
     | 
    
         
            +
                                scale=scale,
         
     | 
| 
      
 1015 
     | 
    
         
            +
                                ax=ax,
         
     | 
| 
      
 1016 
     | 
    
         
            +
                                hue=hue,
         
     | 
| 
      
 1017 
     | 
    
         
            +
                                **kwargs,
         
     | 
| 
      
 1018 
     | 
    
         
            +
                            )
         
     | 
| 
      
 1019 
     | 
    
         
            +
                            # Get the handles and labels.
         
     | 
| 
      
 1020 
     | 
    
         
            +
                            handles, labels = ax.get_legend_handles_labels()
         
     | 
| 
      
 1021 
     | 
    
         
            +
                            if stripplot:
         
     | 
| 
      
 1022 
     | 
    
         
            +
                                ax = sns.stripplot(
         
     | 
| 
      
 1023 
     | 
    
         
            +
                                    x=x,
         
     | 
| 
      
 1024 
     | 
    
         
            +
                                    y=y,
         
     | 
| 
      
 1025 
     | 
    
         
            +
                                    data=obs_tidy,
         
     | 
| 
      
 1026 
     | 
    
         
            +
                                    order=order,
         
     | 
| 
      
 1027 
     | 
    
         
            +
                                    jitter=jitter,
         
     | 
| 
      
 1028 
     | 
    
         
            +
                                    color="black",
         
     | 
| 
      
 1029 
     | 
    
         
            +
                                    size=size,
         
     | 
| 
      
 1030 
     | 
    
         
            +
                                    ax=ax,
         
     | 
| 
      
 1031 
     | 
    
         
            +
                                    hue=hue,
         
     | 
| 
      
 1032 
     | 
    
         
            +
                                    dodge=True,
         
     | 
| 
      
 1033 
     | 
    
         
            +
                                )
         
     | 
| 
      
 1034 
     | 
    
         
            +
                            if xlabel == "" and groupby is not None and rotation is None:
         
     | 
| 
      
 1035 
     | 
    
         
            +
                                xlabel = groupby.replace("_", " ")
         
     | 
| 
      
 1036 
     | 
    
         
            +
                            ax.set_xlabel(xlabel)
         
     | 
| 
      
 1037 
     | 
    
         
            +
                            if ylab is not None:
         
     | 
| 
      
 1038 
     | 
    
         
            +
                                ax.set_ylabel(ylab)
         
     | 
| 
      
 1039 
     | 
    
         
            +
             
     | 
| 
      
 1040 
     | 
    
         
            +
                            if log:
         
     | 
| 
      
 1041 
     | 
    
         
            +
                                ax.set_yscale("log")
         
     | 
| 
      
 1042 
     | 
    
         
            +
                            if rotation is not None:
         
     | 
| 
      
 1043 
     | 
    
         
            +
                                ax.tick_params(axis="x", labelrotation=rotation)
         
     | 
| 
      
 1044 
     | 
    
         
            +
             
     | 
| 
      
 1045 
     | 
    
         
            +
                    show = settings.autoshow if show is None else show
         
     | 
| 
      
 1046 
     | 
    
         
            +
                    if hue is not None and stripplot is True:
         
     | 
| 
      
 1047 
     | 
    
         
            +
                        plt.legend(handles, labels)
         
     | 
| 
      
 1048 
     | 
    
         
            +
                    _utils.savefig_or_show("mixscape_violin", show=show, save=save)
         
     | 
| 
      
 1049 
     | 
    
         
            +
             
     | 
| 
      
 1050 
     | 
    
         
            +
                    if not show:
         
     | 
| 
      
 1051 
     | 
    
         
            +
                        if multi_panel and groupby is None and len(ys) == 1:
         
     | 
| 
      
 1052 
     | 
    
         
            +
                            return g
         
     | 
| 
      
 1053 
     | 
    
         
            +
                        elif len(axs) == 1:
         
     | 
| 
      
 1054 
     | 
    
         
            +
                            return axs[0]
         
     | 
| 
      
 1055 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 1056 
     | 
    
         
            +
                            return axs
         
     | 
| 
      
 1057 
     | 
    
         
            +
             
     | 
| 
      
 1058 
     | 
    
         
            +
                def plot_lda(  # pragma: no cover
         
     | 
| 
      
 1059 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1060 
     | 
    
         
            +
                    adata: AnnData,
         
     | 
| 
      
 1061 
     | 
    
         
            +
                    control: str,
         
     | 
| 
      
 1062 
     | 
    
         
            +
                    mixscape_class: str = "mixscape_class",
         
     | 
| 
      
 1063 
     | 
    
         
            +
                    mixscape_class_global: str = "mixscape_class_global",
         
     | 
| 
      
 1064 
     | 
    
         
            +
                    perturbation_type: str | None = "KO",
         
     | 
| 
      
 1065 
     | 
    
         
            +
                    lda_key: str | None = "mixscape_lda",
         
     | 
| 
      
 1066 
     | 
    
         
            +
                    n_components: int | None = None,
         
     | 
| 
      
 1067 
     | 
    
         
            +
                    color_map: Colormap | str | None = None,
         
     | 
| 
      
 1068 
     | 
    
         
            +
                    palette: str | Sequence[str] | None = None,
         
     | 
| 
      
 1069 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 1070 
     | 
    
         
            +
                    ax: Axes | None = None,
         
     | 
| 
      
 1071 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 1072 
     | 
    
         
            +
                    save: bool | str | None = None,
         
     | 
| 
      
 1073 
     | 
    
         
            +
                    **kwds,
         
     | 
| 
      
 1074 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 1075 
     | 
    
         
            +
                    """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
         
     | 
| 
      
 1076 
     | 
    
         
            +
             
     | 
| 
      
 1077 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 1078 
     | 
    
         
            +
                        adata: The annotated data object.
         
     | 
| 
      
 1079 
     | 
    
         
            +
                        control: Control category from the `pert_key` column.
         
     | 
| 
      
 1080 
     | 
    
         
            +
                        mixscape_class: The column of `.obs` with the mixscape classification result.
         
     | 
| 
      
 1081 
     | 
    
         
            +
                        mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
         
     | 
| 
      
 1082 
     | 
    
         
            +
                        perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
         
     | 
| 
      
 1083 
     | 
    
         
            +
                        lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
         
     | 
| 
      
 1084 
     | 
    
         
            +
                        n_components: The number of dimensions of the embedding.
         
     | 
| 
      
 1085 
     | 
    
         
            +
                        show: Show the plot, do not return axis.
         
     | 
| 
      
 1086 
     | 
    
         
            +
                        save: If `True` or a `str`, save the figure. A string is appended to the default filename.
         
     | 
| 
      
 1087 
     | 
    
         
            +
                              Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
         
     | 
| 
      
 1088 
     | 
    
         
            +
                        **kwds: Additional arguments to `scanpy.pl.umap`.
         
     | 
| 
      
 1089 
     | 
    
         
            +
             
     | 
| 
      
 1090 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 1091 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 1092 
     | 
    
         
            +
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
      
 1093 
     | 
    
         
            +
                        >>> ms_pt = pt.tl.Mixscape()
         
     | 
| 
      
 1094 
     | 
    
         
            +
                        >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
         
     | 
| 
      
 1095 
     | 
    
         
            +
                        >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
         
     | 
| 
      
 1096 
     | 
    
         
            +
                        >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
         
     | 
| 
      
 1097 
     | 
    
         
            +
                        >>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
         
     | 
| 
      
 1098 
     | 
    
         
            +
             
     | 
| 
      
 1099 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 1100 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/mixscape_lda.png
         
     | 
| 
      
 1101 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1102 
     | 
    
         
            +
                    if mixscape_class not in adata.obs:
         
     | 
| 
      
 1103 
     | 
    
         
            +
                        raise ValueError(f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first.')
         
     | 
| 
      
 1104 
     | 
    
         
            +
                    if lda_key not in adata.uns:
         
     | 
| 
      
 1105 
     | 
    
         
            +
                        raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Please run the `lda` function first.')
         
     | 
| 
      
 1106 
     | 
    
         
            +
             
     | 
| 
      
 1107 
     | 
    
         
            +
                    adata_subset = adata[
         
     | 
| 
      
 1108 
     | 
    
         
            +
                        (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
         
     | 
| 
      
 1109 
     | 
    
         
            +
                    ].copy()
         
     | 
| 
      
 1110 
     | 
    
         
            +
                    adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
         
     | 
| 
      
 1111 
     | 
    
         
            +
                    if n_components is None:
         
     | 
| 
      
 1112 
     | 
    
         
            +
                        n_components = adata_subset.uns[lda_key].shape[1]
         
     | 
| 
      
 1113 
     | 
    
         
            +
                    sc.pp.neighbors(adata_subset, use_rep=lda_key)
         
     | 
| 
      
 1114 
     | 
    
         
            +
                    sc.tl.umap(adata_subset, n_components=n_components)
         
     | 
| 
      
 1115 
     | 
    
         
            +
                    sc.pl.umap(
         
     | 
| 
      
 1116 
     | 
    
         
            +
                        adata_subset,
         
     | 
| 
      
 1117 
     | 
    
         
            +
                        color=mixscape_class,
         
     | 
| 
      
 1118 
     | 
    
         
            +
                        palette=palette,
         
     | 
| 
      
 1119 
     | 
    
         
            +
                        color_map=color_map,
         
     | 
| 
      
 1120 
     | 
    
         
            +
                        return_fig=return_fig,
         
     | 
| 
      
 1121 
     | 
    
         
            +
                        show=show,
         
     | 
| 
      
 1122 
     | 
    
         
            +
                        save=save,
         
     | 
| 
      
 1123 
     | 
    
         
            +
                        ax=ax,
         
     | 
| 
      
 1124 
     | 
    
         
            +
                        **kwds,
         
     | 
| 
      
 1125 
     | 
    
         
            +
                    )
         
     |