pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
 - pertpy/data/__init__.py +66 -1
 - pertpy/data/_dataloader.py +28 -26
 - pertpy/data/_datasets.py +261 -92
 - pertpy/metadata/__init__.py +6 -0
 - pertpy/metadata/_cell_line.py +795 -0
 - pertpy/metadata/_compound.py +128 -0
 - pertpy/metadata/_drug.py +238 -0
 - pertpy/metadata/_look_up.py +569 -0
 - pertpy/metadata/_metadata.py +70 -0
 - pertpy/metadata/_moa.py +125 -0
 - pertpy/plot/__init__.py +0 -13
 - pertpy/preprocessing/__init__.py +2 -0
 - pertpy/preprocessing/_guide_rna.py +89 -6
 - pertpy/tools/__init__.py +48 -15
 - pertpy/tools/_augur.py +329 -32
 - pertpy/tools/_cinemaot.py +145 -6
 - pertpy/tools/_coda/_base_coda.py +1237 -116
 - pertpy/tools/_coda/_sccoda.py +66 -36
 - pertpy/tools/_coda/_tasccoda.py +46 -39
 - pertpy/tools/_dialogue.py +180 -77
 - pertpy/tools/_differential_gene_expression/__init__.py +20 -0
 - pertpy/tools/_differential_gene_expression/_base.py +657 -0
 - pertpy/tools/_differential_gene_expression/_checks.py +41 -0
 - pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
 - pertpy/tools/_differential_gene_expression/_edger.py +125 -0
 - pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
 - pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
 - pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
 - pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
 - pertpy/tools/_distances/_distance_tests.py +29 -24
 - pertpy/tools/_distances/_distances.py +584 -98
 - pertpy/tools/_enrichment.py +460 -0
 - pertpy/tools/_kernel_pca.py +1 -1
 - pertpy/tools/_milo.py +406 -49
 - pertpy/tools/_mixscape.py +677 -55
 - pertpy/tools/_perturbation_space/_clustering.py +10 -3
 - pertpy/tools/_perturbation_space/_comparison.py +112 -0
 - pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
 - pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
 - pertpy/tools/_perturbation_space/_simple.py +52 -11
 - pertpy/tools/_scgen/__init__.py +1 -1
 - pertpy/tools/_scgen/_base_components.py +2 -3
 - pertpy/tools/_scgen/_scgen.py +706 -0
 - pertpy/tools/_scgen/_utils.py +3 -5
 - pertpy/tools/decoupler_LICENSE +674 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
 - pertpy-0.8.0.dist-info/RECORD +57 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
 - pertpy/plot/_augur.py +0 -234
 - pertpy/plot/_cinemaot.py +0 -81
 - pertpy/plot/_coda.py +0 -1001
 - pertpy/plot/_dialogue.py +0 -91
 - pertpy/plot/_guide_rna.py +0 -82
 - pertpy/plot/_milopy.py +0 -284
 - pertpy/plot/_mixscape.py +0 -594
 - pertpy/plot/_scgen.py +0 -337
 - pertpy/tools/_differential_gene_expression.py +0 -99
 - pertpy/tools/_metadata/__init__.py +0 -0
 - pertpy/tools/_metadata/_cell_line.py +0 -613
 - pertpy/tools/_metadata/_look_up.py +0 -342
 - pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
 - pertpy/tools/_scgen/_jax_scgen.py +0 -370
 - pertpy-0.6.0.dist-info/RECORD +0 -50
 - /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
 
    
        pertpy/plot/_mixscape.py
    DELETED
    
    | 
         @@ -1,594 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from __future__ import annotations
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            import copy
         
     | 
| 
       4 
     | 
    
         
            -
            from collections import OrderedDict
         
     | 
| 
       5 
     | 
    
         
            -
            from typing import TYPE_CHECKING, Literal
         
     | 
| 
       6 
     | 
    
         
            -
             
     | 
| 
       7 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       8 
     | 
    
         
            -
            import pandas as pd
         
     | 
| 
       9 
     | 
    
         
            -
            import scanpy as sc
         
     | 
| 
       10 
     | 
    
         
            -
            from matplotlib import pyplot as pl
         
     | 
| 
       11 
     | 
    
         
            -
            from plotnine import (
         
     | 
| 
       12 
     | 
    
         
            -
                aes,
         
     | 
| 
       13 
     | 
    
         
            -
                element_blank,
         
     | 
| 
       14 
     | 
    
         
            -
                element_text,
         
     | 
| 
       15 
     | 
    
         
            -
                facet_wrap,
         
     | 
| 
       16 
     | 
    
         
            -
                geom_bar,
         
     | 
| 
       17 
     | 
    
         
            -
                geom_density,
         
     | 
| 
       18 
     | 
    
         
            -
                geom_point,
         
     | 
| 
       19 
     | 
    
         
            -
                ggplot,
         
     | 
| 
       20 
     | 
    
         
            -
                labs,
         
     | 
| 
       21 
     | 
    
         
            -
                scale_color_manual,
         
     | 
| 
       22 
     | 
    
         
            -
                scale_fill_manual,
         
     | 
| 
       23 
     | 
    
         
            -
                theme,
         
     | 
| 
       24 
     | 
    
         
            -
                theme_classic,
         
     | 
| 
       25 
     | 
    
         
            -
                xlab,
         
     | 
| 
       26 
     | 
    
         
            -
                ylab,
         
     | 
| 
       27 
     | 
    
         
            -
            )
         
     | 
| 
       28 
     | 
    
         
            -
            from scanpy import get
         
     | 
| 
       29 
     | 
    
         
            -
            from scanpy._settings import settings
         
     | 
| 
       30 
     | 
    
         
            -
            from scanpy._utils import _check_use_raw, sanitize_anndata
         
     | 
| 
       31 
     | 
    
         
            -
            from scanpy.plotting import _utils
         
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
            if TYPE_CHECKING:
         
     | 
| 
       34 
     | 
    
         
            -
                from collections.abc import Sequence
         
     | 
| 
       35 
     | 
    
         
            -
             
     | 
| 
       36 
     | 
    
         
            -
                from anndata import AnnData
         
     | 
| 
       37 
     | 
    
         
            -
                from matplotlib.axes import Axes
         
     | 
| 
       38 
     | 
    
         
            -
             
     | 
| 
       39 
     | 
    
         
            -
             
     | 
| 
       40 
     | 
    
         
            -
            class MixscapePlot:
         
     | 
| 
       41 
     | 
    
         
            -
                """Plotting functions for Mixscape."""
         
     | 
| 
       42 
     | 
    
         
            -
             
     | 
| 
       43 
     | 
    
         
            -
                @staticmethod
         
     | 
| 
       44 
     | 
    
         
            -
                def barplot(  # pragma: no cover
         
     | 
| 
       45 
     | 
    
         
            -
                    adata: AnnData,
         
     | 
| 
       46 
     | 
    
         
            -
                    guide_rna_column: str,
         
     | 
| 
       47 
     | 
    
         
            -
                    mixscape_class_global="mixscape_class_global",
         
     | 
| 
       48 
     | 
    
         
            -
                    axis_text_x_size: int = 8,
         
     | 
| 
       49 
     | 
    
         
            -
                    axis_text_y_size: int = 6,
         
     | 
| 
       50 
     | 
    
         
            -
                    axis_title_size: int = 8,
         
     | 
| 
       51 
     | 
    
         
            -
                    strip_text_size: int = 6,
         
     | 
| 
       52 
     | 
    
         
            -
                    panel_spacing_x: float = 0.3,
         
     | 
| 
       53 
     | 
    
         
            -
                    panel_spacing_y: float = 0.3,
         
     | 
| 
       54 
     | 
    
         
            -
                    legend_title_size: int = 8,
         
     | 
| 
       55 
     | 
    
         
            -
                    legend_text_size: int = 8,
         
     | 
| 
       56 
     | 
    
         
            -
                    show: bool | None = None,
         
     | 
| 
       57 
     | 
    
         
            -
                    save: bool | str | None = None,
         
     | 
| 
       58 
     | 
    
         
            -
                ):
         
     | 
| 
       59 
     | 
    
         
            -
                    """Barplot to visualize perturbation scores calculated from RunMixscape function.
         
     | 
| 
       60 
     | 
    
         
            -
             
     | 
| 
       61 
     | 
    
         
            -
                    Args:
         
     | 
| 
       62 
     | 
    
         
            -
                        adata: The annotated data object.
         
     | 
| 
       63 
     | 
    
         
            -
                        guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
         
     | 
| 
       64 
     | 
    
         
            -
                                          The format must be <gene_target>g<#>. For example, 'STAT2g1' and 'ATF2g1'.
         
     | 
| 
       65 
     | 
    
         
            -
                        mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
         
     | 
| 
       66 
     | 
    
         
            -
                        show: Show the plot, do not return axis.
         
     | 
| 
       67 
     | 
    
         
            -
                        save: If True or a str, save the figure. A string is appended to the default filename.
         
     | 
| 
       68 
     | 
    
         
            -
                              Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
         
     | 
| 
       69 
     | 
    
         
            -
             
     | 
| 
       70 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       71 
     | 
    
         
            -
                        If show is False, return ggplot object used to draw the plot.
         
     | 
| 
       72 
     | 
    
         
            -
             
     | 
| 
       73 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       74 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       75 
     | 
    
         
            -
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
       76 
     | 
    
         
            -
                        >>> mixscape_identifier = pt.tl.Mixscape()
         
     | 
| 
       77 
     | 
    
         
            -
                        >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
         
     | 
| 
       78 
     | 
    
         
            -
                        >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
         
     | 
| 
       79 
     | 
    
         
            -
                        >>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT')
         
     | 
| 
       80 
     | 
    
         
            -
                    """
         
     | 
| 
       81 
     | 
    
         
            -
                    if mixscape_class_global not in adata.obs:
         
     | 
| 
       82 
     | 
    
         
            -
                        raise ValueError("Please run `pt.tl.mixscape` first.")
         
     | 
| 
       83 
     | 
    
         
            -
                    count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
         
     | 
| 
       84 
     | 
    
         
            -
                    all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
         
     | 
| 
       85 
     | 
    
         
            -
                    KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
         
     | 
| 
       86 
     | 
    
         
            -
                    KO_cells_percentage = KO_cells_percentage.sort_values("value", ascending=False)
         
     | 
| 
       87 
     | 
    
         
            -
             
     | 
| 
       88 
     | 
    
         
            -
                    new_levels = KO_cells_percentage[guide_rna_column]
         
     | 
| 
       89 
     | 
    
         
            -
                    all_cells_percentage[guide_rna_column] = pd.Categorical(
         
     | 
| 
       90 
     | 
    
         
            -
                        all_cells_percentage[guide_rna_column], categories=new_levels, ordered=False
         
     | 
| 
       91 
     | 
    
         
            -
                    )
         
     | 
| 
       92 
     | 
    
         
            -
                    all_cells_percentage[mixscape_class_global] = pd.Categorical(
         
     | 
| 
       93 
     | 
    
         
            -
                        all_cells_percentage[mixscape_class_global], categories=["NT", "NP", "KO"], ordered=False
         
     | 
| 
       94 
     | 
    
         
            -
                    )
         
     | 
| 
       95 
     | 
    
         
            -
                    all_cells_percentage["gene"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[0]
         
     | 
| 
       96 
     | 
    
         
            -
                    all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1]
         
     | 
| 
       97 
     | 
    
         
            -
                    all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
         
     | 
| 
       98 
     | 
    
         
            -
                    NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
         
     | 
| 
       99 
     | 
    
         
            -
             
     | 
| 
       100 
     | 
    
         
            -
                    p1 = (
         
     | 
| 
       101 
     | 
    
         
            -
                        ggplot(NP_KO_cells, aes(x="guide_number", y="value", fill="mixscape_class_global"))
         
     | 
| 
       102 
     | 
    
         
            -
                        + scale_fill_manual(values=["#7d7d7d", "#c9c9c9", "#ff7256"])
         
     | 
| 
       103 
     | 
    
         
            -
                        + geom_bar(stat="identity")
         
     | 
| 
       104 
     | 
    
         
            -
                        + theme_classic()
         
     | 
| 
       105 
     | 
    
         
            -
                        + xlab("sgRNA")
         
     | 
| 
       106 
     | 
    
         
            -
                        + ylab("% of cells")
         
     | 
| 
       107 
     | 
    
         
            -
                    )
         
     | 
| 
       108 
     | 
    
         
            -
             
     | 
| 
       109 
     | 
    
         
            -
                    p1 = (
         
     | 
| 
       110 
     | 
    
         
            -
                        p1
         
     | 
| 
       111 
     | 
    
         
            -
                        + theme(
         
     | 
| 
       112 
     | 
    
         
            -
                            axis_text_x=element_text(size=axis_text_x_size, hjust=2),
         
     | 
| 
       113 
     | 
    
         
            -
                            axis_text_y=element_text(size=axis_text_y_size),
         
     | 
| 
       114 
     | 
    
         
            -
                            axis_title=element_text(size=axis_title_size),
         
     | 
| 
       115 
     | 
    
         
            -
                            strip_text=element_text(size=strip_text_size, face="bold"),
         
     | 
| 
       116 
     | 
    
         
            -
                            panel_spacing_x=panel_spacing_x,
         
     | 
| 
       117 
     | 
    
         
            -
                            panel_spacing_y=panel_spacing_y,
         
     | 
| 
       118 
     | 
    
         
            -
                        )
         
     | 
| 
       119 
     | 
    
         
            -
                        + facet_wrap("gene", ncol=5, scales="free")
         
     | 
| 
       120 
     | 
    
         
            -
                        + labs(fill="mixscape class")
         
     | 
| 
       121 
     | 
    
         
            -
                        + theme(legend_title=element_text(size=legend_title_size), legend_text=element_text(size=legend_text_size))
         
     | 
| 
       122 
     | 
    
         
            -
                    )
         
     | 
| 
       123 
     | 
    
         
            -
             
     | 
| 
       124 
     | 
    
         
            -
                    _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
         
     | 
| 
       125 
     | 
    
         
            -
                    if not show:
         
     | 
| 
       126 
     | 
    
         
            -
                        return p1
         
     | 
| 
       127 
     | 
    
         
            -
             
     | 
| 
       128 
     | 
    
         
            -
                @staticmethod
         
     | 
| 
       129 
     | 
    
         
            -
                def heatmap(  # pragma: no cover
         
     | 
| 
       130 
     | 
    
         
            -
                    adata: AnnData,
         
     | 
| 
       131 
     | 
    
         
            -
                    labels: str,
         
     | 
| 
       132 
     | 
    
         
            -
                    target_gene: str,
         
     | 
| 
       133 
     | 
    
         
            -
                    control: str,
         
     | 
| 
       134 
     | 
    
         
            -
                    layer: str | None = None,
         
     | 
| 
       135 
     | 
    
         
            -
                    method: str | None = "wilcoxon",
         
     | 
| 
       136 
     | 
    
         
            -
                    subsample_number: int | None = 900,
         
     | 
| 
       137 
     | 
    
         
            -
                    vmin: float | None = -2,
         
     | 
| 
       138 
     | 
    
         
            -
                    vmax: float | None = 2,
         
     | 
| 
       139 
     | 
    
         
            -
                    show: bool | None = None,
         
     | 
| 
       140 
     | 
    
         
            -
                    save: bool | str | None = None,
         
     | 
| 
       141 
     | 
    
         
            -
                    **kwds,
         
     | 
| 
       142 
     | 
    
         
            -
                ):
         
     | 
| 
       143 
     | 
    
         
            -
                    """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
         
     | 
| 
       144 
     | 
    
         
            -
             
     | 
| 
       145 
     | 
    
         
            -
                    Args:
         
     | 
| 
       146 
     | 
    
         
            -
                        adata: The annotated data object.
         
     | 
| 
       147 
     | 
    
         
            -
                        labels: The column of `.obs` with target gene labels.
         
     | 
| 
       148 
     | 
    
         
            -
                        target_gene: Target gene name to visualize heatmap for.
         
     | 
| 
       149 
     | 
    
         
            -
                        control: Control category from the `pert_key` column.
         
     | 
| 
       150 
     | 
    
         
            -
                        layer: Key from `adata.layers` whose value will be used to perform tests on.
         
     | 
| 
       151 
     | 
    
         
            -
                        method: The default method is 'wilcoxon', see `method` parameter in `scanpy.tl.rank_genes_groups` for more options.
         
     | 
| 
       152 
     | 
    
         
            -
                        subsample_number: Subsample to this number of observations.
         
     | 
| 
       153 
     | 
    
         
            -
                        vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
         
     | 
| 
       154 
     | 
    
         
            -
                        vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
         
     | 
| 
       155 
     | 
    
         
            -
                        show: Show the plot, do not return axis.
         
     | 
| 
       156 
     | 
    
         
            -
                        save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
         
     | 
| 
       157 
     | 
    
         
            -
                        ax: A matplotlib axes object. Only works if plotting a single component.
         
     | 
| 
       158 
     | 
    
         
            -
                        **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
         
     | 
| 
       159 
     | 
    
         
            -
             
     | 
| 
       160 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       161 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       162 
     | 
    
         
            -
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
       163 
     | 
    
         
            -
                        >>> mixscape_identifier = pt.tl.Mixscape()
         
     | 
| 
       164 
     | 
    
         
            -
                        >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
         
     | 
| 
       165 
     | 
    
         
            -
                        >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
         
     | 
| 
       166 
     | 
    
         
            -
                        >>> pt.pl.ms.heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
         
     | 
| 
       167 
     | 
    
         
            -
                    """
         
     | 
| 
       168 
     | 
    
         
            -
                    if "mixscape_class" not in adata.obs:
         
     | 
| 
       169 
     | 
    
         
            -
                        raise ValueError("Please run `pt.tl.mixscape` first.")
         
     | 
| 
       170 
     | 
    
         
            -
                    adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
         
     | 
| 
       171 
     | 
    
         
            -
                    sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
         
     | 
| 
       172 
     | 
    
         
            -
                    sc.pp.scale(adata_subset, max_value=vmax)
         
     | 
| 
       173 
     | 
    
         
            -
                    sc.pp.subsample(adata_subset, n_obs=subsample_number)
         
     | 
| 
       174 
     | 
    
         
            -
                    return sc.pl.rank_genes_groups_heatmap(
         
     | 
| 
       175 
     | 
    
         
            -
                        adata_subset,
         
     | 
| 
       176 
     | 
    
         
            -
                        groupby="mixscape_class",
         
     | 
| 
       177 
     | 
    
         
            -
                        vmin=vmin,
         
     | 
| 
       178 
     | 
    
         
            -
                        vmax=vmax,
         
     | 
| 
       179 
     | 
    
         
            -
                        n_genes=20,
         
     | 
| 
       180 
     | 
    
         
            -
                        groups=["NT"],
         
     | 
| 
       181 
     | 
    
         
            -
                        show=show,
         
     | 
| 
       182 
     | 
    
         
            -
                        save=save,
         
     | 
| 
       183 
     | 
    
         
            -
                        **kwds,
         
     | 
| 
       184 
     | 
    
         
            -
                    )
         
     | 
| 
       185 
     | 
    
         
            -
             
     | 
| 
       186 
     | 
    
         
            -
                @staticmethod
         
     | 
| 
       187 
     | 
    
         
            -
                def perturbscore(  # pragma: no cover
         
     | 
| 
       188 
     | 
    
         
            -
                    adata: AnnData,
         
     | 
| 
       189 
     | 
    
         
            -
                    labels: str,
         
     | 
| 
       190 
     | 
    
         
            -
                    target_gene: str,
         
     | 
| 
       191 
     | 
    
         
            -
                    mixscape_class="mixscape_class",
         
     | 
| 
       192 
     | 
    
         
            -
                    color="orange",
         
     | 
| 
       193 
     | 
    
         
            -
                    split_by: str = None,
         
     | 
| 
       194 
     | 
    
         
            -
                    before_mixscape=False,
         
     | 
| 
       195 
     | 
    
         
            -
                    perturbation_type: str = "KO",
         
     | 
| 
       196 
     | 
    
         
            -
                ):
         
     | 
| 
       197 
     | 
    
         
            -
                    """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. Requires `pt.tl.mixscape` to be run first.
         
     | 
| 
       198 
     | 
    
         
            -
             
     | 
| 
       199 
     | 
    
         
            -
                    https://satijalab.org/seurat/reference/plotperturbscore
         
     | 
| 
       200 
     | 
    
         
            -
             
     | 
| 
       201 
     | 
    
         
            -
                    Args:
         
     | 
| 
       202 
     | 
    
         
            -
                        adata: The annotated data object.
         
     | 
| 
       203 
     | 
    
         
            -
                        labels: The column of `.obs` with target gene labels.
         
     | 
| 
       204 
     | 
    
         
            -
                        target_gene: Target gene name to visualize perturbation scores for.
         
     | 
| 
       205 
     | 
    
         
            -
                        mixscape_class: The column of `.obs` with mixscape classifications.
         
     | 
| 
       206 
     | 
    
         
            -
                        color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
         
     | 
| 
       207 
     | 
    
         
            -
                        split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
         
     | 
| 
       208 
     | 
    
         
            -
                            the perturbation signature for every replicate separately.
         
     | 
| 
       209 
     | 
    
         
            -
                        before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification. Default is set to NULL and plots cells by original class ID.
         
     | 
| 
       210 
     | 
    
         
            -
                        perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Default is KO.
         
     | 
| 
       211 
     | 
    
         
            -
             
     | 
| 
       212 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       213 
     | 
    
         
            -
                        The ggplot object used for drawn.
         
     | 
| 
       214 
     | 
    
         
            -
             
     | 
| 
       215 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       216 
     | 
    
         
            -
                        Visualizing the perturbation scores for the cells in a dataset:
         
     | 
| 
       217 
     | 
    
         
            -
             
     | 
| 
       218 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       219 
     | 
    
         
            -
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
       220 
     | 
    
         
            -
                        >>> mixscape_identifier = pt.tl.Mixscape()
         
     | 
| 
       221 
     | 
    
         
            -
                        >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
         
     | 
| 
       222 
     | 
    
         
            -
                        >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
         
     | 
| 
       223 
     | 
    
         
            -
                        >>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
         
     | 
| 
       224 
     | 
    
         
            -
                    """
         
     | 
| 
       225 
     | 
    
         
            -
                    if "mixscape" not in adata.uns:
         
     | 
| 
       226 
     | 
    
         
            -
                        raise ValueError("Please run `pt.tl.mixscape` first.")
         
     | 
| 
       227 
     | 
    
         
            -
                    perturbation_score = None
         
     | 
| 
       228 
     | 
    
         
            -
                    for key in adata.uns["mixscape"][target_gene].keys():
         
     | 
| 
       229 
     | 
    
         
            -
                        perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
         
     | 
| 
       230 
     | 
    
         
            -
                        perturbation_score_temp["name"] = key
         
     | 
| 
       231 
     | 
    
         
            -
                        if perturbation_score is None:
         
     | 
| 
       232 
     | 
    
         
            -
                            perturbation_score = copy.deepcopy(perturbation_score_temp)
         
     | 
| 
       233 
     | 
    
         
            -
                        else:
         
     | 
| 
       234 
     | 
    
         
            -
                            perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
         
     | 
| 
       235 
     | 
    
         
            -
                    perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
         
     | 
| 
       236 
     | 
    
         
            -
                    gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
         
     | 
| 
       237 
     | 
    
         
            -
                    # If before_mixscape is True, split densities based on original target gene classification
         
     | 
| 
       238 
     | 
    
         
            -
                    if before_mixscape is True:
         
     | 
| 
       239 
     | 
    
         
            -
                        cols = {gd: "#7d7d7d", target_gene: color}
         
     | 
| 
       240 
     | 
    
         
            -
                        p = ggplot(perturbation_score, aes(x="pvec", color=labels)) + geom_density() + theme_classic()
         
     | 
| 
       241 
     | 
    
         
            -
                        p_copy = copy.deepcopy(p)
         
     | 
| 
       242 
     | 
    
         
            -
                        p_copy._build()
         
     | 
| 
       243 
     | 
    
         
            -
                        top_r = max(p_copy.layers[0].data["density"])
         
     | 
| 
       244 
     | 
    
         
            -
                        perturbation_score["y_jitter"] = perturbation_score["pvec"]
         
     | 
| 
       245 
     | 
    
         
            -
                        rng = np.random.default_rng()
         
     | 
| 
       246 
     | 
    
         
            -
                        perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
         
     | 
| 
       247 
     | 
    
         
            -
                            low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
         
     | 
| 
       248 
     | 
    
         
            -
                        )
         
     | 
| 
       249 
     | 
    
         
            -
                        perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
         
     | 
| 
       250 
     | 
    
         
            -
                            low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
         
     | 
| 
       251 
     | 
    
         
            -
                        )
         
     | 
| 
       252 
     | 
    
         
            -
                        # If split_by is provided, split densities based on the split_by
         
     | 
| 
       253 
     | 
    
         
            -
                        if split_by is not None:
         
     | 
| 
       254 
     | 
    
         
            -
                            perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
         
     | 
| 
       255 
     | 
    
         
            -
                            p2 = (
         
     | 
| 
       256 
     | 
    
         
            -
                                p
         
     | 
| 
       257 
     | 
    
         
            -
                                + scale_color_manual(values=cols, drop=False)
         
     | 
| 
       258 
     | 
    
         
            -
                                + geom_density(size=1.5)
         
     | 
| 
       259 
     | 
    
         
            -
                                + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
         
     | 
| 
       260 
     | 
    
         
            -
                                + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
         
     | 
| 
       261 
     | 
    
         
            -
                                + ylab("Cell density")
         
     | 
| 
       262 
     | 
    
         
            -
                                + xlab("Perturbation score")
         
     | 
| 
       263 
     | 
    
         
            -
                                + theme(
         
     | 
| 
       264 
     | 
    
         
            -
                                    legend_key_size=1,
         
     | 
| 
       265 
     | 
    
         
            -
                                    legend_text=element_text(colour="black", size=14),
         
     | 
| 
       266 
     | 
    
         
            -
                                    legend_title=element_blank(),
         
     | 
| 
       267 
     | 
    
         
            -
                                    plot_title=element_text(size=16, face="bold"),
         
     | 
| 
       268 
     | 
    
         
            -
                                )
         
     | 
| 
       269 
     | 
    
         
            -
                                + facet_wrap("split")
         
     | 
| 
       270 
     | 
    
         
            -
                            )
         
     | 
| 
       271 
     | 
    
         
            -
                        else:
         
     | 
| 
       272 
     | 
    
         
            -
                            p2 = (
         
     | 
| 
       273 
     | 
    
         
            -
                                p
         
     | 
| 
       274 
     | 
    
         
            -
                                + scale_color_manual(values=cols, drop=False)
         
     | 
| 
       275 
     | 
    
         
            -
                                + geom_density(size=1.5)
         
     | 
| 
       276 
     | 
    
         
            -
                                + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
         
     | 
| 
       277 
     | 
    
         
            -
                                + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
         
     | 
| 
       278 
     | 
    
         
            -
                                + ylab("Cell density")
         
     | 
| 
       279 
     | 
    
         
            -
                                + xlab("Perturbation score")
         
     | 
| 
       280 
     | 
    
         
            -
                                + theme(
         
     | 
| 
       281 
     | 
    
         
            -
                                    legend_key_size=1,
         
     | 
| 
       282 
     | 
    
         
            -
                                    legend_text=element_text(colour="black", size=14),
         
     | 
| 
       283 
     | 
    
         
            -
                                    legend_title=element_blank(),
         
     | 
| 
       284 
     | 
    
         
            -
                                    plot_title=element_text(size=16, face="bold"),
         
     | 
| 
       285 
     | 
    
         
            -
                                )
         
     | 
| 
       286 
     | 
    
         
            -
                            )
         
     | 
| 
       287 
     | 
    
         
            -
                    # If before_mixscape is False, split densities based on mixscape classifications
         
     | 
| 
       288 
     | 
    
         
            -
                    else:
         
     | 
| 
       289 
     | 
    
         
            -
                        cols = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
         
     | 
| 
       290 
     | 
    
         
            -
                        p = ggplot(perturbation_score, aes(x="pvec", color="mix")) + geom_density() + theme_classic()
         
     | 
| 
       291 
     | 
    
         
            -
                        p_copy = copy.deepcopy(p)
         
     | 
| 
       292 
     | 
    
         
            -
                        p_copy._build()
         
     | 
| 
       293 
     | 
    
         
            -
                        top_r = max(p_copy.layers[0].data["density"])
         
     | 
| 
       294 
     | 
    
         
            -
                        perturbation_score["y_jitter"] = perturbation_score["pvec"]
         
     | 
| 
       295 
     | 
    
         
            -
                        rng = np.random.default_rng()
         
     | 
| 
       296 
     | 
    
         
            -
                        gd2 = list(
         
     | 
| 
       297 
     | 
    
         
            -
                            set(perturbation_score["mix"]).difference([f"{target_gene} NP", f"{target_gene} {perturbation_type}"])
         
     | 
| 
       298 
     | 
    
         
            -
                        )[0]
         
     | 
| 
       299 
     | 
    
         
            -
                        perturbation_score.loc[perturbation_score["mix"] == gd2, "y_jitter"] = rng.uniform(
         
     | 
| 
       300 
     | 
    
         
            -
                            low=0.001, high=top_r / 10, size=sum(perturbation_score["mix"] == gd2)
         
     | 
| 
       301 
     | 
    
         
            -
                        )
         
     | 
| 
       302 
     | 
    
         
            -
                        perturbation_score.loc[
         
     | 
| 
       303 
     | 
    
         
            -
                            perturbation_score["mix"] == f"{target_gene} {perturbation_type}", "y_jitter"
         
     | 
| 
       304 
     | 
    
         
            -
                        ] = rng.uniform(
         
     | 
| 
       305 
     | 
    
         
            -
                            low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} {perturbation_type}")
         
     | 
| 
       306 
     | 
    
         
            -
                        )
         
     | 
| 
       307 
     | 
    
         
            -
                        perturbation_score.loc[perturbation_score["mix"] == f"{target_gene} NP", "y_jitter"] = rng.uniform(
         
     | 
| 
       308 
     | 
    
         
            -
                            low=-top_r / 10, high=0, size=sum(perturbation_score["mix"] == f"{target_gene} NP")
         
     | 
| 
       309 
     | 
    
         
            -
                        )
         
     | 
| 
       310 
     | 
    
         
            -
                        # If split_by is provided, split densities based on the split_by
         
     | 
| 
       311 
     | 
    
         
            -
                        if split_by is not None:
         
     | 
| 
       312 
     | 
    
         
            -
                            perturbation_score["split"] = adata.obs[split_by][perturbation_score.index]
         
     | 
| 
       313 
     | 
    
         
            -
                            p2 = (
         
     | 
| 
       314 
     | 
    
         
            -
                                ggplot(perturbation_score, aes(x="pvec", color="mix"))
         
     | 
| 
       315 
     | 
    
         
            -
                                + scale_color_manual(values=cols, drop=False)
         
     | 
| 
       316 
     | 
    
         
            -
                                + geom_density(size=1.5)
         
     | 
| 
       317 
     | 
    
         
            -
                                + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
         
     | 
| 
       318 
     | 
    
         
            -
                                + theme_classic()
         
     | 
| 
       319 
     | 
    
         
            -
                                + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
         
     | 
| 
       320 
     | 
    
         
            -
                                + ylab("Cell density")
         
     | 
| 
       321 
     | 
    
         
            -
                                + xlab("Perturbation score")
         
     | 
| 
       322 
     | 
    
         
            -
                                + theme(
         
     | 
| 
       323 
     | 
    
         
            -
                                    legend_key_size=1,
         
     | 
| 
       324 
     | 
    
         
            -
                                    legend_text=element_text(colour="black", size=14),
         
     | 
| 
       325 
     | 
    
         
            -
                                    legend_title=element_blank(),
         
     | 
| 
       326 
     | 
    
         
            -
                                    plot_title=element_text(size=16, face="bold"),
         
     | 
| 
       327 
     | 
    
         
            -
                                )
         
     | 
| 
       328 
     | 
    
         
            -
                                + facet_wrap("split")
         
     | 
| 
       329 
     | 
    
         
            -
                            )
         
     | 
| 
       330 
     | 
    
         
            -
                        else:
         
     | 
| 
       331 
     | 
    
         
            -
                            p2 = (
         
     | 
| 
       332 
     | 
    
         
            -
                                p
         
     | 
| 
       333 
     | 
    
         
            -
                                + scale_color_manual(values=cols, drop=False)
         
     | 
| 
       334 
     | 
    
         
            -
                                + geom_density(size=1.5)
         
     | 
| 
       335 
     | 
    
         
            -
                                + geom_point(aes(x="pvec", y="y_jitter"), size=0.1)
         
     | 
| 
       336 
     | 
    
         
            -
                                + theme(axis_text=element_text(size=18), axis_title=element_text(size=20))
         
     | 
| 
       337 
     | 
    
         
            -
                                + ylab("Cell density")
         
     | 
| 
       338 
     | 
    
         
            -
                                + xlab("Perturbation score")
         
     | 
| 
       339 
     | 
    
         
            -
                                + theme(
         
     | 
| 
       340 
     | 
    
         
            -
                                    legend_key_size=1,
         
     | 
| 
       341 
     | 
    
         
            -
                                    legend_text=element_text(colour="black", size=14),
         
     | 
| 
       342 
     | 
    
         
            -
                                    legend_title=element_blank(),
         
     | 
| 
       343 
     | 
    
         
            -
                                    plot_title=element_text(size=16, face="bold"),
         
     | 
| 
       344 
     | 
    
         
            -
                                )
         
     | 
| 
       345 
     | 
    
         
            -
                            )
         
     | 
| 
       346 
     | 
    
         
            -
                    return p2
         
     | 
| 
       347 
     | 
    
         
            -
             
     | 
| 
       348 
     | 
    
         
            -
                @staticmethod
         
     | 
| 
       349 
     | 
    
         
            -
                def violin(  # pragma: no cover
         
     | 
| 
       350 
     | 
    
         
            -
                    adata: AnnData,
         
     | 
| 
       351 
     | 
    
         
            -
                    target_gene_idents: str | list[str],
         
     | 
| 
       352 
     | 
    
         
            -
                    keys: str | Sequence[str] = "mixscape_class_p_ko",
         
     | 
| 
       353 
     | 
    
         
            -
                    groupby: str | None = "mixscape_class",
         
     | 
| 
       354 
     | 
    
         
            -
                    log: bool = False,
         
     | 
| 
       355 
     | 
    
         
            -
                    use_raw: bool | None = None,
         
     | 
| 
       356 
     | 
    
         
            -
                    stripplot: bool = True,
         
     | 
| 
       357 
     | 
    
         
            -
                    hue: str | None = None,
         
     | 
| 
       358 
     | 
    
         
            -
                    jitter: float | bool = True,
         
     | 
| 
       359 
     | 
    
         
            -
                    size: int = 1,
         
     | 
| 
       360 
     | 
    
         
            -
                    layer: str | None = None,
         
     | 
| 
       361 
     | 
    
         
            -
                    scale: Literal["area", "count", "width"] = "width",
         
     | 
| 
       362 
     | 
    
         
            -
                    order: Sequence[str] | None = None,
         
     | 
| 
       363 
     | 
    
         
            -
                    multi_panel: bool | None = None,
         
     | 
| 
       364 
     | 
    
         
            -
                    xlabel: str = "",
         
     | 
| 
       365 
     | 
    
         
            -
                    ylabel: str | Sequence[str] | None = None,
         
     | 
| 
       366 
     | 
    
         
            -
                    rotation: float | None = None,
         
     | 
| 
       367 
     | 
    
         
            -
                    show: bool | None = None,
         
     | 
| 
       368 
     | 
    
         
            -
                    save: bool | str | None = None,
         
     | 
| 
       369 
     | 
    
         
            -
                    ax: Axes | None = None,
         
     | 
| 
       370 
     | 
    
         
            -
                    **kwds,
         
     | 
| 
       371 
     | 
    
         
            -
                ):
         
     | 
| 
       372 
     | 
    
         
            -
                    """Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first.
         
     | 
| 
       373 
     | 
    
         
            -
             
     | 
| 
       374 
     | 
    
         
            -
                    Args:
         
     | 
| 
       375 
     | 
    
         
            -
                        adata: The annotated data object.
         
     | 
| 
       376 
     | 
    
         
            -
                        target_gene: Target gene name to plot.
         
     | 
| 
       377 
     | 
    
         
            -
                        keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
         
     | 
| 
       378 
     | 
    
         
            -
                        groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
         
     | 
| 
       379 
     | 
    
         
            -
                        log: Plot on logarithmic axis.
         
     | 
| 
       380 
     | 
    
         
            -
                        use_raw: Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
         
     | 
| 
       381 
     | 
    
         
            -
                        stripplot: Add a stripplot on top of the violin plot.
         
     | 
| 
       382 
     | 
    
         
            -
                        order: Order in which to show the categories.
         
     | 
| 
       383 
     | 
    
         
            -
                        xlabel: Label of the x axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
         
     | 
| 
       384 
     | 
    
         
            -
                        ylabel: Label of the y axis. If `None` and `groupby` is `None`, defaults to `'value'`. If `None` and `groubpy` is not `None`, defaults to `keys`.
         
     | 
| 
       385 
     | 
    
         
            -
                        show: Show the plot, do not return axis.
         
     | 
| 
       386 
     | 
    
         
            -
                        save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
         
     | 
| 
       387 
     | 
    
         
            -
                        ax: A matplotlib axes object. Only works if plotting a single component.
         
     | 
| 
       388 
     | 
    
         
            -
                        **kwds: Additional arguments to `seaborn.violinplot`.
         
     | 
| 
       389 
     | 
    
         
            -
             
     | 
| 
       390 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       391 
     | 
    
         
            -
                        A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
         
     | 
| 
       392 
     | 
    
         
            -
             
     | 
| 
       393 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       394 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       395 
     | 
    
         
            -
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
       396 
     | 
    
         
            -
                        >>> mixscape_identifier = pt.tl.Mixscape()
         
     | 
| 
       397 
     | 
    
         
            -
                        >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
         
     | 
| 
       398 
     | 
    
         
            -
                        >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
         
     | 
| 
       399 
     | 
    
         
            -
                        >>> pt.pl.ms.violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
         
     | 
| 
       400 
     | 
    
         
            -
                    """
         
     | 
| 
       401 
     | 
    
         
            -
                    if isinstance(target_gene_idents, str):
         
     | 
| 
       402 
     | 
    
         
            -
                        mixscape_class_mask = adata.obs[groupby] == target_gene_idents
         
     | 
| 
       403 
     | 
    
         
            -
                    elif isinstance(target_gene_idents, list):
         
     | 
| 
       404 
     | 
    
         
            -
                        mixscape_class_mask = np.full_like(adata.obs[groupby], False, dtype=bool)
         
     | 
| 
       405 
     | 
    
         
            -
                        for ident in target_gene_idents:
         
     | 
| 
       406 
     | 
    
         
            -
                            mixscape_class_mask |= adata.obs[groupby] == ident
         
     | 
| 
       407 
     | 
    
         
            -
                    adata = adata[mixscape_class_mask]
         
     | 
| 
       408 
     | 
    
         
            -
             
     | 
| 
       409 
     | 
    
         
            -
                    import seaborn as sns  # Slow import, only import if called
         
     | 
| 
       410 
     | 
    
         
            -
             
     | 
| 
       411 
     | 
    
         
            -
                    sanitize_anndata(adata)
         
     | 
| 
       412 
     | 
    
         
            -
                    use_raw = _check_use_raw(adata, use_raw)
         
     | 
| 
       413 
     | 
    
         
            -
                    if isinstance(keys, str):
         
     | 
| 
       414 
     | 
    
         
            -
                        keys = [keys]
         
     | 
| 
       415 
     | 
    
         
            -
                    keys = list(OrderedDict.fromkeys(keys))  # remove duplicates, preserving the order
         
     | 
| 
       416 
     | 
    
         
            -
             
     | 
| 
       417 
     | 
    
         
            -
                    if isinstance(ylabel, (str, type(None))):
         
     | 
| 
       418 
     | 
    
         
            -
                        ylabel = [ylabel] * (1 if groupby is None else len(keys))
         
     | 
| 
       419 
     | 
    
         
            -
                    if groupby is None:
         
     | 
| 
       420 
     | 
    
         
            -
                        if len(ylabel) != 1:
         
     | 
| 
       421 
     | 
    
         
            -
                            raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
         
     | 
| 
       422 
     | 
    
         
            -
                    elif len(ylabel) != len(keys):
         
     | 
| 
       423 
     | 
    
         
            -
                        raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
         
     | 
| 
       424 
     | 
    
         
            -
             
     | 
| 
       425 
     | 
    
         
            -
                    if groupby is not None:
         
     | 
| 
       426 
     | 
    
         
            -
                        if hue is not None:
         
     | 
| 
       427 
     | 
    
         
            -
                            obs_df = get.obs_df(adata, keys=[groupby] + keys + [hue], layer=layer, use_raw=use_raw)
         
     | 
| 
       428 
     | 
    
         
            -
                        else:
         
     | 
| 
       429 
     | 
    
         
            -
                            obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw)
         
     | 
| 
       430 
     | 
    
         
            -
             
     | 
| 
       431 
     | 
    
         
            -
                    else:
         
     | 
| 
       432 
     | 
    
         
            -
                        obs_df = get.obs_df(adata, keys=keys, layer=layer, use_raw=use_raw)
         
     | 
| 
       433 
     | 
    
         
            -
                    if groupby is None:
         
     | 
| 
       434 
     | 
    
         
            -
                        obs_tidy = pd.melt(obs_df, value_vars=keys)
         
     | 
| 
       435 
     | 
    
         
            -
                        x = "variable"
         
     | 
| 
       436 
     | 
    
         
            -
                        ys = ["value"]
         
     | 
| 
       437 
     | 
    
         
            -
                    else:
         
     | 
| 
       438 
     | 
    
         
            -
                        obs_tidy = obs_df
         
     | 
| 
       439 
     | 
    
         
            -
                        x = groupby
         
     | 
| 
       440 
     | 
    
         
            -
                        ys = keys
         
     | 
| 
       441 
     | 
    
         
            -
             
     | 
| 
       442 
     | 
    
         
            -
                    if multi_panel and groupby is None and len(ys) == 1:
         
     | 
| 
       443 
     | 
    
         
            -
                        # This is a quick and dirty way for adapting scales across several
         
     | 
| 
       444 
     | 
    
         
            -
                        # keys if groupby is None.
         
     | 
| 
       445 
     | 
    
         
            -
                        y = ys[0]
         
     | 
| 
       446 
     | 
    
         
            -
             
     | 
| 
       447 
     | 
    
         
            -
                        g = sns.catplot(
         
     | 
| 
       448 
     | 
    
         
            -
                            y=y,
         
     | 
| 
       449 
     | 
    
         
            -
                            data=obs_tidy,
         
     | 
| 
       450 
     | 
    
         
            -
                            kind="violin",
         
     | 
| 
       451 
     | 
    
         
            -
                            scale=scale,
         
     | 
| 
       452 
     | 
    
         
            -
                            col=x,
         
     | 
| 
       453 
     | 
    
         
            -
                            col_order=keys,
         
     | 
| 
       454 
     | 
    
         
            -
                            sharey=False,
         
     | 
| 
       455 
     | 
    
         
            -
                            order=keys,
         
     | 
| 
       456 
     | 
    
         
            -
                            cut=0,
         
     | 
| 
       457 
     | 
    
         
            -
                            inner=None,
         
     | 
| 
       458 
     | 
    
         
            -
                            **kwds,
         
     | 
| 
       459 
     | 
    
         
            -
                        )
         
     | 
| 
       460 
     | 
    
         
            -
             
     | 
| 
       461 
     | 
    
         
            -
                        if stripplot:
         
     | 
| 
       462 
     | 
    
         
            -
                            grouped_df = obs_tidy.groupby(x)
         
     | 
| 
       463 
     | 
    
         
            -
                            for ax_id, key in zip(range(g.axes.shape[1]), keys):
         
     | 
| 
       464 
     | 
    
         
            -
                                sns.stripplot(
         
     | 
| 
       465 
     | 
    
         
            -
                                    y=y,
         
     | 
| 
       466 
     | 
    
         
            -
                                    data=grouped_df.get_group(key),
         
     | 
| 
       467 
     | 
    
         
            -
                                    jitter=jitter,
         
     | 
| 
       468 
     | 
    
         
            -
                                    size=size,
         
     | 
| 
       469 
     | 
    
         
            -
                                    color="black",
         
     | 
| 
       470 
     | 
    
         
            -
                                    ax=g.axes[0, ax_id],
         
     | 
| 
       471 
     | 
    
         
            -
                                )
         
     | 
| 
       472 
     | 
    
         
            -
                        if log:
         
     | 
| 
       473 
     | 
    
         
            -
                            g.set(yscale="log")
         
     | 
| 
       474 
     | 
    
         
            -
                        g.set_titles(col_template="{col_name}").set_xlabels("")
         
     | 
| 
       475 
     | 
    
         
            -
                        if rotation is not None:
         
     | 
| 
       476 
     | 
    
         
            -
                            for ax in g.axes[0]:
         
     | 
| 
       477 
     | 
    
         
            -
                                ax.tick_params(axis="x", labelrotation=rotation)
         
     | 
| 
       478 
     | 
    
         
            -
                    else:
         
     | 
| 
       479 
     | 
    
         
            -
                        # set by default the violin plot cut=0 to limit the extend
         
     | 
| 
       480 
     | 
    
         
            -
                        # of the violin plot (see stacked_violin code) for more info.
         
     | 
| 
       481 
     | 
    
         
            -
                        kwds.setdefault("cut", 0)
         
     | 
| 
       482 
     | 
    
         
            -
                        kwds.setdefault("inner")
         
     | 
| 
       483 
     | 
    
         
            -
             
     | 
| 
       484 
     | 
    
         
            -
                        if ax is None:
         
     | 
| 
       485 
     | 
    
         
            -
                            axs, _, _, _ = _utils.setup_axes(
         
     | 
| 
       486 
     | 
    
         
            -
                                ax=ax,
         
     | 
| 
       487 
     | 
    
         
            -
                                panels=["x"] if groupby is None else keys,
         
     | 
| 
       488 
     | 
    
         
            -
                                show_ticks=True,
         
     | 
| 
       489 
     | 
    
         
            -
                                right_margin=0.3,
         
     | 
| 
       490 
     | 
    
         
            -
                            )
         
     | 
| 
       491 
     | 
    
         
            -
                        else:
         
     | 
| 
       492 
     | 
    
         
            -
                            axs = [ax]
         
     | 
| 
       493 
     | 
    
         
            -
                        for ax, y, ylab in zip(axs, ys, ylabel):  # noqa: F402
         
     | 
| 
       494 
     | 
    
         
            -
                            ax = sns.violinplot(
         
     | 
| 
       495 
     | 
    
         
            -
                                x=x,
         
     | 
| 
       496 
     | 
    
         
            -
                                y=y,
         
     | 
| 
       497 
     | 
    
         
            -
                                data=obs_tidy,
         
     | 
| 
       498 
     | 
    
         
            -
                                order=order,
         
     | 
| 
       499 
     | 
    
         
            -
                                orient="vertical",
         
     | 
| 
       500 
     | 
    
         
            -
                                scale=scale,
         
     | 
| 
       501 
     | 
    
         
            -
                                ax=ax,
         
     | 
| 
       502 
     | 
    
         
            -
                                hue=hue,
         
     | 
| 
       503 
     | 
    
         
            -
                                **kwds,
         
     | 
| 
       504 
     | 
    
         
            -
                            )
         
     | 
| 
       505 
     | 
    
         
            -
                            # Get the handles and labels.
         
     | 
| 
       506 
     | 
    
         
            -
                            handles, labels = ax.get_legend_handles_labels()
         
     | 
| 
       507 
     | 
    
         
            -
                            if stripplot:
         
     | 
| 
       508 
     | 
    
         
            -
                                ax = sns.stripplot(
         
     | 
| 
       509 
     | 
    
         
            -
                                    x=x,
         
     | 
| 
       510 
     | 
    
         
            -
                                    y=y,
         
     | 
| 
       511 
     | 
    
         
            -
                                    data=obs_tidy,
         
     | 
| 
       512 
     | 
    
         
            -
                                    order=order,
         
     | 
| 
       513 
     | 
    
         
            -
                                    jitter=jitter,
         
     | 
| 
       514 
     | 
    
         
            -
                                    color="black",
         
     | 
| 
       515 
     | 
    
         
            -
                                    size=size,
         
     | 
| 
       516 
     | 
    
         
            -
                                    ax=ax,
         
     | 
| 
       517 
     | 
    
         
            -
                                    hue=hue,
         
     | 
| 
       518 
     | 
    
         
            -
                                    dodge=True,
         
     | 
| 
       519 
     | 
    
         
            -
                                )
         
     | 
| 
       520 
     | 
    
         
            -
                            if xlabel == "" and groupby is not None and rotation is None:
         
     | 
| 
       521 
     | 
    
         
            -
                                xlabel = groupby.replace("_", " ")
         
     | 
| 
       522 
     | 
    
         
            -
                            ax.set_xlabel(xlabel)
         
     | 
| 
       523 
     | 
    
         
            -
                            if ylab is not None:
         
     | 
| 
       524 
     | 
    
         
            -
                                ax.set_ylabel(ylab)
         
     | 
| 
       525 
     | 
    
         
            -
             
     | 
| 
       526 
     | 
    
         
            -
                            if log:
         
     | 
| 
       527 
     | 
    
         
            -
                                ax.set_yscale("log")
         
     | 
| 
       528 
     | 
    
         
            -
                            if rotation is not None:
         
     | 
| 
       529 
     | 
    
         
            -
                                ax.tick_params(axis="x", labelrotation=rotation)
         
     | 
| 
       530 
     | 
    
         
            -
             
     | 
| 
       531 
     | 
    
         
            -
                    show = settings.autoshow if show is None else show
         
     | 
| 
       532 
     | 
    
         
            -
                    if hue is not None and stripplot is True:
         
     | 
| 
       533 
     | 
    
         
            -
                        pl.legend(handles, labels)
         
     | 
| 
       534 
     | 
    
         
            -
                    _utils.savefig_or_show("mixscape_violin", show=show, save=save)
         
     | 
| 
       535 
     | 
    
         
            -
             
     | 
| 
       536 
     | 
    
         
            -
                    if not show:
         
     | 
| 
       537 
     | 
    
         
            -
                        if multi_panel and groupby is None and len(ys) == 1:
         
     | 
| 
       538 
     | 
    
         
            -
                            return g
         
     | 
| 
       539 
     | 
    
         
            -
                        elif len(axs) == 1:
         
     | 
| 
       540 
     | 
    
         
            -
                            return axs[0]
         
     | 
| 
       541 
     | 
    
         
            -
                        else:
         
     | 
| 
       542 
     | 
    
         
            -
                            return axs
         
     | 
| 
       543 
     | 
    
         
            -
             
     | 
| 
       544 
     | 
    
         
            -
                @staticmethod
         
     | 
| 
       545 
     | 
    
         
            -
                def lda(  # pragma: no cover
         
     | 
| 
       546 
     | 
    
         
            -
                    adata: AnnData,
         
     | 
| 
       547 
     | 
    
         
            -
                    control: str,
         
     | 
| 
       548 
     | 
    
         
            -
                    mixscape_class="mixscape_class",
         
     | 
| 
       549 
     | 
    
         
            -
                    mixscape_class_global="mixscape_class_global",
         
     | 
| 
       550 
     | 
    
         
            -
                    perturbation_type: str | None = "KO",
         
     | 
| 
       551 
     | 
    
         
            -
                    lda_key: str | None = "mixscape_lda",
         
     | 
| 
       552 
     | 
    
         
            -
                    n_components: int | None = None,
         
     | 
| 
       553 
     | 
    
         
            -
                    show: bool | None = None,
         
     | 
| 
       554 
     | 
    
         
            -
                    save: bool | str | None = None,
         
     | 
| 
       555 
     | 
    
         
            -
                    **kwds,
         
     | 
| 
       556 
     | 
    
         
            -
                ):
         
     | 
| 
       557 
     | 
    
         
            -
                    """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
         
     | 
| 
       558 
     | 
    
         
            -
             
     | 
| 
       559 
     | 
    
         
            -
                    Args:
         
     | 
| 
       560 
     | 
    
         
            -
                        adata: The annotated data object.
         
     | 
| 
       561 
     | 
    
         
            -
                        control: Control category from the `pert_key` column.
         
     | 
| 
       562 
     | 
    
         
            -
                        labels: The column of `.obs` with target gene labels.
         
     | 
| 
       563 
     | 
    
         
            -
                        mixscape_class: The column of `.obs` with the mixscape classification result.
         
     | 
| 
       564 
     | 
    
         
            -
                        mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
         
     | 
| 
       565 
     | 
    
         
            -
                        perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to 'KO'.
         
     | 
| 
       566 
     | 
    
         
            -
                        lda_key: If not speficied, lda looks .uns["mixscape_lda"] for the LDA results.
         
     | 
| 
       567 
     | 
    
         
            -
                        n_components: The number of dimensions of the embedding.
         
     | 
| 
       568 
     | 
    
         
            -
                        show: Show the plot, do not return axis.
         
     | 
| 
       569 
     | 
    
         
            -
                        save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
         
     | 
| 
       570 
     | 
    
         
            -
                        **kwds: Additional arguments to `scanpy.pl.umap`.
         
     | 
| 
       571 
     | 
    
         
            -
             
     | 
| 
       572 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       573 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       574 
     | 
    
         
            -
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
       575 
     | 
    
         
            -
                        >>> mixscape_identifier = pt.tl.Mixscape()
         
     | 
| 
       576 
     | 
    
         
            -
                        >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
         
     | 
| 
       577 
     | 
    
         
            -
                        >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
         
     | 
| 
       578 
     | 
    
         
            -
                        >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
         
     | 
| 
       579 
     | 
    
         
            -
                        >>> pt.pl.ms.lda(adata=mdata['rna'], control='NT')
         
     | 
| 
       580 
     | 
    
         
            -
                    """
         
     | 
| 
       581 
     | 
    
         
            -
                    if mixscape_class not in adata.obs:
         
     | 
| 
       582 
     | 
    
         
            -
                        raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.')
         
     | 
| 
       583 
     | 
    
         
            -
                    if lda_key not in adata.uns:
         
     | 
| 
       584 
     | 
    
         
            -
                        raise ValueError(f'Did not find .uns["{lda_key!r}"]. Run `pt.tl.neighbors` first.')
         
     | 
| 
       585 
     | 
    
         
            -
             
     | 
| 
       586 
     | 
    
         
            -
                    adata_subset = adata[
         
     | 
| 
       587 
     | 
    
         
            -
                        (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
         
     | 
| 
       588 
     | 
    
         
            -
                    ].copy()
         
     | 
| 
       589 
     | 
    
         
            -
                    adata_subset.obsm[lda_key] = adata_subset.uns[lda_key]
         
     | 
| 
       590 
     | 
    
         
            -
                    if n_components is None:
         
     | 
| 
       591 
     | 
    
         
            -
                        n_components = adata_subset.uns[lda_key].shape[1]
         
     | 
| 
       592 
     | 
    
         
            -
                    sc.pp.neighbors(adata_subset, use_rep=lda_key)
         
     | 
| 
       593 
     | 
    
         
            -
                    sc.tl.umap(adata_subset, n_components=n_components)
         
     | 
| 
       594 
     | 
    
         
            -
                    sc.pl.umap(adata_subset, color=mixscape_class, show=show, save=save, **kwds)
         
     |