pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
 - pertpy/data/__init__.py +66 -1
 - pertpy/data/_dataloader.py +28 -26
 - pertpy/data/_datasets.py +261 -92
 - pertpy/metadata/__init__.py +6 -0
 - pertpy/metadata/_cell_line.py +795 -0
 - pertpy/metadata/_compound.py +128 -0
 - pertpy/metadata/_drug.py +238 -0
 - pertpy/metadata/_look_up.py +569 -0
 - pertpy/metadata/_metadata.py +70 -0
 - pertpy/metadata/_moa.py +125 -0
 - pertpy/plot/__init__.py +0 -13
 - pertpy/preprocessing/__init__.py +2 -0
 - pertpy/preprocessing/_guide_rna.py +89 -6
 - pertpy/tools/__init__.py +48 -15
 - pertpy/tools/_augur.py +329 -32
 - pertpy/tools/_cinemaot.py +145 -6
 - pertpy/tools/_coda/_base_coda.py +1237 -116
 - pertpy/tools/_coda/_sccoda.py +66 -36
 - pertpy/tools/_coda/_tasccoda.py +46 -39
 - pertpy/tools/_dialogue.py +180 -77
 - pertpy/tools/_differential_gene_expression/__init__.py +20 -0
 - pertpy/tools/_differential_gene_expression/_base.py +657 -0
 - pertpy/tools/_differential_gene_expression/_checks.py +41 -0
 - pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
 - pertpy/tools/_differential_gene_expression/_edger.py +125 -0
 - pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
 - pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
 - pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
 - pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
 - pertpy/tools/_distances/_distance_tests.py +29 -24
 - pertpy/tools/_distances/_distances.py +584 -98
 - pertpy/tools/_enrichment.py +460 -0
 - pertpy/tools/_kernel_pca.py +1 -1
 - pertpy/tools/_milo.py +406 -49
 - pertpy/tools/_mixscape.py +677 -55
 - pertpy/tools/_perturbation_space/_clustering.py +10 -3
 - pertpy/tools/_perturbation_space/_comparison.py +112 -0
 - pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
 - pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
 - pertpy/tools/_perturbation_space/_simple.py +52 -11
 - pertpy/tools/_scgen/__init__.py +1 -1
 - pertpy/tools/_scgen/_base_components.py +2 -3
 - pertpy/tools/_scgen/_scgen.py +706 -0
 - pertpy/tools/_scgen/_utils.py +3 -5
 - pertpy/tools/decoupler_LICENSE +674 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
 - pertpy-0.8.0.dist-info/RECORD +57 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
 - pertpy/plot/_augur.py +0 -234
 - pertpy/plot/_cinemaot.py +0 -81
 - pertpy/plot/_coda.py +0 -1001
 - pertpy/plot/_dialogue.py +0 -91
 - pertpy/plot/_guide_rna.py +0 -82
 - pertpy/plot/_milopy.py +0 -284
 - pertpy/plot/_mixscape.py +0 -594
 - pertpy/plot/_scgen.py +0 -337
 - pertpy/tools/_differential_gene_expression.py +0 -99
 - pertpy/tools/_metadata/__init__.py +0 -0
 - pertpy/tools/_metadata/_cell_line.py +0 -613
 - pertpy/tools/_metadata/_look_up.py +0 -342
 - pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
 - pertpy/tools/_scgen/_jax_scgen.py +0 -370
 - pertpy-0.6.0.dist-info/RECORD +0 -50
 - /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
 
| 
         @@ -0,0 +1,657 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import os
         
     | 
| 
      
 2 
     | 
    
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 
      
 3 
     | 
    
         
            +
            from dataclasses import dataclass
         
     | 
| 
      
 4 
     | 
    
         
            +
            from itertools import chain
         
     | 
| 
      
 5 
     | 
    
         
            +
            from types import MappingProxyType
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            import adjustText
         
     | 
| 
      
 8 
     | 
    
         
            +
            import anndata as ad
         
     | 
| 
      
 9 
     | 
    
         
            +
            import matplotlib.patheffects as PathEffects
         
     | 
| 
      
 10 
     | 
    
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 
      
 11 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 12 
     | 
    
         
            +
            import pandas as pd
         
     | 
| 
      
 13 
     | 
    
         
            +
            import seaborn as sns
         
     | 
| 
      
 14 
     | 
    
         
            +
            from matplotlib.ticker import MaxNLocator
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
         
     | 
| 
      
 17 
     | 
    
         
            +
            from pertpy.tools._differential_gene_expression._formulaic import (
         
     | 
| 
      
 18 
     | 
    
         
            +
                AmbiguousAttributeError,
         
     | 
| 
      
 19 
     | 
    
         
            +
                Factor,
         
     | 
| 
      
 20 
     | 
    
         
            +
                get_factor_storage_and_materializer,
         
     | 
| 
      
 21 
     | 
    
         
            +
                resolve_ambiguous,
         
     | 
| 
      
 22 
     | 
    
         
            +
            )
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 26 
     | 
    
         
            +
            class Contrast:
         
     | 
| 
      
 27 
     | 
    
         
            +
                """Simple contrast for comparison between groups"""
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                column: str
         
     | 
| 
      
 30 
     | 
    
         
            +
                baseline: str
         
     | 
| 
      
 31 
     | 
    
         
            +
                group_to_compare: str
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
            ContrastType = Contrast | tuple[str, str, str]
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
            class MethodBase(ABC):
         
     | 
| 
      
 38 
     | 
    
         
            +
                def __init__(self, adata, *, mask=None, layer=None, **kwargs):
         
     | 
| 
      
 39 
     | 
    
         
            +
                    """
         
     | 
| 
      
 40 
     | 
    
         
            +
                    Initialize the method.
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 43 
     | 
    
         
            +
                        adata: AnnData object, usually pseudobulked.
         
     | 
| 
      
 44 
     | 
    
         
            +
                        mask: A column in `adata.var` that contains a boolean mask with selected features.
         
     | 
| 
      
 45 
     | 
    
         
            +
                        layer: Layer to use in fit(). If None, use the X array.
         
     | 
| 
      
 46 
     | 
    
         
            +
                        **kwargs: Keyword arguments specific to the method implementation.
         
     | 
| 
      
 47 
     | 
    
         
            +
                    """
         
     | 
| 
      
 48 
     | 
    
         
            +
                    self.adata = adata
         
     | 
| 
      
 49 
     | 
    
         
            +
                    if mask is not None:
         
     | 
| 
      
 50 
     | 
    
         
            +
                        self.adata = self.adata[:, self.adata.var[mask]]
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
                    self.layer = layer
         
     | 
| 
      
 53 
     | 
    
         
            +
                    check_is_numeric_matrix(self.data)
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                @property
         
     | 
| 
      
 56 
     | 
    
         
            +
                def data(self):
         
     | 
| 
      
 57 
     | 
    
         
            +
                    """Get the data matrix from anndata this object was initalized with (X or layer)."""
         
     | 
| 
      
 58 
     | 
    
         
            +
                    if self.layer is None:
         
     | 
| 
      
 59 
     | 
    
         
            +
                        return self.adata.X
         
     | 
| 
      
 60 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 61 
     | 
    
         
            +
                        return self.adata.layer[self.layer]
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 64 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 65 
     | 
    
         
            +
                def compare_groups(
         
     | 
| 
      
 66 
     | 
    
         
            +
                    cls,
         
     | 
| 
      
 67 
     | 
    
         
            +
                    adata,
         
     | 
| 
      
 68 
     | 
    
         
            +
                    column,
         
     | 
| 
      
 69 
     | 
    
         
            +
                    baseline,
         
     | 
| 
      
 70 
     | 
    
         
            +
                    groups_to_compare,
         
     | 
| 
      
 71 
     | 
    
         
            +
                    *,
         
     | 
| 
      
 72 
     | 
    
         
            +
                    paired_by=None,
         
     | 
| 
      
 73 
     | 
    
         
            +
                    mask=None,
         
     | 
| 
      
 74 
     | 
    
         
            +
                    layer=None,
         
     | 
| 
      
 75 
     | 
    
         
            +
                    fit_kwargs=MappingProxyType({}),
         
     | 
| 
      
 76 
     | 
    
         
            +
                    test_kwargs=MappingProxyType({}),
         
     | 
| 
      
 77 
     | 
    
         
            +
                ):
         
     | 
| 
      
 78 
     | 
    
         
            +
                    """
         
     | 
| 
      
 79 
     | 
    
         
            +
                    Compare between groups in a specified column.
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 82 
     | 
    
         
            +
                        adata: AnnData object.
         
     | 
| 
      
 83 
     | 
    
         
            +
                        column: column in obs that contains the grouping information.
         
     | 
| 
      
 84 
     | 
    
         
            +
                        baseline: baseline value (one category from variable).
         
     | 
| 
      
 85 
     | 
    
         
            +
                        groups_to_compare: One or multiple categories from variable to compare against baseline.
         
     | 
| 
      
 86 
     | 
    
         
            +
                        paired_by: Column from `obs` that contains information about paired sample (e.g. subject_id).
         
     | 
| 
      
 87 
     | 
    
         
            +
                        mask: Subset anndata by a boolean mask stored in this column in `.obs` before making any tests.
         
     | 
| 
      
 88 
     | 
    
         
            +
                        layer: Use this layer instead of `.X`.
         
     | 
| 
      
 89 
     | 
    
         
            +
                        fit_kwargs: Additional fit options.
         
     | 
| 
      
 90 
     | 
    
         
            +
                        test_kwargs: Additional test options.
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 93 
     | 
    
         
            +
                        Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
         
     | 
| 
      
 94 
     | 
    
         
            +
                    """
         
     | 
| 
      
 95 
     | 
    
         
            +
                    ...
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
                def plot_volcano(
         
     | 
| 
      
 98 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 99 
     | 
    
         
            +
                    data: pd.DataFrame | ad.AnnData,
         
     | 
| 
      
 100 
     | 
    
         
            +
                    *,
         
     | 
| 
      
 101 
     | 
    
         
            +
                    log2fc_col: str = "log_fc",
         
     | 
| 
      
 102 
     | 
    
         
            +
                    pvalue_col: str = "adj_p_value",
         
     | 
| 
      
 103 
     | 
    
         
            +
                    symbol_col: str = "variable",
         
     | 
| 
      
 104 
     | 
    
         
            +
                    pval_thresh: float = 0.05,
         
     | 
| 
      
 105 
     | 
    
         
            +
                    log2fc_thresh: float = 0.75,
         
     | 
| 
      
 106 
     | 
    
         
            +
                    to_label: int | list[str] = 5,
         
     | 
| 
      
 107 
     | 
    
         
            +
                    s_curve: bool | None = False,
         
     | 
| 
      
 108 
     | 
    
         
            +
                    colors: list[str] = None,
         
     | 
| 
      
 109 
     | 
    
         
            +
                    varm_key: str | None = None,
         
     | 
| 
      
 110 
     | 
    
         
            +
                    color_dict: dict[str, list[str]] | None = None,
         
     | 
| 
      
 111 
     | 
    
         
            +
                    shape_dict: dict[str, list[str]] | None = None,
         
     | 
| 
      
 112 
     | 
    
         
            +
                    size_col: str | None = None,
         
     | 
| 
      
 113 
     | 
    
         
            +
                    fontsize: int = 10,
         
     | 
| 
      
 114 
     | 
    
         
            +
                    top_right_frame: bool = False,
         
     | 
| 
      
 115 
     | 
    
         
            +
                    figsize: tuple[int, int] = (5, 5),
         
     | 
| 
      
 116 
     | 
    
         
            +
                    legend_pos: tuple[float, float] = (1.6, 1),
         
     | 
| 
      
 117 
     | 
    
         
            +
                    point_sizes: tuple[int, int] = (15, 150),
         
     | 
| 
      
 118 
     | 
    
         
            +
                    save: bool | str | None = None,
         
     | 
| 
      
 119 
     | 
    
         
            +
                    shapes: list[str] | None = None,
         
     | 
| 
      
 120 
     | 
    
         
            +
                    shape_order: list[str] | None = None,
         
     | 
| 
      
 121 
     | 
    
         
            +
                    x_label: str | None = None,
         
     | 
| 
      
 122 
     | 
    
         
            +
                    y_label: str | None = None,
         
     | 
| 
      
 123 
     | 
    
         
            +
                    **kwargs: int,
         
     | 
| 
      
 124 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 125 
     | 
    
         
            +
                    """Creates a volcano plot from a pandas DataFrame or Anndata.
         
     | 
| 
      
 126 
     | 
    
         
            +
             
     | 
| 
      
 127 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 128 
     | 
    
         
            +
                        data: DataFrame or Anndata to plot.
         
     | 
| 
      
 129 
     | 
    
         
            +
                        log2fc_col: Column name of log2 Fold-Change values.
         
     | 
| 
      
 130 
     | 
    
         
            +
                        pvalue_col: Column name of the p values.
         
     | 
| 
      
 131 
     | 
    
         
            +
                        symbol_col: Column name of gene IDs.
         
     | 
| 
      
 132 
     | 
    
         
            +
                        varm_key: Key in Anndata.varm slot to use for plotting if an Anndata object was passed.
         
     | 
| 
      
 133 
     | 
    
         
            +
                        size_col: Column name to size points by.
         
     | 
| 
      
 134 
     | 
    
         
            +
                        point_sizes: Lower and upper bounds of point sizes.
         
     | 
| 
      
 135 
     | 
    
         
            +
                        pval_thresh: Threshold p value for significance.
         
     | 
| 
      
 136 
     | 
    
         
            +
                        log2fc_thresh: Threshold for log2 fold change significance.
         
     | 
| 
      
 137 
     | 
    
         
            +
                        to_label: Number of top genes or list of genes to label.
         
     | 
| 
      
 138 
     | 
    
         
            +
                        s_curve: Whether to use a reciprocal threshold for up and down gene determination.
         
     | 
| 
      
 139 
     | 
    
         
            +
                        color_dict: Dictionary for coloring dots by categories.
         
     | 
| 
      
 140 
     | 
    
         
            +
                        shape_dict: Dictionary for shaping dots by categories.
         
     | 
| 
      
 141 
     | 
    
         
            +
                        fontsize: Size of gene labels.
         
     | 
| 
      
 142 
     | 
    
         
            +
                        colors: Colors for [non-DE, up, down] genes. Defaults to ['gray', '#D62728', '#1F77B4'].
         
     | 
| 
      
 143 
     | 
    
         
            +
                        top_right_frame: Whether to show the top and right frame of the plot.
         
     | 
| 
      
 144 
     | 
    
         
            +
                        figsize: Size of the figure.
         
     | 
| 
      
 145 
     | 
    
         
            +
                        legend_pos: Position of the legend as determined by matplotlib.
         
     | 
| 
      
 146 
     | 
    
         
            +
                        save: Saves the plot if True or to the path provided.
         
     | 
| 
      
 147 
     | 
    
         
            +
                        shapes: List of matplotlib marker ids.
         
     | 
| 
      
 148 
     | 
    
         
            +
                        shape_order: Order of categories for shapes.
         
     | 
| 
      
 149 
     | 
    
         
            +
                        x_label: Label for the x-axis.
         
     | 
| 
      
 150 
     | 
    
         
            +
                        y_label: Label for the y-axis.
         
     | 
| 
      
 151 
     | 
    
         
            +
                        **kwargs: Additional arguments for seaborn.scatterplot.
         
     | 
| 
      
 152 
     | 
    
         
            +
                    """
         
     | 
| 
      
 153 
     | 
    
         
            +
                    if colors is None:
         
     | 
| 
      
 154 
     | 
    
         
            +
                        colors = ["gray", "#D62728", "#1F77B4"]
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                    def _pval_reciprocal(lfc: float) -> float:
         
     | 
| 
      
 157 
     | 
    
         
            +
                        """
         
     | 
| 
      
 158 
     | 
    
         
            +
                        Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
         
     | 
| 
      
 159 
     | 
    
         
            +
             
     | 
| 
      
 160 
     | 
    
         
            +
                        Used for plotting the S-curve
         
     | 
| 
      
 161 
     | 
    
         
            +
                        """
         
     | 
| 
      
 162 
     | 
    
         
            +
                        return pval_thresh / (lfc - log2fc_thresh)
         
     | 
| 
      
 163 
     | 
    
         
            +
             
     | 
| 
      
 164 
     | 
    
         
            +
                    def _map_shape(symbol: str) -> str:
         
     | 
| 
      
 165 
     | 
    
         
            +
                        if shape_dict is not None:
         
     | 
| 
      
 166 
     | 
    
         
            +
                            for k in shape_dict.keys():
         
     | 
| 
      
 167 
     | 
    
         
            +
                                if shape_dict[k] is not None and symbol in shape_dict[k]:
         
     | 
| 
      
 168 
     | 
    
         
            +
                                    return k
         
     | 
| 
      
 169 
     | 
    
         
            +
                        return "other"
         
     | 
| 
      
 170 
     | 
    
         
            +
             
     | 
| 
      
 171 
     | 
    
         
            +
                    # TODO join the two mapping functions
         
     | 
| 
      
 172 
     | 
    
         
            +
                    def _map_genes_categories(
         
     | 
| 
      
 173 
     | 
    
         
            +
                        row: pd.Series,
         
     | 
| 
      
 174 
     | 
    
         
            +
                        log2fc_col: str,
         
     | 
| 
      
 175 
     | 
    
         
            +
                        nlog10_col: str,
         
     | 
| 
      
 176 
     | 
    
         
            +
                        log2fc_thresh: float,
         
     | 
| 
      
 177 
     | 
    
         
            +
                        pval_thresh: float = None,
         
     | 
| 
      
 178 
     | 
    
         
            +
                        s_curve: bool = False,
         
     | 
| 
      
 179 
     | 
    
         
            +
                    ) -> str:
         
     | 
| 
      
 180 
     | 
    
         
            +
                        """
         
     | 
| 
      
 181 
     | 
    
         
            +
                        Map genes to categorize based on log2fc and pvalue.
         
     | 
| 
      
 182 
     | 
    
         
            +
             
     | 
| 
      
 183 
     | 
    
         
            +
                        These categories are used for coloring the dots.
         
     | 
| 
      
 184 
     | 
    
         
            +
                        Used when no color_dict is passed, sets up/down/nonsignificant.
         
     | 
| 
      
 185 
     | 
    
         
            +
                        """
         
     | 
| 
      
 186 
     | 
    
         
            +
                        log2fc = row[log2fc_col]
         
     | 
| 
      
 187 
     | 
    
         
            +
                        nlog10 = row[nlog10_col]
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                        if s_curve:
         
     | 
| 
      
 190 
     | 
    
         
            +
                            # S-curve condition for Up or Down categorization
         
     | 
| 
      
 191 
     | 
    
         
            +
                            reciprocal_thresh = _pval_reciprocal(abs(log2fc))
         
     | 
| 
      
 192 
     | 
    
         
            +
                            if log2fc > log2fc_thresh and nlog10 > reciprocal_thresh:
         
     | 
| 
      
 193 
     | 
    
         
            +
                                return "Up"
         
     | 
| 
      
 194 
     | 
    
         
            +
                            elif log2fc < -log2fc_thresh and nlog10 > reciprocal_thresh:
         
     | 
| 
      
 195 
     | 
    
         
            +
                                return "Down"
         
     | 
| 
      
 196 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 197 
     | 
    
         
            +
                                return "not DE"
         
     | 
| 
      
 198 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 199 
     | 
    
         
            +
                            # Standard condition for Up or Down categorization
         
     | 
| 
      
 200 
     | 
    
         
            +
                            if log2fc > log2fc_thresh and nlog10 > pval_thresh:
         
     | 
| 
      
 201 
     | 
    
         
            +
                                return "Up"
         
     | 
| 
      
 202 
     | 
    
         
            +
                            elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
         
     | 
| 
      
 203 
     | 
    
         
            +
                                return "Down"
         
     | 
| 
      
 204 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 205 
     | 
    
         
            +
                                return "not DE"
         
     | 
| 
      
 206 
     | 
    
         
            +
             
     | 
| 
      
 207 
     | 
    
         
            +
                    def _map_genes_categories_highlight(
         
     | 
| 
      
 208 
     | 
    
         
            +
                        row: pd.Series,
         
     | 
| 
      
 209 
     | 
    
         
            +
                        log2fc_col: str,
         
     | 
| 
      
 210 
     | 
    
         
            +
                        nlog10_col: str,
         
     | 
| 
      
 211 
     | 
    
         
            +
                        log2fc_thresh: float,
         
     | 
| 
      
 212 
     | 
    
         
            +
                        pval_thresh: float = None,
         
     | 
| 
      
 213 
     | 
    
         
            +
                        s_curve: bool = False,
         
     | 
| 
      
 214 
     | 
    
         
            +
                        symbol_col: str = None,
         
     | 
| 
      
 215 
     | 
    
         
            +
                    ) -> str:
         
     | 
| 
      
 216 
     | 
    
         
            +
                        """
         
     | 
| 
      
 217 
     | 
    
         
            +
                        Map genes to categorize based on log2fc and pvalue.
         
     | 
| 
      
 218 
     | 
    
         
            +
             
     | 
| 
      
 219 
     | 
    
         
            +
                        These categories are used for coloring the dots.
         
     | 
| 
      
 220 
     | 
    
         
            +
                        Used when color_dict is passed, sets DE / not DE for background and user supplied highlight genes.
         
     | 
| 
      
 221 
     | 
    
         
            +
                        """
         
     | 
| 
      
 222 
     | 
    
         
            +
                        log2fc = row[log2fc_col]
         
     | 
| 
      
 223 
     | 
    
         
            +
                        nlog10 = row[nlog10_col]
         
     | 
| 
      
 224 
     | 
    
         
            +
                        symbol = row[symbol_col]
         
     | 
| 
      
 225 
     | 
    
         
            +
             
     | 
| 
      
 226 
     | 
    
         
            +
                        if color_dict is not None:
         
     | 
| 
      
 227 
     | 
    
         
            +
                            for k in color_dict.keys():
         
     | 
| 
      
 228 
     | 
    
         
            +
                                if symbol in color_dict[k]:
         
     | 
| 
      
 229 
     | 
    
         
            +
                                    return k
         
     | 
| 
      
 230 
     | 
    
         
            +
             
     | 
| 
      
 231 
     | 
    
         
            +
                        if s_curve:
         
     | 
| 
      
 232 
     | 
    
         
            +
                            # Use S-curve condition for filtering DE
         
     | 
| 
      
 233 
     | 
    
         
            +
                            if nlog10 > _pval_reciprocal(abs(log2fc)) and abs(log2fc) > log2fc_thresh:
         
     | 
| 
      
 234 
     | 
    
         
            +
                                return "DE"
         
     | 
| 
      
 235 
     | 
    
         
            +
                            return "not DE"
         
     | 
| 
      
 236 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 237 
     | 
    
         
            +
                            # Use standard condition for filtering DE
         
     | 
| 
      
 238 
     | 
    
         
            +
                            if abs(log2fc) < log2fc_thresh or nlog10 < pval_thresh:
         
     | 
| 
      
 239 
     | 
    
         
            +
                                return "not DE"
         
     | 
| 
      
 240 
     | 
    
         
            +
                            return "DE"
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
                    if isinstance(data, ad.AnnData):
         
     | 
| 
      
 243 
     | 
    
         
            +
                        if varm_key is None:
         
     | 
| 
      
 244 
     | 
    
         
            +
                            raise ValueError("Please pass a .varm key to use for plotting")
         
     | 
| 
      
 245 
     | 
    
         
            +
             
     | 
| 
      
 246 
     | 
    
         
            +
                        raise NotImplementedError("Anndata not implemented yet")
         
     | 
| 
      
 247 
     | 
    
         
            +
                        df = data.varm[varm_key].copy()
         
     | 
| 
      
 248 
     | 
    
         
            +
             
     | 
| 
      
 249 
     | 
    
         
            +
                    df = data.copy(deep=True)
         
     | 
| 
      
 250 
     | 
    
         
            +
             
     | 
| 
      
 251 
     | 
    
         
            +
                    # clean and replace 0s as they would lead to -inf
         
     | 
| 
      
 252 
     | 
    
         
            +
                    if df[[log2fc_col, pvalue_col]].isnull().values.any():
         
     | 
| 
      
 253 
     | 
    
         
            +
                        print("NaNs encountered, dropping rows with NaNs")
         
     | 
| 
      
 254 
     | 
    
         
            +
                        df = df.dropna(subset=[log2fc_col, pvalue_col])
         
     | 
| 
      
 255 
     | 
    
         
            +
             
     | 
| 
      
 256 
     | 
    
         
            +
                    if df[pvalue_col].min() == 0:
         
     | 
| 
      
 257 
     | 
    
         
            +
                        print("0s encountered for p value, replacing with 1e-323")
         
     | 
| 
      
 258 
     | 
    
         
            +
                        df.loc[df[pvalue_col] == 0, pvalue_col] = 1e-323
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
                    # convert p value threshold to nlog10
         
     | 
| 
      
 261 
     | 
    
         
            +
                    pval_thresh = -np.log10(pval_thresh)
         
     | 
| 
      
 262 
     | 
    
         
            +
                    # make nlog10 column
         
     | 
| 
      
 263 
     | 
    
         
            +
                    df["nlog10"] = -np.log10(df[pvalue_col])
         
     | 
| 
      
 264 
     | 
    
         
            +
                    y_max = df["nlog10"].max() + 1
         
     | 
| 
      
 265 
     | 
    
         
            +
                    # make a column to pick top genes
         
     | 
| 
      
 266 
     | 
    
         
            +
                    df["top_genes"] = df["nlog10"] * df[log2fc_col]
         
     | 
| 
      
 267 
     | 
    
         
            +
             
     | 
| 
      
 268 
     | 
    
         
            +
                    # Label everything with assigned color / shape
         
     | 
| 
      
 269 
     | 
    
         
            +
                    if shape_dict or color_dict:
         
     | 
| 
      
 270 
     | 
    
         
            +
                        combined_labels = []
         
     | 
| 
      
 271 
     | 
    
         
            +
                        if isinstance(shape_dict, dict):
         
     | 
| 
      
 272 
     | 
    
         
            +
                            combined_labels.extend([item for sublist in shape_dict.values() for item in sublist])
         
     | 
| 
      
 273 
     | 
    
         
            +
                        if isinstance(color_dict, dict):
         
     | 
| 
      
 274 
     | 
    
         
            +
                            combined_labels.extend([item for sublist in color_dict.values() for item in sublist])
         
     | 
| 
      
 275 
     | 
    
         
            +
                        label_df = df[df[symbol_col].isin(combined_labels)]
         
     | 
| 
      
 276 
     | 
    
         
            +
             
     | 
| 
      
 277 
     | 
    
         
            +
                    # Label top n_gens
         
     | 
| 
      
 278 
     | 
    
         
            +
                    elif isinstance(to_label, int):
         
     | 
| 
      
 279 
     | 
    
         
            +
                        label_df = pd.concat(
         
     | 
| 
      
 280 
     | 
    
         
            +
                            (
         
     | 
| 
      
 281 
     | 
    
         
            +
                                df.sort_values("top_genes")[-to_label:],
         
     | 
| 
      
 282 
     | 
    
         
            +
                                df.sort_values("top_genes")[0:to_label],
         
     | 
| 
      
 283 
     | 
    
         
            +
                            )
         
     | 
| 
      
 284 
     | 
    
         
            +
                        )
         
     | 
| 
      
 285 
     | 
    
         
            +
             
     | 
| 
      
 286 
     | 
    
         
            +
                    # assume that a list of genes was passed to label
         
     | 
| 
      
 287 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 288 
     | 
    
         
            +
                        label_df = df[df[symbol_col].isin(to_label)]
         
     | 
| 
      
 289 
     | 
    
         
            +
             
     | 
| 
      
 290 
     | 
    
         
            +
                    # By default mode colors by up/down if no dict is passed
         
     | 
| 
      
 291 
     | 
    
         
            +
             
     | 
| 
      
 292 
     | 
    
         
            +
                    if color_dict is None:
         
     | 
| 
      
 293 
     | 
    
         
            +
                        df["color"] = df.apply(
         
     | 
| 
      
 294 
     | 
    
         
            +
                            lambda row: _map_genes_categories(
         
     | 
| 
      
 295 
     | 
    
         
            +
                                row,
         
     | 
| 
      
 296 
     | 
    
         
            +
                                log2fc_col=log2fc_col,
         
     | 
| 
      
 297 
     | 
    
         
            +
                                nlog10_col="nlog10",
         
     | 
| 
      
 298 
     | 
    
         
            +
                                log2fc_thresh=log2fc_thresh,
         
     | 
| 
      
 299 
     | 
    
         
            +
                                pval_thresh=pval_thresh,
         
     | 
| 
      
 300 
     | 
    
         
            +
                                s_curve=s_curve,
         
     | 
| 
      
 301 
     | 
    
         
            +
                            ),
         
     | 
| 
      
 302 
     | 
    
         
            +
                            axis=1,
         
     | 
| 
      
 303 
     | 
    
         
            +
                        )
         
     | 
| 
      
 304 
     | 
    
         
            +
             
     | 
| 
      
 305 
     | 
    
         
            +
                        # order of colors
         
     | 
| 
      
 306 
     | 
    
         
            +
                        hues = ["not DE", "Up", "Down"][: len(df.color.unique())]
         
     | 
| 
      
 307 
     | 
    
         
            +
             
     | 
| 
      
 308 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 309 
     | 
    
         
            +
                        df["color"] = df.apply(
         
     | 
| 
      
 310 
     | 
    
         
            +
                            lambda row: _map_genes_categories_highlight(
         
     | 
| 
      
 311 
     | 
    
         
            +
                                row,
         
     | 
| 
      
 312 
     | 
    
         
            +
                                log2fc_col=log2fc_col,
         
     | 
| 
      
 313 
     | 
    
         
            +
                                nlog10_col="nlog10",
         
     | 
| 
      
 314 
     | 
    
         
            +
                                log2fc_thresh=log2fc_thresh,
         
     | 
| 
      
 315 
     | 
    
         
            +
                                pval_thresh=pval_thresh,
         
     | 
| 
      
 316 
     | 
    
         
            +
                                symbol_col=symbol_col,
         
     | 
| 
      
 317 
     | 
    
         
            +
                                s_curve=s_curve,
         
     | 
| 
      
 318 
     | 
    
         
            +
                            ),
         
     | 
| 
      
 319 
     | 
    
         
            +
                            axis=1,
         
     | 
| 
      
 320 
     | 
    
         
            +
                        )
         
     | 
| 
      
 321 
     | 
    
         
            +
             
     | 
| 
      
 322 
     | 
    
         
            +
                        user_added_cats = [x for x in df.color.unique() if x not in ["DE", "not DE"]]
         
     | 
| 
      
 323 
     | 
    
         
            +
                        hues = ["DE", "not DE"] + user_added_cats
         
     | 
| 
      
 324 
     | 
    
         
            +
             
     | 
| 
      
 325 
     | 
    
         
            +
                        # order of colors
         
     | 
| 
      
 326 
     | 
    
         
            +
                        hues = hues[: len(df.color.unique())]
         
     | 
| 
      
 327 
     | 
    
         
            +
                        colors = [
         
     | 
| 
      
 328 
     | 
    
         
            +
                            "dimgrey",
         
     | 
| 
      
 329 
     | 
    
         
            +
                            "lightgrey",
         
     | 
| 
      
 330 
     | 
    
         
            +
                            "tab:blue",
         
     | 
| 
      
 331 
     | 
    
         
            +
                            "tab:orange",
         
     | 
| 
      
 332 
     | 
    
         
            +
                            "tab:green",
         
     | 
| 
      
 333 
     | 
    
         
            +
                            "tab:red",
         
     | 
| 
      
 334 
     | 
    
         
            +
                            "tab:purple",
         
     | 
| 
      
 335 
     | 
    
         
            +
                            "tab:brown",
         
     | 
| 
      
 336 
     | 
    
         
            +
                            "tab:pink",
         
     | 
| 
      
 337 
     | 
    
         
            +
                            "tab:olive",
         
     | 
| 
      
 338 
     | 
    
         
            +
                            "tab:cyan",
         
     | 
| 
      
 339 
     | 
    
         
            +
                        ]
         
     | 
| 
      
 340 
     | 
    
         
            +
             
     | 
| 
      
 341 
     | 
    
         
            +
                    # coloring if dictionary passed, subtle background + highlight
         
     | 
| 
      
 342 
     | 
    
         
            +
                    # map shapes if dictionary exists
         
     | 
| 
      
 343 
     | 
    
         
            +
                    if shape_dict is not None:
         
     | 
| 
      
 344 
     | 
    
         
            +
                        df["shape"] = df[symbol_col].map(_map_shape)
         
     | 
| 
      
 345 
     | 
    
         
            +
                        user_added_cats = [x for x in df["shape"].unique() if x != "other"]
         
     | 
| 
      
 346 
     | 
    
         
            +
                        shape_order = ["other"] + user_added_cats
         
     | 
| 
      
 347 
     | 
    
         
            +
                        if shapes is None:
         
     | 
| 
      
 348 
     | 
    
         
            +
                            shapes = ["o", "^", "s", "X", "*", "d"]
         
     | 
| 
      
 349 
     | 
    
         
            +
                        shapes = shapes[: len(df["shape"].unique())]
         
     | 
| 
      
 350 
     | 
    
         
            +
                        shape_col = "shape"
         
     | 
| 
      
 351 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 352 
     | 
    
         
            +
                        shape_col = None
         
     | 
| 
      
 353 
     | 
    
         
            +
             
     | 
| 
      
 354 
     | 
    
         
            +
                    # build palette
         
     | 
| 
      
 355 
     | 
    
         
            +
                    colors = colors[: len(df.color.unique())]
         
     | 
| 
      
 356 
     | 
    
         
            +
             
     | 
| 
      
 357 
     | 
    
         
            +
                    # We want plot highlighted genes on top + at bigger size, split dataframe
         
     | 
| 
      
 358 
     | 
    
         
            +
                    df_highlight = None
         
     | 
| 
      
 359 
     | 
    
         
            +
                    if shape_dict or color_dict:
         
     | 
| 
      
 360 
     | 
    
         
            +
                        label_genes = label_df[symbol_col].unique()
         
     | 
| 
      
 361 
     | 
    
         
            +
                        df_highlight = df[df[symbol_col].isin(label_genes)]
         
     | 
| 
      
 362 
     | 
    
         
            +
                        df = df[~df[symbol_col].isin(label_genes)]
         
     | 
| 
      
 363 
     | 
    
         
            +
             
     | 
| 
      
 364 
     | 
    
         
            +
                    plt.figure(figsize=figsize)
         
     | 
| 
      
 365 
     | 
    
         
            +
                    # Plot non-highlighted genes
         
     | 
| 
      
 366 
     | 
    
         
            +
                    ax = sns.scatterplot(
         
     | 
| 
      
 367 
     | 
    
         
            +
                        data=df,
         
     | 
| 
      
 368 
     | 
    
         
            +
                        x=log2fc_col,
         
     | 
| 
      
 369 
     | 
    
         
            +
                        y="nlog10",
         
     | 
| 
      
 370 
     | 
    
         
            +
                        hue="color",
         
     | 
| 
      
 371 
     | 
    
         
            +
                        hue_order=hues,
         
     | 
| 
      
 372 
     | 
    
         
            +
                        palette=colors,
         
     | 
| 
      
 373 
     | 
    
         
            +
                        size=size_col,
         
     | 
| 
      
 374 
     | 
    
         
            +
                        sizes=point_sizes,
         
     | 
| 
      
 375 
     | 
    
         
            +
                        style=shape_col,
         
     | 
| 
      
 376 
     | 
    
         
            +
                        style_order=shape_order,
         
     | 
| 
      
 377 
     | 
    
         
            +
                        markers=shapes,
         
     | 
| 
      
 378 
     | 
    
         
            +
                        **kwargs,
         
     | 
| 
      
 379 
     | 
    
         
            +
                    )
         
     | 
| 
      
 380 
     | 
    
         
            +
                    # Plot highlighted genes
         
     | 
| 
      
 381 
     | 
    
         
            +
                    if df_highlight is not None:
         
     | 
| 
      
 382 
     | 
    
         
            +
                        ax = sns.scatterplot(
         
     | 
| 
      
 383 
     | 
    
         
            +
                            data=df_highlight,
         
     | 
| 
      
 384 
     | 
    
         
            +
                            x=log2fc_col,
         
     | 
| 
      
 385 
     | 
    
         
            +
                            y="nlog10",
         
     | 
| 
      
 386 
     | 
    
         
            +
                            hue="color",
         
     | 
| 
      
 387 
     | 
    
         
            +
                            hue_order=hues,
         
     | 
| 
      
 388 
     | 
    
         
            +
                            palette=colors,
         
     | 
| 
      
 389 
     | 
    
         
            +
                            size=size_col,
         
     | 
| 
      
 390 
     | 
    
         
            +
                            sizes=point_sizes,
         
     | 
| 
      
 391 
     | 
    
         
            +
                            style=shape_col,
         
     | 
| 
      
 392 
     | 
    
         
            +
                            style_order=shape_order,
         
     | 
| 
      
 393 
     | 
    
         
            +
                            markers=shapes,
         
     | 
| 
      
 394 
     | 
    
         
            +
                            legend=False,
         
     | 
| 
      
 395 
     | 
    
         
            +
                            edgecolor="black",
         
     | 
| 
      
 396 
     | 
    
         
            +
                            linewidth=1,
         
     | 
| 
      
 397 
     | 
    
         
            +
                            **kwargs,
         
     | 
| 
      
 398 
     | 
    
         
            +
                        )
         
     | 
| 
      
 399 
     | 
    
         
            +
             
     | 
| 
      
 400 
     | 
    
         
            +
                    # plot vertical and horizontal lines
         
     | 
| 
      
 401 
     | 
    
         
            +
                    if s_curve:
         
     | 
| 
      
 402 
     | 
    
         
            +
                        x = np.arange((log2fc_thresh + 0.000001), y_max, 0.01)
         
     | 
| 
      
 403 
     | 
    
         
            +
                        y = _pval_reciprocal(x)
         
     | 
| 
      
 404 
     | 
    
         
            +
                        ax.plot(x, y, zorder=1, c="k", lw=2, ls="--")
         
     | 
| 
      
 405 
     | 
    
         
            +
                        ax.plot(-x, y, zorder=1, c="k", lw=2, ls="--")
         
     | 
| 
      
 406 
     | 
    
         
            +
             
     | 
| 
      
 407 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 408 
     | 
    
         
            +
                        ax.axhline(pval_thresh, zorder=1, c="k", lw=2, ls="--")
         
     | 
| 
      
 409 
     | 
    
         
            +
                        ax.axvline(log2fc_thresh, zorder=1, c="k", lw=2, ls="--")
         
     | 
| 
      
 410 
     | 
    
         
            +
                        ax.axvline(log2fc_thresh * -1, zorder=1, c="k", lw=2, ls="--")
         
     | 
| 
      
 411 
     | 
    
         
            +
                    plt.ylim(0, y_max)
         
     | 
| 
      
 412 
     | 
    
         
            +
                    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
         
     | 
| 
      
 413 
     | 
    
         
            +
             
     | 
| 
      
 414 
     | 
    
         
            +
                    # make labels
         
     | 
| 
      
 415 
     | 
    
         
            +
                    texts = []
         
     | 
| 
      
 416 
     | 
    
         
            +
                    for i in range(len(label_df)):
         
     | 
| 
      
 417 
     | 
    
         
            +
                        txt = plt.text(
         
     | 
| 
      
 418 
     | 
    
         
            +
                            x=label_df.iloc[i][log2fc_col],
         
     | 
| 
      
 419 
     | 
    
         
            +
                            y=label_df.iloc[i].nlog10,
         
     | 
| 
      
 420 
     | 
    
         
            +
                            s=label_df.iloc[i][symbol_col],
         
     | 
| 
      
 421 
     | 
    
         
            +
                            fontsize=fontsize,
         
     | 
| 
      
 422 
     | 
    
         
            +
                        )
         
     | 
| 
      
 423 
     | 
    
         
            +
             
     | 
| 
      
 424 
     | 
    
         
            +
                        txt.set_path_effects([PathEffects.withStroke(linewidth=3, foreground="w")])
         
     | 
| 
      
 425 
     | 
    
         
            +
                        texts.append(txt)
         
     | 
| 
      
 426 
     | 
    
         
            +
             
     | 
| 
      
 427 
     | 
    
         
            +
                    adjustText.adjust_text(texts, arrowprops={"arrowstyle": "-", "color": "k", "zorder": 5})
         
     | 
| 
      
 428 
     | 
    
         
            +
             
     | 
| 
      
 429 
     | 
    
         
            +
                    # make things pretty
         
     | 
| 
      
 430 
     | 
    
         
            +
                    for axis in ["bottom", "left", "top", "right"]:
         
     | 
| 
      
 431 
     | 
    
         
            +
                        ax.spines[axis].set_linewidth(2)
         
     | 
| 
      
 432 
     | 
    
         
            +
             
     | 
| 
      
 433 
     | 
    
         
            +
                    if not top_right_frame:
         
     | 
| 
      
 434 
     | 
    
         
            +
                        ax.spines["top"].set_visible(False)
         
     | 
| 
      
 435 
     | 
    
         
            +
                        ax.spines["right"].set_visible(False)
         
     | 
| 
      
 436 
     | 
    
         
            +
             
     | 
| 
      
 437 
     | 
    
         
            +
                    ax.tick_params(width=2)
         
     | 
| 
      
 438 
     | 
    
         
            +
                    plt.xticks(size=11, fontsize=10)
         
     | 
| 
      
 439 
     | 
    
         
            +
                    plt.yticks(size=11)
         
     | 
| 
      
 440 
     | 
    
         
            +
             
     | 
| 
      
 441 
     | 
    
         
            +
                    # Set default axis titles
         
     | 
| 
      
 442 
     | 
    
         
            +
                    if x_label is None:
         
     | 
| 
      
 443 
     | 
    
         
            +
                        x_label = log2fc_col
         
     | 
| 
      
 444 
     | 
    
         
            +
                    if y_label is None:
         
     | 
| 
      
 445 
     | 
    
         
            +
                        y_label = f"-$log_{{10}}$ {pvalue_col}"
         
     | 
| 
      
 446 
     | 
    
         
            +
             
     | 
| 
      
 447 
     | 
    
         
            +
                    plt.xlabel(x_label, size=15)
         
     | 
| 
      
 448 
     | 
    
         
            +
                    plt.ylabel(y_label, size=15)
         
     | 
| 
      
 449 
     | 
    
         
            +
             
     | 
| 
      
 450 
     | 
    
         
            +
                    plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
         
     | 
| 
      
 451 
     | 
    
         
            +
             
     | 
| 
      
 452 
     | 
    
         
            +
                    # TODO replace with scanpy save style
         
     | 
| 
      
 453 
     | 
    
         
            +
                    if save:
         
     | 
| 
      
 454 
     | 
    
         
            +
                        files = os.listdir()
         
     | 
| 
      
 455 
     | 
    
         
            +
                        for x in range(100):
         
     | 
| 
      
 456 
     | 
    
         
            +
                            file_pref = "volcano_" + "%02d" % (x,)
         
     | 
| 
      
 457 
     | 
    
         
            +
                            if len([x for x in files if x.startswith(file_pref)]) == 0:
         
     | 
| 
      
 458 
     | 
    
         
            +
                                plt.savefig(file_pref + ".png", dpi=300, bbox_inches="tight")
         
     | 
| 
      
 459 
     | 
    
         
            +
                                plt.savefig(file_pref + ".svg", bbox_inches="tight")
         
     | 
| 
      
 460 
     | 
    
         
            +
                                break
         
     | 
| 
      
 461 
     | 
    
         
            +
                    elif isinstance(save, str):
         
     | 
| 
      
 462 
     | 
    
         
            +
                        plt.savefig(save + ".png", dpi=300, bbox_inches="tight")
         
     | 
| 
      
 463 
     | 
    
         
            +
                        plt.savefig(save + ".svg", bbox_inches="tight")
         
     | 
| 
      
 464 
     | 
    
         
            +
             
     | 
| 
      
 465 
     | 
    
         
            +
                    plt.show()
         
     | 
| 
      
 466 
     | 
    
         
            +
             
     | 
| 
      
 467 
     | 
    
         
            +
             
     | 
| 
      
 468 
     | 
    
         
            +
            class LinearModelBase(MethodBase):
         
     | 
| 
      
 469 
     | 
    
         
            +
                def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
         
     | 
| 
      
 470 
     | 
    
         
            +
                    """
         
     | 
| 
      
 471 
     | 
    
         
            +
                    Initialize the method.
         
     | 
| 
      
 472 
     | 
    
         
            +
             
     | 
| 
      
 473 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 474 
     | 
    
         
            +
                        adata: AnnData object, usually pseudobulked.
         
     | 
| 
      
 475 
     | 
    
         
            +
                        design: Model design. Can be either a design matrix, a formulaic formula.Formulaic formula in the format 'x + z' or '~x+z'.
         
     | 
| 
      
 476 
     | 
    
         
            +
                        mask: A column in adata.var that contains a boolean mask with selected features.
         
     | 
| 
      
 477 
     | 
    
         
            +
                        layer: Layer to use in fit(). If None, use the X array.
         
     | 
| 
      
 478 
     | 
    
         
            +
                        **kwargs: Keyword arguments specific to the method implementation.
         
     | 
| 
      
 479 
     | 
    
         
            +
                    """
         
     | 
| 
      
 480 
     | 
    
         
            +
                    super().__init__(adata, mask=mask, layer=layer)
         
     | 
| 
      
 481 
     | 
    
         
            +
                    self._check_counts()
         
     | 
| 
      
 482 
     | 
    
         
            +
             
     | 
| 
      
 483 
     | 
    
         
            +
                    self.factor_storage = None
         
     | 
| 
      
 484 
     | 
    
         
            +
                    self.variable_to_factors = None
         
     | 
| 
      
 485 
     | 
    
         
            +
             
     | 
| 
      
 486 
     | 
    
         
            +
                    if isinstance(design, str):
         
     | 
| 
      
 487 
     | 
    
         
            +
                        self.factor_storage, self.variable_to_factors, materializer_class = get_factor_storage_and_materializer()
         
     | 
| 
      
 488 
     | 
    
         
            +
                        self.design = materializer_class(adata.obs, record_factor_metadata=True).get_model_matrix(design)
         
     | 
| 
      
 489 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 490 
     | 
    
         
            +
                        self.design = design
         
     | 
| 
      
 491 
     | 
    
         
            +
             
     | 
| 
      
 492 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 493 
     | 
    
         
            +
                def compare_groups(
         
     | 
| 
      
 494 
     | 
    
         
            +
                    cls,
         
     | 
| 
      
 495 
     | 
    
         
            +
                    adata,
         
     | 
| 
      
 496 
     | 
    
         
            +
                    column,
         
     | 
| 
      
 497 
     | 
    
         
            +
                    baseline,
         
     | 
| 
      
 498 
     | 
    
         
            +
                    groups_to_compare,
         
     | 
| 
      
 499 
     | 
    
         
            +
                    *,
         
     | 
| 
      
 500 
     | 
    
         
            +
                    paired_by=None,
         
     | 
| 
      
 501 
     | 
    
         
            +
                    mask=None,
         
     | 
| 
      
 502 
     | 
    
         
            +
                    layer=None,
         
     | 
| 
      
 503 
     | 
    
         
            +
                    fit_kwargs=MappingProxyType({}),
         
     | 
| 
      
 504 
     | 
    
         
            +
                    test_kwargs=MappingProxyType({}),
         
     | 
| 
      
 505 
     | 
    
         
            +
                ):
         
     | 
| 
      
 506 
     | 
    
         
            +
                    design = f"~{column}"
         
     | 
| 
      
 507 
     | 
    
         
            +
                    if paired_by is not None:
         
     | 
| 
      
 508 
     | 
    
         
            +
                        design += f"+{paired_by}"
         
     | 
| 
      
 509 
     | 
    
         
            +
                    if isinstance(groups_to_compare, str):
         
     | 
| 
      
 510 
     | 
    
         
            +
                        groups_to_compare = [groups_to_compare]
         
     | 
| 
      
 511 
     | 
    
         
            +
                    model = cls(adata, design=design, mask=mask, layer=layer)
         
     | 
| 
      
 512 
     | 
    
         
            +
             
     | 
| 
      
 513 
     | 
    
         
            +
                    model.fit(**fit_kwargs)
         
     | 
| 
      
 514 
     | 
    
         
            +
             
     | 
| 
      
 515 
     | 
    
         
            +
                    de_res = model.test_contrasts(
         
     | 
| 
      
 516 
     | 
    
         
            +
                        {
         
     | 
| 
      
 517 
     | 
    
         
            +
                            group_to_compare: model.contrast(column=column, baseline=baseline, group_to_compare=group_to_compare)
         
     | 
| 
      
 518 
     | 
    
         
            +
                            for group_to_compare in groups_to_compare
         
     | 
| 
      
 519 
     | 
    
         
            +
                        },
         
     | 
| 
      
 520 
     | 
    
         
            +
                        **test_kwargs,
         
     | 
| 
      
 521 
     | 
    
         
            +
                    )
         
     | 
| 
      
 522 
     | 
    
         
            +
             
     | 
| 
      
 523 
     | 
    
         
            +
                    return de_res
         
     | 
| 
      
 524 
     | 
    
         
            +
             
     | 
| 
      
 525 
     | 
    
         
            +
                @property
         
     | 
| 
      
 526 
     | 
    
         
            +
                def variables(self):
         
     | 
| 
      
 527 
     | 
    
         
            +
                    """Get the names of the variables used in the model definition."""
         
     | 
| 
      
 528 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 529 
     | 
    
         
            +
                        return self.design.model_spec.variables_by_source["data"]
         
     | 
| 
      
 530 
     | 
    
         
            +
                    except AttributeError:
         
     | 
| 
      
 531 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 532 
     | 
    
         
            +
                            "Retrieving variables is only possible if the model was initialized using a formula."
         
     | 
| 
      
 533 
     | 
    
         
            +
                        ) from None
         
     | 
| 
      
 534 
     | 
    
         
            +
             
     | 
| 
      
 535 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 536 
     | 
    
         
            +
                def _check_counts(self):
         
     | 
| 
      
 537 
     | 
    
         
            +
                    """
         
     | 
| 
      
 538 
     | 
    
         
            +
                    Check that counts are valid for the specific method.
         
     | 
| 
      
 539 
     | 
    
         
            +
             
     | 
| 
      
 540 
     | 
    
         
            +
                    Raises:
         
     | 
| 
      
 541 
     | 
    
         
            +
                        ValueError: if the data matrix does not comply with the expectations.
         
     | 
| 
      
 542 
     | 
    
         
            +
                    """
         
     | 
| 
      
 543 
     | 
    
         
            +
                    ...
         
     | 
| 
      
 544 
     | 
    
         
            +
             
     | 
| 
      
 545 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 546 
     | 
    
         
            +
                def fit(self, **kwargs):
         
     | 
| 
      
 547 
     | 
    
         
            +
                    """
         
     | 
| 
      
 548 
     | 
    
         
            +
                    Fit the model.
         
     | 
| 
      
 549 
     | 
    
         
            +
             
     | 
| 
      
 550 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 551 
     | 
    
         
            +
                        **kwargs: Additional arguments for fitting the specific method.
         
     | 
| 
      
 552 
     | 
    
         
            +
                    """
         
     | 
| 
      
 553 
     | 
    
         
            +
                    ...
         
     | 
| 
      
 554 
     | 
    
         
            +
             
     | 
| 
      
 555 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 556 
     | 
    
         
            +
                def _test_single_contrast(self, contrast, **kwargs): ...
         
     | 
| 
      
 557 
     | 
    
         
            +
             
     | 
| 
      
 558 
     | 
    
         
            +
                def test_contrasts(self, contrasts, **kwargs):
         
     | 
| 
      
 559 
     | 
    
         
            +
                    """
         
     | 
| 
      
 560 
     | 
    
         
            +
                    Perform a comparison as specified in a contrast vector.
         
     | 
| 
      
 561 
     | 
    
         
            +
             
     | 
| 
      
 562 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 563 
     | 
    
         
            +
                        contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
         
     | 
| 
      
 564 
     | 
    
         
            +
                        **kwargs: passed to the respective implementation.
         
     | 
| 
      
 565 
     | 
    
         
            +
             
     | 
| 
      
 566 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 567 
     | 
    
         
            +
                        A dataframe with the results.
         
     | 
| 
      
 568 
     | 
    
         
            +
                    """
         
     | 
| 
      
 569 
     | 
    
         
            +
                    if not isinstance(contrasts, dict):
         
     | 
| 
      
 570 
     | 
    
         
            +
                        contrasts = {None: contrasts}
         
     | 
| 
      
 571 
     | 
    
         
            +
                    results = []
         
     | 
| 
      
 572 
     | 
    
         
            +
                    for name, contrast in contrasts.items():
         
     | 
| 
      
 573 
     | 
    
         
            +
                        results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
         
     | 
| 
      
 574 
     | 
    
         
            +
             
     | 
| 
      
 575 
     | 
    
         
            +
                    results_df = pd.concat(results)
         
     | 
| 
      
 576 
     | 
    
         
            +
                    return results_df
         
     | 
| 
      
 577 
     | 
    
         
            +
             
     | 
| 
      
 578 
     | 
    
         
            +
                def test_reduced(self, modelB):
         
     | 
| 
      
 579 
     | 
    
         
            +
                    """
         
     | 
| 
      
 580 
     | 
    
         
            +
                    Test against a reduced model.
         
     | 
| 
      
 581 
     | 
    
         
            +
             
     | 
| 
      
 582 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 583 
     | 
    
         
            +
                        modelB: the reduced model against which to test.
         
     | 
| 
      
 584 
     | 
    
         
            +
             
     | 
| 
      
 585 
     | 
    
         
            +
                    Example:
         
     | 
| 
      
 586 
     | 
    
         
            +
                        modelA = Model().fit()
         
     | 
| 
      
 587 
     | 
    
         
            +
                        modelB = Model().fit()
         
     | 
| 
      
 588 
     | 
    
         
            +
                        modelA.test_reduced(modelB)
         
     | 
| 
      
 589 
     | 
    
         
            +
                    """
         
     | 
| 
      
 590 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 591 
     | 
    
         
            +
             
     | 
| 
      
 592 
     | 
    
         
            +
                def cond(self, **kwargs):
         
     | 
| 
      
 593 
     | 
    
         
            +
                    """
         
     | 
| 
      
 594 
     | 
    
         
            +
                    Get a contrast vector representing a specific condition.
         
     | 
| 
      
 595 
     | 
    
         
            +
             
     | 
| 
      
 596 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 597 
     | 
    
         
            +
                        **kwargs: column/value pairs.
         
     | 
| 
      
 598 
     | 
    
         
            +
             
     | 
| 
      
 599 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 600 
     | 
    
         
            +
                        A contrast vector that aligns to the columns of the design matrix.
         
     | 
| 
      
 601 
     | 
    
         
            +
                    """
         
     | 
| 
      
 602 
     | 
    
         
            +
                    if self.factor_storage is None:
         
     | 
| 
      
 603 
     | 
    
         
            +
                        raise RuntimeError(
         
     | 
| 
      
 604 
     | 
    
         
            +
                            "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
         
     | 
| 
      
 605 
     | 
    
         
            +
                        )
         
     | 
| 
      
 606 
     | 
    
         
            +
                    cond_dict = kwargs
         
     | 
| 
      
 607 
     | 
    
         
            +
                    if not set(cond_dict.keys()).issubset(self.variables):
         
     | 
| 
      
 608 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 609 
     | 
    
         
            +
                            "You specified a variable that is not part of the model. Available variables: "
         
     | 
| 
      
 610 
     | 
    
         
            +
                            + ",".join(self.variables)
         
     | 
| 
      
 611 
     | 
    
         
            +
                        )
         
     | 
| 
      
 612 
     | 
    
         
            +
                    for var in self.variables:
         
     | 
| 
      
 613 
     | 
    
         
            +
                        if var in cond_dict:
         
     | 
| 
      
 614 
     | 
    
         
            +
                            self._check_category(var, cond_dict[var])
         
     | 
| 
      
 615 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 616 
     | 
    
         
            +
                            cond_dict[var] = self._get_default_value(var)
         
     | 
| 
      
 617 
     | 
    
         
            +
                    df = pd.DataFrame([kwargs])
         
     | 
| 
      
 618 
     | 
    
         
            +
                    return self.design.model_spec.get_model_matrix(df).iloc[0]
         
     | 
| 
      
 619 
     | 
    
         
            +
             
     | 
| 
      
 620 
     | 
    
         
            +
                def _get_factor_metadata_for_variable(self, var):
         
     | 
| 
      
 621 
     | 
    
         
            +
                    factors = self.variable_to_factors[var]
         
     | 
| 
      
 622 
     | 
    
         
            +
                    return list(chain.from_iterable(self.factor_storage[f] for f in factors))
         
     | 
| 
      
 623 
     | 
    
         
            +
             
     | 
| 
      
 624 
     | 
    
         
            +
                def _get_default_value(self, var):
         
     | 
| 
      
 625 
     | 
    
         
            +
                    factor_metadata = self._get_factor_metadata_for_variable(var)
         
     | 
| 
      
 626 
     | 
    
         
            +
                    if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
         
     | 
| 
      
 627 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 628 
     | 
    
         
            +
                            tmp_base = resolve_ambiguous(factor_metadata, "base")
         
     | 
| 
      
 629 
     | 
    
         
            +
                        except AmbiguousAttributeError as e:
         
     | 
| 
      
 630 
     | 
    
         
            +
                            raise ValueError(
         
     | 
| 
      
 631 
     | 
    
         
            +
                                f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
         
     | 
| 
      
 632 
     | 
    
         
            +
                            ) from e
         
     | 
| 
      
 633 
     | 
    
         
            +
                        return tmp_base if tmp_base is not None else "\0"
         
     | 
| 
      
 634 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 635 
     | 
    
         
            +
                        return 0
         
     | 
| 
      
 636 
     | 
    
         
            +
             
     | 
| 
      
 637 
     | 
    
         
            +
                def _check_category(self, var, value):
         
     | 
| 
      
 638 
     | 
    
         
            +
                    factor_metadata = self._get_factor_metadata_for_variable(var)
         
     | 
| 
      
 639 
     | 
    
         
            +
                    tmp_categories = resolve_ambiguous(factor_metadata, "categories")
         
     | 
| 
      
 640 
     | 
    
         
            +
                    if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
         
     | 
| 
      
 641 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 642 
     | 
    
         
            +
                            f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
         
     | 
| 
      
 643 
     | 
    
         
            +
                        )
         
     | 
| 
      
 644 
     | 
    
         
            +
             
     | 
| 
      
 645 
     | 
    
         
            +
                def contrast(self, column, baseline, group_to_compare):
         
     | 
| 
      
 646 
     | 
    
         
            +
                    """
         
     | 
| 
      
 647 
     | 
    
         
            +
                    Build a simple contrast for pairwise comparisons.
         
     | 
| 
      
 648 
     | 
    
         
            +
             
     | 
| 
      
 649 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 650 
     | 
    
         
            +
                        column: column in adata.obs to test on.
         
     | 
| 
      
 651 
     | 
    
         
            +
                        baseline: baseline category (denominator).
         
     | 
| 
      
 652 
     | 
    
         
            +
                        group_to_compare: category to compare against baseline (nominator).
         
     | 
| 
      
 653 
     | 
    
         
            +
             
     | 
| 
      
 654 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 655 
     | 
    
         
            +
                        Numeric contrast vector.
         
     | 
| 
      
 656 
     | 
    
         
            +
                    """
         
     | 
| 
      
 657 
     | 
    
         
            +
                    return self.cond(**{column: group_to_compare}) - self.cond(**{column: baseline})
         
     | 
| 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 2 
     | 
    
         
            +
            from scipy.sparse import issparse, spmatrix
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None:
         
     | 
| 
      
 6 
     | 
    
         
            +
                """Check if a matrix is numeric and only contains finite/non-NA values.
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 9 
     | 
    
         
            +
                    array: Dense or sparse matrix to check.
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
                Raises:
         
     | 
| 
      
 12 
     | 
    
         
            +
                    ValueError: If the matrix is not numeric or contains NaNs or infinite values.
         
     | 
| 
      
 13 
     | 
    
         
            +
                """
         
     | 
| 
      
 14 
     | 
    
         
            +
                if not np.issubdtype(array.dtype, np.number):
         
     | 
| 
      
 15 
     | 
    
         
            +
                    raise ValueError("Counts must be numeric.")
         
     | 
| 
      
 16 
     | 
    
         
            +
                if issparse(array):
         
     | 
| 
      
 17 
     | 
    
         
            +
                    if np.any(~np.isfinite(array.data)):
         
     | 
| 
      
 18 
     | 
    
         
            +
                        raise ValueError("Counts cannot contain negative, NaN or Inf values.")
         
     | 
| 
      
 19 
     | 
    
         
            +
                else:
         
     | 
| 
      
 20 
     | 
    
         
            +
                    if np.any(~np.isfinite(array)):
         
     | 
| 
      
 21 
     | 
    
         
            +
                        raise ValueError("Counts cannot contain negative, NaN or Inf values.")
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
            def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None:
         
     | 
| 
      
 25 
     | 
    
         
            +
                """Check if a matrix container integers, or floats that are close to integers.
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 28 
     | 
    
         
            +
                    array: Dense or sparse matrix to check.
         
     | 
| 
      
 29 
     | 
    
         
            +
                    tolerance: Values must be this close to integers.
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
                Raises:
         
     | 
| 
      
 32 
     | 
    
         
            +
                    ValueError: If the matrix contains values that are not close to integers.
         
     | 
| 
      
 33 
     | 
    
         
            +
                """
         
     | 
| 
      
 34 
     | 
    
         
            +
                if issparse(array):
         
     | 
| 
      
 35 
     | 
    
         
            +
                    if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
         
     | 
| 
      
 36 
     | 
    
         
            +
                        raise ValueError("Non-zero elements of the matrix must be close to integer values.")
         
     | 
| 
      
 37 
     | 
    
         
            +
                else:
         
     | 
| 
      
 38 
     | 
    
         
            +
                    if not array.dtype.kind == "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
         
     | 
| 
      
 39 
     | 
    
         
            +
                        raise ValueError("Matrix must be a count matrix.")
         
     | 
| 
      
 40 
     | 
    
         
            +
                if (array < 0).sum() > 0:
         
     | 
| 
      
 41 
     | 
    
         
            +
                    raise ValueError("Non-zero elements of the matrix must be positive.")
         
     |