pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +3 -2
 - pertpy/data/__init__.py +5 -1
 - pertpy/data/_dataloader.py +2 -4
 - pertpy/data/_datasets.py +203 -92
 - pertpy/metadata/__init__.py +4 -0
 - pertpy/metadata/_cell_line.py +826 -0
 - pertpy/metadata/_compound.py +129 -0
 - pertpy/metadata/_drug.py +242 -0
 - pertpy/metadata/_look_up.py +582 -0
 - pertpy/metadata/_metadata.py +73 -0
 - pertpy/metadata/_moa.py +129 -0
 - pertpy/plot/__init__.py +1 -9
 - pertpy/plot/_augur.py +53 -116
 - pertpy/plot/_coda.py +277 -677
 - pertpy/plot/_guide_rna.py +17 -35
 - pertpy/plot/_milopy.py +59 -134
 - pertpy/plot/_mixscape.py +152 -391
 - pertpy/preprocessing/_guide_rna.py +88 -4
 - pertpy/tools/__init__.py +8 -13
 - pertpy/tools/_augur.py +315 -17
 - pertpy/tools/_cinemaot.py +143 -4
 - pertpy/tools/_coda/_base_coda.py +1210 -65
 - pertpy/tools/_coda/_sccoda.py +50 -21
 - pertpy/tools/_coda/_tasccoda.py +27 -19
 - pertpy/tools/_dialogue.py +164 -56
 - pertpy/tools/_differential_gene_expression.py +240 -14
 - pertpy/tools/_distances/_distance_tests.py +8 -8
 - pertpy/tools/_distances/_distances.py +184 -34
 - pertpy/tools/_enrichment.py +465 -0
 - pertpy/tools/_milo.py +345 -11
 - pertpy/tools/_mixscape.py +668 -50
 - pertpy/tools/_perturbation_space/_clustering.py +5 -1
 - pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
 - pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
 - pertpy/tools/_perturbation_space/_simple.py +51 -10
 - pertpy/tools/_scgen/__init__.py +1 -1
 - pertpy/tools/_scgen/_scgen.py +701 -0
 - pertpy/tools/_scgen/_utils.py +1 -3
 - pertpy/tools/decoupler_LICENSE +674 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
 - pertpy-0.7.0.dist-info/RECORD +53 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
 - pertpy/plot/_cinemaot.py +0 -81
 - pertpy/plot/_dialogue.py +0 -91
 - pertpy/plot/_scgen.py +0 -337
 - pertpy/tools/_metadata/__init__.py +0 -0
 - pertpy/tools/_metadata/_cell_line.py +0 -613
 - pertpy/tools/_metadata/_look_up.py +0 -342
 - pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
 - pertpy/tools/_scgen/_jax_scgen.py +0 -370
 - pertpy-0.6.0.dist-info/RECORD +0 -50
 - /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
 
| 
         @@ -1,12 +1,16 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
      
 3 
     | 
    
         
            +
            import uuid
         
     | 
| 
       3 
4 
     | 
    
         
             
            from typing import TYPE_CHECKING
         
     | 
| 
       4 
5 
     | 
    
         | 
| 
       5 
6 
     | 
    
         
             
            import numpy as np
         
     | 
| 
      
 7 
     | 
    
         
            +
            import pandas as pd
         
     | 
| 
      
 8 
     | 
    
         
            +
            import scanpy as sc
         
     | 
| 
       6 
9 
     | 
    
         
             
            import scipy
         
     | 
| 
       7 
10 
     | 
    
         | 
| 
       8 
11 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
       9 
12 
     | 
    
         
             
                from anndata import AnnData
         
     | 
| 
      
 13 
     | 
    
         
            +
                from matplotlib.axes import Axes
         
     | 
| 
       10 
14 
     | 
    
         | 
| 
       11 
15 
     | 
    
         | 
| 
       12 
16 
     | 
    
         
             
            class GuideAssignment:
         
     | 
| 
         @@ -39,7 +43,7 @@ class GuideAssignment: 
     | 
|
| 
       39 
43 
     | 
    
         | 
| 
       40 
44 
     | 
    
         
             
                        >>> import pertpy as pt
         
     | 
| 
       41 
45 
     | 
    
         
             
                        >>> mdata = pt.data.papalexi_2021()
         
     | 
| 
       42 
     | 
    
         
            -
                        >>> gdo = mdata.mod[ 
     | 
| 
      
 46 
     | 
    
         
            +
                        >>> gdo = mdata.mod["gdo"]
         
     | 
| 
       43 
47 
     | 
    
         
             
                        >>> ga = pt.pp.GuideAssignment()
         
     | 
| 
       44 
48 
     | 
    
         
             
                        >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
         
     | 
| 
       45 
49 
     | 
    
         
             
                    """
         
     | 
| 
         @@ -71,7 +75,6 @@ class GuideAssignment: 
     | 
|
| 
       71 
75 
     | 
    
         | 
| 
       72 
76 
     | 
    
         
             
                    Args:
         
     | 
| 
       73 
77 
     | 
    
         
             
                        adata: Annotated data matrix containing gRNA values
         
     | 
| 
       74 
     | 
    
         
            -
                               assignment_threshold: If a gRNA is available for at least `assignment_threshold`, it will be recognized as assigned.
         
     | 
| 
       75 
78 
     | 
    
         
             
                        assignment_threshold: The count threshold that is required for an assignment to be viable.
         
     | 
| 
       76 
79 
     | 
    
         
             
                        layer: Key to the layer containing raw count values of the gRNAs.
         
     | 
| 
       77 
80 
     | 
    
         
             
                               adata.X is used if layer is None. Expects count data.
         
     | 
| 
         @@ -83,8 +86,8 @@ class GuideAssignment: 
     | 
|
| 
       83 
86 
     | 
    
         
             
                        Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
         
     | 
| 
       84 
87 
     | 
    
         | 
| 
       85 
88 
     | 
    
         
             
                        >>> import pertpy as pt
         
     | 
| 
       86 
     | 
    
         
            -
                        >>> mdata = pt. 
     | 
| 
       87 
     | 
    
         
            -
                        >>> gdo = mdata.mod[ 
     | 
| 
      
 89 
     | 
    
         
            +
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
      
 90 
     | 
    
         
            +
                        >>> gdo = mdata.mod["gdo"]
         
     | 
| 
       88 
91 
     | 
    
         
             
                        >>> ga = pt.pp.GuideAssignment()
         
     | 
| 
       89 
92 
     | 
    
         
             
                        >>> ga.assign_to_max_guide(gdo, assignment_threshold=5)
         
     | 
| 
       90 
93 
     | 
    
         
             
                    """
         
     | 
| 
         @@ -103,3 +106,84 @@ class GuideAssignment: 
     | 
|
| 
       103 
106 
     | 
    
         
             
                    adata.obs[output_key] = assigned_grna
         
     | 
| 
       104 
107 
     | 
    
         | 
| 
       105 
108 
     | 
    
         
             
                    return None
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
                def plot_heatmap(
         
     | 
| 
      
 111 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 112 
     | 
    
         
            +
                    adata: AnnData,
         
     | 
| 
      
 113 
     | 
    
         
            +
                    layer: str | None = None,
         
     | 
| 
      
 114 
     | 
    
         
            +
                    order_by: np.ndarray | str | None = None,
         
     | 
| 
      
 115 
     | 
    
         
            +
                    key_to_save_order: str = None,
         
     | 
| 
      
 116 
     | 
    
         
            +
                    **kwargs,
         
     | 
| 
      
 117 
     | 
    
         
            +
                ) -> list[Axes]:
         
     | 
| 
      
 118 
     | 
    
         
            +
                    """Heatmap plotting of guide RNA expression matrix.
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
                    Assuming guides have sparse expression, this function reorders cells
         
     | 
| 
      
 121 
     | 
    
         
            +
                    and plots guide RNA expression so that a nice sparse representation is achieved.
         
     | 
| 
      
 122 
     | 
    
         
            +
                    The cell ordering can be stored and reused in future plots to obtain consistent
         
     | 
| 
      
 123 
     | 
    
         
            +
                    plots before and after analysis of the guide RNA expression.
         
     | 
| 
      
 124 
     | 
    
         
            +
                    Note: This function expects a log-normalized or binary data.
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 127 
     | 
    
         
            +
                        adata: Annotated data matrix containing gRNA values
         
     | 
| 
      
 128 
     | 
    
         
            +
                        layer: Key to the layer containing log normalized count values of the gRNAs.
         
     | 
| 
      
 129 
     | 
    
         
            +
                               adata.X is used if layer is None.
         
     | 
| 
      
 130 
     | 
    
         
            +
                        order_by: The order of cells in y axis. Defaults to None.
         
     | 
| 
      
 131 
     | 
    
         
            +
                                  If None, cells will be reordered to have a nice sparse representation.
         
     | 
| 
      
 132 
     | 
    
         
            +
                                  If a string is provided, adata.obs[order_by] will be used as the order.
         
     | 
| 
      
 133 
     | 
    
         
            +
                                  If a numpy array is provided, the array will be used for ordering.
         
     | 
| 
      
 134 
     | 
    
         
            +
                        key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
         
     | 
| 
      
 135 
     | 
    
         
            +
                        kwargs: Are passed to sc.pl.heatmap.
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 138 
     | 
    
         
            +
                        List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
         
     | 
| 
      
 139 
     | 
    
         
            +
                        Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
         
     | 
| 
      
 140 
     | 
    
         
            +
             
     | 
| 
      
 141 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 142 
     | 
    
         
            +
                        Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
         
     | 
| 
      
 143 
     | 
    
         
            +
                        visualized using a heatmap.
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 146 
     | 
    
         
            +
                        >>> mdata = pt.dt.papalexi_2021()
         
     | 
| 
      
 147 
     | 
    
         
            +
                        >>> gdo = mdata.mod["gdo"]
         
     | 
| 
      
 148 
     | 
    
         
            +
                        >>> ga = pt.pp.GuideAssignment()
         
     | 
| 
      
 149 
     | 
    
         
            +
                        >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
         
     | 
| 
      
 150 
     | 
    
         
            +
                        >>> ga.plot_heatmap(gdo)
         
     | 
| 
      
 151 
     | 
    
         
            +
                    """
         
     | 
| 
      
 152 
     | 
    
         
            +
                    data = adata.X if layer is None else adata.layers[layer]
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
                    if order_by is None:
         
     | 
| 
      
 155 
     | 
    
         
            +
                        if scipy.sparse.issparse(data):
         
     | 
| 
      
 156 
     | 
    
         
            +
                            max_values = data.max(axis=1).A.squeeze()
         
     | 
| 
      
 157 
     | 
    
         
            +
                            data_argmax = data.argmax(axis=1).A.squeeze()
         
     | 
| 
      
 158 
     | 
    
         
            +
                            max_guide_index = np.where(max_values != data.min(axis=1).A.squeeze(), data_argmax, -1)
         
     | 
| 
      
 159 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 160 
     | 
    
         
            +
                            max_guide_index = np.where(
         
     | 
| 
      
 161 
     | 
    
         
            +
                                data.max(axis=1).squeeze() != data.min(axis=1).squeeze(), data.argmax(axis=1).squeeze(), -1
         
     | 
| 
      
 162 
     | 
    
         
            +
                            )
         
     | 
| 
      
 163 
     | 
    
         
            +
                        order = np.argsort(max_guide_index)
         
     | 
| 
      
 164 
     | 
    
         
            +
                    elif isinstance(order_by, str):
         
     | 
| 
      
 165 
     | 
    
         
            +
                        order = np.argsort(adata.obs[order_by])
         
     | 
| 
      
 166 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 167 
     | 
    
         
            +
                        order = order_by
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
                    temp_col_name = f"_tmp_pertpy_grna_plot_{uuid.uuid4()}"
         
     | 
| 
      
 170 
     | 
    
         
            +
                    adata.obs[temp_col_name] = pd.Categorical(["" for _ in range(adata.shape[0])])
         
     | 
| 
      
 171 
     | 
    
         
            +
             
     | 
| 
      
 172 
     | 
    
         
            +
                    if key_to_save_order is not None:
         
     | 
| 
      
 173 
     | 
    
         
            +
                        adata.obs[key_to_save_order] = pd.Categorical(order)
         
     | 
| 
      
 174 
     | 
    
         
            +
             
     | 
| 
      
 175 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 176 
     | 
    
         
            +
                        axis_group = sc.pl.heatmap(
         
     | 
| 
      
 177 
     | 
    
         
            +
                            adata[order, :],
         
     | 
| 
      
 178 
     | 
    
         
            +
                            var_names=adata.var.index.tolist(),
         
     | 
| 
      
 179 
     | 
    
         
            +
                            groupby=temp_col_name,
         
     | 
| 
      
 180 
     | 
    
         
            +
                            cmap="viridis",
         
     | 
| 
      
 181 
     | 
    
         
            +
                            use_raw=False,
         
     | 
| 
      
 182 
     | 
    
         
            +
                            dendrogram=False,
         
     | 
| 
      
 183 
     | 
    
         
            +
                            layer=layer,
         
     | 
| 
      
 184 
     | 
    
         
            +
                            **kwargs,
         
     | 
| 
      
 185 
     | 
    
         
            +
                        )
         
     | 
| 
      
 186 
     | 
    
         
            +
                    finally:
         
     | 
| 
      
 187 
     | 
    
         
            +
                        del adata.obs[temp_col_name]
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                    return axis_group
         
     | 
    
        pertpy/tools/__init__.py
    CHANGED
    
    | 
         @@ -1,24 +1,19 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from rich import print
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
1 
     | 
    
         
             
            from pertpy.tools._augur import Augur
         
     | 
| 
       4 
2 
     | 
    
         
             
            from pertpy.tools._cinemaot import Cinemaot
         
     | 
| 
      
 3 
     | 
    
         
            +
            from pertpy.tools._coda._sccoda import Sccoda
         
     | 
| 
      
 4 
     | 
    
         
            +
            from pertpy.tools._coda._tasccoda import Tasccoda
         
     | 
| 
       5 
5 
     | 
    
         
             
            from pertpy.tools._dialogue import Dialogue
         
     | 
| 
       6 
6 
     | 
    
         
             
            from pertpy.tools._differential_gene_expression import DifferentialGeneExpression
         
     | 
| 
       7 
7 
     | 
    
         
             
            from pertpy.tools._distances._distance_tests import DistanceTest
         
     | 
| 
       8 
8 
     | 
    
         
             
            from pertpy.tools._distances._distances import Distance
         
     | 
| 
       9 
     | 
    
         
            -
            from pertpy.tools. 
     | 
| 
      
 9 
     | 
    
         
            +
            from pertpy.tools._enrichment import Enrichment
         
     | 
| 
       10 
10 
     | 
    
         
             
            from pertpy.tools._milo import Milo
         
     | 
| 
       11 
11 
     | 
    
         
             
            from pertpy.tools._mixscape import Mixscape
         
     | 
| 
       12 
12 
     | 
    
         
             
            from pertpy.tools._perturbation_space._clustering import ClusteringSpace
         
     | 
| 
       13 
     | 
    
         
            -
            from pertpy.tools._perturbation_space. 
     | 
| 
      
 13 
     | 
    
         
            +
            from pertpy.tools._perturbation_space._discriminator_classifiers import (
         
     | 
| 
      
 14 
     | 
    
         
            +
                DiscriminatorClassifierSpace,
         
     | 
| 
      
 15 
     | 
    
         
            +
                LRClassifierSpace,
         
     | 
| 
      
 16 
     | 
    
         
            +
                MLPClassifierSpace,
         
     | 
| 
      
 17 
     | 
    
         
            +
            )
         
     | 
| 
       14 
18 
     | 
    
         
             
            from pertpy.tools._perturbation_space._simple import CentroidSpace, DBSCANSpace, KMeansSpace, PseudobulkSpace
         
     | 
| 
       15 
19 
     | 
    
         
             
            from pertpy.tools._scgen import SCGEN
         
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
            try:
         
     | 
| 
       18 
     | 
    
         
            -
                from pertpy.tools._coda._sccoda import Sccoda
         
     | 
| 
       19 
     | 
    
         
            -
                from pertpy.tools._coda._tasccoda import Tasccoda
         
     | 
| 
       20 
     | 
    
         
            -
            except ImportError as e:
         
     | 
| 
       21 
     | 
    
         
            -
                if "ete3" in str(e):
         
     | 
| 
       22 
     | 
    
         
            -
                    print("[bold yellow]To use sccoda or tasccoda please install ete3 with [green]pip install ete3")
         
     | 
| 
       23 
     | 
    
         
            -
                else:
         
     | 
| 
       24 
     | 
    
         
            -
                    raise e
         
     | 
    
        pertpy/tools/_augur.py
    CHANGED
    
    | 
         @@ -4,8 +4,10 @@ import random 
     | 
|
| 
       4 
4 
     | 
    
         
             
            from collections import defaultdict
         
     | 
| 
       5 
5 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
       6 
6 
     | 
    
         
             
            from math import floor, nan
         
     | 
| 
       7 
     | 
    
         
            -
            from typing import Any, Literal
         
     | 
| 
      
 7 
     | 
    
         
            +
            from typing import TYPE_CHECKING, Any, Literal
         
     | 
| 
       8 
8 
     | 
    
         | 
| 
      
 9 
     | 
    
         
            +
            import anndata as ad
         
     | 
| 
      
 10 
     | 
    
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 
       9 
11 
     | 
    
         
             
            import numpy as np
         
     | 
| 
       10 
12 
     | 
    
         
             
            import pandas as pd
         
     | 
| 
       11 
13 
     | 
    
         
             
            import scanpy as sc
         
     | 
| 
         @@ -34,6 +36,10 @@ from sklearn.preprocessing import LabelEncoder 
     | 
|
| 
       34 
36 
     | 
    
         
             
            from skmisc.loess import loess
         
     | 
| 
       35 
37 
     | 
    
         
             
            from statsmodels.stats.multitest import fdrcorrection
         
     | 
| 
       36 
38 
     | 
    
         | 
| 
      
 39 
     | 
    
         
            +
            if TYPE_CHECKING:
         
     | 
| 
      
 40 
     | 
    
         
            +
                from matplotlib.axes import Axes
         
     | 
| 
      
 41 
     | 
    
         
            +
                from matplotlib.figure import Figure
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
       37 
43 
     | 
    
         | 
| 
       38 
44 
     | 
    
         
             
            @dataclass
         
     | 
| 
       39 
45 
     | 
    
         
             
            class Params:
         
     | 
| 
         @@ -135,8 +141,8 @@ class Augur: 
     | 
|
| 
       135 
141 
     | 
    
         
             
                        # filter samples according to label
         
     | 
| 
       136 
142 
     | 
    
         
             
                        if condition_label is not None and treatment_label is not None:
         
     | 
| 
       137 
143 
     | 
    
         
             
                            print(f"Filtering samples with {condition_label} and {treatment_label} labels.")
         
     | 
| 
       138 
     | 
    
         
            -
                            adata =  
     | 
| 
       139 
     | 
    
         
            -
                                adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
         
     | 
| 
      
 144 
     | 
    
         
            +
                            adata = ad.concat(
         
     | 
| 
      
 145 
     | 
    
         
            +
                                [adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
         
     | 
| 
       140 
146 
     | 
    
         
             
                            )
         
     | 
| 
       141 
147 
     | 
    
         
             
                        label_encoder = LabelEncoder()
         
     | 
| 
       142 
148 
     | 
    
         
             
                        adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
         
     | 
| 
         @@ -214,7 +220,9 @@ class Augur: 
     | 
|
| 
       214 
220 
     | 
    
         
             
                        >>> loaded_data = ag_rfc.load(adata)
         
     | 
| 
       215 
221 
     | 
    
         
             
                        >>> ag_rfc.select_highly_variable(loaded_data)
         
     | 
| 
       216 
222 
     | 
    
         
             
                        >>> features = loaded_data.var_names
         
     | 
| 
       217 
     | 
    
         
            -
                        >>> subsample = ag_rfc.sample( 
     | 
| 
      
 223 
     | 
    
         
            +
                        >>> subsample = ag_rfc.sample(
         
     | 
| 
      
 224 
     | 
    
         
            +
                        ...     loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names
         
     | 
| 
      
 225 
     | 
    
         
            +
                        ... )
         
     | 
| 
       218 
226 
     | 
    
         
             
                    """
         
     | 
| 
       219 
227 
     | 
    
         
             
                    # export subsampling.
         
     | 
| 
       220 
228 
     | 
    
         
             
                    random.seed(random_state)
         
     | 
| 
         @@ -230,7 +238,7 @@ class Augur: 
     | 
|
| 
       230 
238 
     | 
    
         
             
                                    random_state=random_state,
         
     | 
| 
       231 
239 
     | 
    
         
             
                                )
         
     | 
| 
       232 
240 
     | 
    
         
             
                            )
         
     | 
| 
       233 
     | 
    
         
            -
                        subsample =  
     | 
| 
      
 241 
     | 
    
         
            +
                        subsample = ad.concat([*label_subsamples], index_unique=None)
         
     | 
| 
       234 
242 
     | 
    
         
             
                    else:
         
     | 
| 
       235 
243 
     | 
    
         
             
                        subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
         
     | 
| 
       236 
244 
     | 
    
         | 
| 
         @@ -409,8 +417,8 @@ class Augur: 
     | 
|
| 
       409 
417 
     | 
    
         
             
                    """
         
     | 
| 
       410 
418 
     | 
    
         
             
                    if multiclass:
         
     | 
| 
       411 
419 
     | 
    
         
             
                        return {
         
     | 
| 
       412 
     | 
    
         
            -
                            "augur_score": make_scorer(roc_auc_score, multi_class="ovo",  
     | 
| 
       413 
     | 
    
         
            -
                            "auc": make_scorer(roc_auc_score, multi_class="ovo",  
     | 
| 
      
 420 
     | 
    
         
            +
                            "augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
         
     | 
| 
      
 421 
     | 
    
         
            +
                            "auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
         
     | 
| 
       414 
422 
     | 
    
         
             
                            "accuracy": make_scorer(accuracy_score),
         
     | 
| 
       415 
423 
     | 
    
         
             
                            "precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
         
     | 
| 
       416 
424 
     | 
    
         
             
                            "f1": make_scorer(f1_score, average="macro"),
         
     | 
| 
         @@ -418,8 +426,8 @@ class Augur: 
     | 
|
| 
       418 
426 
     | 
    
         
             
                        }
         
     | 
| 
       419 
427 
     | 
    
         
             
                    return (
         
     | 
| 
       420 
428 
     | 
    
         
             
                        {
         
     | 
| 
       421 
     | 
    
         
            -
                            "augur_score": make_scorer(roc_auc_score,  
     | 
| 
       422 
     | 
    
         
            -
                            "auc": make_scorer(roc_auc_score,  
     | 
| 
      
 429 
     | 
    
         
            +
                            "augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
         
     | 
| 
      
 430 
     | 
    
         
            +
                            "auc": make_scorer(roc_auc_score, response_method="predict_proba"),
         
     | 
| 
       423 
431 
     | 
    
         
             
                            "accuracy": make_scorer(accuracy_score),
         
     | 
| 
       424 
432 
     | 
    
         
             
                            "precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
         
     | 
| 
       425 
433 
     | 
    
         
             
                            "f1": make_scorer(f1_score, average="binary"),
         
     | 
| 
         @@ -488,7 +496,7 @@ class Augur: 
     | 
|
| 
       488 
496 
     | 
    
         
             
                    # feature importances
         
     | 
| 
       489 
497 
     | 
    
         
             
                    feature_importances = defaultdict(list)
         
     | 
| 
       490 
498 
     | 
    
         
             
                    if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
         
     | 
| 
       491 
     | 
    
         
            -
                        for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
         
     | 
| 
      
 499 
     | 
    
         
            +
                        for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
         
     | 
| 
       492 
500 
     | 
    
         
             
                            feature_importances["genes"].extend(x.columns.tolist())
         
     | 
| 
       493 
501 
     | 
    
         
             
                            feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
         
     | 
| 
       494 
502 
     | 
    
         
             
                            feature_importances["subsample_idx"].extend(len(x.columns) * [subsample_idx])
         
     | 
| 
         @@ -497,7 +505,7 @@ class Augur: 
     | 
|
| 
       497 
505 
     | 
    
         
             
                    # standardized coefficients with Agresti method
         
     | 
| 
       498 
506 
     | 
    
         
             
                    # cf. https://think-lab.github.io/d/205/#3
         
     | 
| 
       499 
507 
     | 
    
         
             
                    if isinstance(self.estimator, LogisticRegression):
         
     | 
| 
       500 
     | 
    
         
            -
                        for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
         
     | 
| 
      
 508 
     | 
    
         
            +
                        for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
         
     | 
| 
       501 
509 
     | 
    
         
             
                            feature_importances["genes"].extend(x.columns.tolist())
         
     | 
| 
       502 
510 
     | 
    
         
             
                            feature_importances["feature_importances"].extend(
         
     | 
| 
       503 
511 
     | 
    
         
             
                                (self.estimator.coef_ * self.estimator.coef_.std()).flatten().tolist()
         
     | 
| 
         @@ -723,6 +731,7 @@ class Augur: 
     | 
|
| 
       723 
731 
     | 
    
         
             
                        >>> loaded_data = ag_rfc.load(adata)
         
     | 
| 
       724 
732 
     | 
    
         
             
                        >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
         
     | 
| 
       725 
733 
     | 
    
         
             
                    """
         
     | 
| 
      
 734 
     | 
    
         
            +
                    adata = adata.copy()
         
     | 
| 
       726 
735 
     | 
    
         
             
                    if augur_mode == "permute" and n_subsamples < 100:
         
     | 
| 
       727 
736 
     | 
    
         
             
                        n_subsamples = 500
         
     | 
| 
       728 
737 
     | 
    
         
             
                    if is_regressor(self.estimator) and len(adata.obs["y_"].unique()) <= 3:
         
     | 
| 
         @@ -765,6 +774,7 @@ class Augur: 
     | 
|
| 
       765 
774 
     | 
    
         
             
                        elif (
         
     | 
| 
       766 
775 
     | 
    
         
             
                            cell_type_subsample.obs.groupby(
         
     | 
| 
       767 
776 
     | 
    
         
             
                                ["cell_type", "label"],
         
     | 
| 
      
 777 
     | 
    
         
            +
                                observed=True,
         
     | 
| 
       768 
778 
     | 
    
         
             
                            ).y_.count()
         
     | 
| 
       769 
779 
     | 
    
         
             
                            < subsample_size
         
     | 
| 
       770 
780 
     | 
    
         
             
                        ).any():
         
     | 
| 
         @@ -804,7 +814,7 @@ class Augur: 
     | 
|
| 
       804 
814 
     | 
    
         
             
                                * (len(results["feature_importances"]["genes"]) - len(results["feature_importances"]["cell_type"]))
         
     | 
| 
       805 
815 
     | 
    
         
             
                            )
         
     | 
| 
       806 
816 
     | 
    
         | 
| 
       807 
     | 
    
         
            -
                            for idx, cv in zip(range(n_subsamples), results[cell_type]):
         
     | 
| 
      
 817 
     | 
    
         
            +
                            for idx, cv in zip(range(n_subsamples), results[cell_type], strict=False):
         
     | 
| 
       808 
818 
     | 
    
         
             
                                results["full_results"]["idx"].extend([idx] * folds)
         
     | 
| 
       809 
819 
     | 
    
         
             
                                results["full_results"]["augur_score"].extend(cv["test_augur_score"])
         
     | 
| 
       810 
820 
     | 
    
         
             
                                results["full_results"]["folds"].extend(range(folds))
         
     | 
| 
         @@ -869,28 +879,31 @@ class Augur: 
     | 
|
| 
       869 
879 
     | 
    
         
             
                        & set(permuted_results1["summary_metrics"].columns)
         
     | 
| 
       870 
880 
     | 
    
         
             
                        & set(permuted_results2["summary_metrics"].columns)
         
     | 
| 
       871 
881 
     | 
    
         
             
                    )
         
     | 
| 
      
 882 
     | 
    
         
            +
             
     | 
| 
      
 883 
     | 
    
         
            +
                    cell_types_list = list(cell_types)
         
     | 
| 
      
 884 
     | 
    
         
            +
             
     | 
| 
       872 
885 
     | 
    
         
             
                    # mean augur scores
         
     | 
| 
       873 
886 
     | 
    
         
             
                    augur_score1 = (
         
     | 
| 
       874 
887 
     | 
    
         
             
                        augur_results1["summary_metrics"]
         
     | 
| 
       875 
     | 
    
         
            -
                        .loc["mean_augur_score",  
     | 
| 
      
 888 
     | 
    
         
            +
                        .loc["mean_augur_score", cell_types_list]
         
     | 
| 
       876 
889 
     | 
    
         
             
                        .reset_index()
         
     | 
| 
       877 
890 
     | 
    
         
             
                        .rename(columns={"index": "cell_type"})
         
     | 
| 
       878 
891 
     | 
    
         
             
                    )
         
     | 
| 
       879 
892 
     | 
    
         
             
                    augur_score2 = (
         
     | 
| 
       880 
893 
     | 
    
         
             
                        augur_results2["summary_metrics"]
         
     | 
| 
       881 
     | 
    
         
            -
                        .loc["mean_augur_score",  
     | 
| 
      
 894 
     | 
    
         
            +
                        .loc["mean_augur_score", cell_types_list]
         
     | 
| 
       882 
895 
     | 
    
         
             
                        .reset_index()
         
     | 
| 
       883 
896 
     | 
    
         
             
                        .rename(columns={"index": "cell_type"})
         
     | 
| 
       884 
897 
     | 
    
         
             
                    )
         
     | 
| 
       885 
898 
     | 
    
         | 
| 
       886 
899 
     | 
    
         
             
                    # mean permuted scores over cross validation runs
         
     | 
| 
       887 
900 
     | 
    
         
             
                    permuted_cv_augur1 = (
         
     | 
| 
       888 
     | 
    
         
            -
                        permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin( 
     | 
| 
      
 901 
     | 
    
         
            +
                        permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types_list)]
         
     | 
| 
       889 
902 
     | 
    
         
             
                        .groupby(["cell_type", "idx"], as_index=False)
         
     | 
| 
       890 
903 
     | 
    
         
             
                        .mean()
         
     | 
| 
       891 
904 
     | 
    
         
             
                    )
         
     | 
| 
       892 
905 
     | 
    
         
             
                    permuted_cv_augur2 = (
         
     | 
| 
       893 
     | 
    
         
            -
                        permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin( 
     | 
| 
      
 906 
     | 
    
         
            +
                        permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types_list)]
         
     | 
| 
       894 
907 
     | 
    
         
             
                        .groupby(["cell_type", "idx"], as_index=False)
         
     | 
| 
       895 
908 
     | 
    
         
             
                        .mean()
         
     | 
| 
       896 
909 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -901,7 +914,7 @@ class Augur: 
     | 
|
| 
       901 
914 
     | 
    
         
             
                    # draw mean aucs for permute1 and permute2
         
     | 
| 
       902 
915 
     | 
    
         
             
                    for celltype in permuted_cv_augur1["cell_type"].unique():
         
     | 
| 
       903 
916 
     | 
    
         
             
                        df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
         
     | 
| 
       904 
     | 
    
         
            -
                        df2 = permuted_cv_augur2[ 
     | 
| 
      
 917 
     | 
    
         
            +
                        df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
         
     | 
| 
       905 
918 
     | 
    
         
             
                        for permutation_idx in range(n_permutations):
         
     | 
| 
       906 
919 
     | 
    
         
             
                            # subsample
         
     | 
| 
       907 
920 
     | 
    
         
             
                            sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
         
     | 
| 
         @@ -961,3 +974,288 @@ class Augur: 
     | 
|
| 
       961 
974 
     | 
    
         
             
                    delta["padj"] = fdrcorrection(delta["pval"])[1]
         
     | 
| 
       962 
975 
     | 
    
         | 
| 
       963 
976 
     | 
    
         
             
                    return delta
         
     | 
| 
      
 977 
     | 
    
         
            +
             
     | 
| 
      
 978 
     | 
    
         
            +
                def plot_dp_scatter(
         
     | 
| 
      
 979 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 980 
     | 
    
         
            +
                    results: pd.DataFrame,
         
     | 
| 
      
 981 
     | 
    
         
            +
                    top_n: int = None,
         
     | 
| 
      
 982 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 983 
     | 
    
         
            +
                    ax: Axes = None,
         
     | 
| 
      
 984 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 985 
     | 
    
         
            +
                    save: str | bool | None = None,
         
     | 
| 
      
 986 
     | 
    
         
            +
                ) -> Axes | Figure | None:
         
     | 
| 
      
 987 
     | 
    
         
            +
                    """Plot scatterplot of differential prioritization.
         
     | 
| 
      
 988 
     | 
    
         
            +
             
     | 
| 
      
 989 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 990 
     | 
    
         
            +
                        results: Results after running differential prioritization.
         
     | 
| 
      
 991 
     | 
    
         
            +
                        top_n: optionally, the number of top prioritized cell types to label in the plot
         
     | 
| 
      
 992 
     | 
    
         
            +
                        ax: optionally, axes used to draw plot
         
     | 
| 
      
 993 
     | 
    
         
            +
             
     | 
| 
      
 994 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 995 
     | 
    
         
            +
                        Axes of the plot.
         
     | 
| 
      
 996 
     | 
    
         
            +
             
     | 
| 
      
 997 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 998 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 999 
     | 
    
         
            +
                        >>> adata = pt.dt.bhattacherjee()
         
     | 
| 
      
 1000 
     | 
    
         
            +
                        >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
         
     | 
| 
      
 1001 
     | 
    
         
            +
             
     | 
| 
      
 1002 
     | 
    
         
            +
                        >>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine")
         
     | 
| 
      
 1003 
     | 
    
         
            +
                        >>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4)
         
     | 
| 
      
 1004 
     | 
    
         
            +
                        >>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
         
     | 
| 
      
 1005 
     | 
    
         
            +
             
     | 
| 
      
 1006 
     | 
    
         
            +
                        >>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine")
         
     | 
| 
      
 1007 
     | 
    
         
            +
                        >>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4)
         
     | 
| 
      
 1008 
     | 
    
         
            +
                        >>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
         
     | 
| 
      
 1009 
     | 
    
         
            +
             
     | 
| 
      
 1010 
     | 
    
         
            +
                        >>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
         
     | 
| 
      
 1011 
     | 
    
         
            +
                            permuted_results1=results_15_permute, permuted_results2=results_48_permute)
         
     | 
| 
      
 1012 
     | 
    
         
            +
                        >>> ag_rfc.plot_dp_scatter(pvals)
         
     | 
| 
      
 1013 
     | 
    
         
            +
             
     | 
| 
      
 1014 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 1015 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/augur_dp_scatter.png
         
     | 
| 
      
 1016 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1017 
     | 
    
         
            +
                    x = results["mean_augur_score1"]
         
     | 
| 
      
 1018 
     | 
    
         
            +
                    y = results["mean_augur_score2"]
         
     | 
| 
      
 1019 
     | 
    
         
            +
             
     | 
| 
      
 1020 
     | 
    
         
            +
                    if ax is None:
         
     | 
| 
      
 1021 
     | 
    
         
            +
                        fig, ax = plt.subplots()
         
     | 
| 
      
 1022 
     | 
    
         
            +
                    scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
         
     | 
| 
      
 1023 
     | 
    
         
            +
             
     | 
| 
      
 1024 
     | 
    
         
            +
                    # adding optional labels
         
     | 
| 
      
 1025 
     | 
    
         
            +
                    top_n_index = results.sort_values(by="pval").index[:top_n]
         
     | 
| 
      
 1026 
     | 
    
         
            +
                    for idx in top_n_index:
         
     | 
| 
      
 1027 
     | 
    
         
            +
                        ax.annotate(
         
     | 
| 
      
 1028 
     | 
    
         
            +
                            results.loc[idx, "cell_type"],
         
     | 
| 
      
 1029 
     | 
    
         
            +
                            (results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
         
     | 
| 
      
 1030 
     | 
    
         
            +
                        )
         
     | 
| 
      
 1031 
     | 
    
         
            +
             
     | 
| 
      
 1032 
     | 
    
         
            +
                    # add diagonal
         
     | 
| 
      
 1033 
     | 
    
         
            +
                    limits = max(ax.get_xlim(), ax.get_ylim())
         
     | 
| 
      
 1034 
     | 
    
         
            +
                    (_,) = ax.plot(limits, limits, ls="--", c=".3")
         
     | 
| 
      
 1035 
     | 
    
         
            +
             
     | 
| 
      
 1036 
     | 
    
         
            +
                    # formatting and details
         
     | 
| 
      
 1037 
     | 
    
         
            +
                    plt.xlabel("Augur scores 1")
         
     | 
| 
      
 1038 
     | 
    
         
            +
                    plt.ylabel("Augur scores 2")
         
     | 
| 
      
 1039 
     | 
    
         
            +
                    legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
         
     | 
| 
      
 1040 
     | 
    
         
            +
                    ax.add_artist(legend1)
         
     | 
| 
      
 1041 
     | 
    
         
            +
             
     | 
| 
      
 1042 
     | 
    
         
            +
                    if save:
         
     | 
| 
      
 1043 
     | 
    
         
            +
                        plt.savefig(save, bbox_inches="tight")
         
     | 
| 
      
 1044 
     | 
    
         
            +
                    if show:
         
     | 
| 
      
 1045 
     | 
    
         
            +
                        plt.show()
         
     | 
| 
      
 1046 
     | 
    
         
            +
                    if return_fig:
         
     | 
| 
      
 1047 
     | 
    
         
            +
                        return plt.gcf()
         
     | 
| 
      
 1048 
     | 
    
         
            +
                    if not (show or save):
         
     | 
| 
      
 1049 
     | 
    
         
            +
                        return ax
         
     | 
| 
      
 1050 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 1051 
     | 
    
         
            +
             
     | 
| 
      
 1052 
     | 
    
         
            +
                def plot_important_features(
         
     | 
| 
      
 1053 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1054 
     | 
    
         
            +
                    data: dict[str, Any],
         
     | 
| 
      
 1055 
     | 
    
         
            +
                    key: str = "augurpy_results",
         
     | 
| 
      
 1056 
     | 
    
         
            +
                    top_n: int = 10,
         
     | 
| 
      
 1057 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 1058 
     | 
    
         
            +
                    ax: Axes = None,
         
     | 
| 
      
 1059 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 1060 
     | 
    
         
            +
                    save: str | bool | None = None,
         
     | 
| 
      
 1061 
     | 
    
         
            +
                ) -> Axes | None:
         
     | 
| 
      
 1062 
     | 
    
         
            +
                    """Plot a lollipop plot of the n features with largest feature importances.
         
     | 
| 
      
 1063 
     | 
    
         
            +
             
     | 
| 
      
 1064 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 1065 
     | 
    
         
            +
                        results: results after running `predict()` as dictionary or the AnnData object.
         
     | 
| 
      
 1066 
     | 
    
         
            +
                        key: Key in the AnnData object of the results
         
     | 
| 
      
 1067 
     | 
    
         
            +
                        top_n: n number feature importance values to plot. Default is 10.
         
     | 
| 
      
 1068 
     | 
    
         
            +
                        ax: optionally, axes used to draw plot
         
     | 
| 
      
 1069 
     | 
    
         
            +
                        return_figure: if `True` returns figure of the plot, default is `False`
         
     | 
| 
      
 1070 
     | 
    
         
            +
             
     | 
| 
      
 1071 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 1072 
     | 
    
         
            +
                        Axes of the plot.
         
     | 
| 
      
 1073 
     | 
    
         
            +
             
     | 
| 
      
 1074 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 1075 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 1076 
     | 
    
         
            +
                        >>> adata = pt.dt.sc_sim_augur()
         
     | 
| 
      
 1077 
     | 
    
         
            +
                        >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
         
     | 
| 
      
 1078 
     | 
    
         
            +
                        >>> loaded_data = ag_rfc.load(adata)
         
     | 
| 
      
 1079 
     | 
    
         
            +
                        >>> v_adata, v_results = ag_rfc.predict(
         
     | 
| 
      
 1080 
     | 
    
         
            +
                        ...     loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
         
     | 
| 
      
 1081 
     | 
    
         
            +
                        ... )
         
     | 
| 
      
 1082 
     | 
    
         
            +
                        >>> ag_rfc.plot_important_features(v_results)
         
     | 
| 
      
 1083 
     | 
    
         
            +
             
     | 
| 
      
 1084 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 1085 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/augur_important_features.png
         
     | 
| 
      
 1086 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1087 
     | 
    
         
            +
                    if isinstance(data, AnnData):
         
     | 
| 
      
 1088 
     | 
    
         
            +
                        results = data.uns[key]
         
     | 
| 
      
 1089 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1090 
     | 
    
         
            +
                        results = data
         
     | 
| 
      
 1091 
     | 
    
         
            +
                    n_features = (
         
     | 
| 
      
 1092 
     | 
    
         
            +
                        results["feature_importances"]
         
     | 
| 
      
 1093 
     | 
    
         
            +
                        .groupby("genes", as_index=False)
         
     | 
| 
      
 1094 
     | 
    
         
            +
                        .feature_importances.mean()
         
     | 
| 
      
 1095 
     | 
    
         
            +
                        .sort_values(by="feature_importances")[-top_n:]
         
     | 
| 
      
 1096 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1097 
     | 
    
         
            +
             
     | 
| 
      
 1098 
     | 
    
         
            +
                    if ax is None:
         
     | 
| 
      
 1099 
     | 
    
         
            +
                        fig, ax = plt.subplots()
         
     | 
| 
      
 1100 
     | 
    
         
            +
                    y_axes_range = range(1, top_n + 1)
         
     | 
| 
      
 1101 
     | 
    
         
            +
                    ax.hlines(
         
     | 
| 
      
 1102 
     | 
    
         
            +
                        y_axes_range,
         
     | 
| 
      
 1103 
     | 
    
         
            +
                        xmin=0,
         
     | 
| 
      
 1104 
     | 
    
         
            +
                        xmax=n_features["feature_importances"],
         
     | 
| 
      
 1105 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1106 
     | 
    
         
            +
             
     | 
| 
      
 1107 
     | 
    
         
            +
                    ax.plot(n_features["feature_importances"], y_axes_range, "o")
         
     | 
| 
      
 1108 
     | 
    
         
            +
             
     | 
| 
      
 1109 
     | 
    
         
            +
                    plt.xlabel("Mean Feature Importance")
         
     | 
| 
      
 1110 
     | 
    
         
            +
                    plt.ylabel("Gene")
         
     | 
| 
      
 1111 
     | 
    
         
            +
                    plt.yticks(y_axes_range, n_features["genes"])
         
     | 
| 
      
 1112 
     | 
    
         
            +
             
     | 
| 
      
 1113 
     | 
    
         
            +
                    if save:
         
     | 
| 
      
 1114 
     | 
    
         
            +
                        plt.savefig(save, bbox_inches="tight")
         
     | 
| 
      
 1115 
     | 
    
         
            +
                    if show:
         
     | 
| 
      
 1116 
     | 
    
         
            +
                        plt.show()
         
     | 
| 
      
 1117 
     | 
    
         
            +
                    if return_fig:
         
     | 
| 
      
 1118 
     | 
    
         
            +
                        return plt.gcf()
         
     | 
| 
      
 1119 
     | 
    
         
            +
                    if not (show or save):
         
     | 
| 
      
 1120 
     | 
    
         
            +
                        return ax
         
     | 
| 
      
 1121 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 1122 
     | 
    
         
            +
             
     | 
| 
      
 1123 
     | 
    
         
            +
                def plot_lollipop(
         
     | 
| 
      
 1124 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1125 
     | 
    
         
            +
                    data: dict[str, Any],
         
     | 
| 
      
 1126 
     | 
    
         
            +
                    key: str = "augurpy_results",
         
     | 
| 
      
 1127 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 1128 
     | 
    
         
            +
                    ax: Axes = None,
         
     | 
| 
      
 1129 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 1130 
     | 
    
         
            +
                    save: str | bool | None = None,
         
     | 
| 
      
 1131 
     | 
    
         
            +
                ) -> Axes | Figure | None:
         
     | 
| 
      
 1132 
     | 
    
         
            +
                    """Plot a lollipop plot of the mean augur values.
         
     | 
| 
      
 1133 
     | 
    
         
            +
             
     | 
| 
      
 1134 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 1135 
     | 
    
         
            +
                        results: results after running `predict()` as dictionary or the AnnData object.
         
     | 
| 
      
 1136 
     | 
    
         
            +
                        key: Key in the AnnData object of the results
         
     | 
| 
      
 1137 
     | 
    
         
            +
                        ax: optionally, axes used to draw plot
         
     | 
| 
      
 1138 
     | 
    
         
            +
                        return_figure: if `True` returns figure of the plot
         
     | 
| 
      
 1139 
     | 
    
         
            +
             
     | 
| 
      
 1140 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 1141 
     | 
    
         
            +
                        Axes of the plot.
         
     | 
| 
      
 1142 
     | 
    
         
            +
             
     | 
| 
      
 1143 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 1144 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 1145 
     | 
    
         
            +
                        >>> adata = pt.dt.sc_sim_augur()
         
     | 
| 
      
 1146 
     | 
    
         
            +
                        >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
         
     | 
| 
      
 1147 
     | 
    
         
            +
                        >>> loaded_data = ag_rfc.load(adata)
         
     | 
| 
      
 1148 
     | 
    
         
            +
                        >>> v_adata, v_results = ag_rfc.predict(
         
     | 
| 
      
 1149 
     | 
    
         
            +
                        ...     loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
         
     | 
| 
      
 1150 
     | 
    
         
            +
                        ... )
         
     | 
| 
      
 1151 
     | 
    
         
            +
                        >>> ag_rfc.plot_lollipop(v_results)
         
     | 
| 
      
 1152 
     | 
    
         
            +
             
     | 
| 
      
 1153 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 1154 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/augur_lollipop.png
         
     | 
| 
      
 1155 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1156 
     | 
    
         
            +
                    if isinstance(data, AnnData):
         
     | 
| 
      
 1157 
     | 
    
         
            +
                        results = data.uns[key]
         
     | 
| 
      
 1158 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1159 
     | 
    
         
            +
                        results = data
         
     | 
| 
      
 1160 
     | 
    
         
            +
                    if ax is None:
         
     | 
| 
      
 1161 
     | 
    
         
            +
                        fig, ax = plt.subplots()
         
     | 
| 
      
 1162 
     | 
    
         
            +
                    y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
         
     | 
| 
      
 1163 
     | 
    
         
            +
                    ax.hlines(
         
     | 
| 
      
 1164 
     | 
    
         
            +
                        y_axes_range,
         
     | 
| 
      
 1165 
     | 
    
         
            +
                        xmin=0,
         
     | 
| 
      
 1166 
     | 
    
         
            +
                        xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
         
     | 
| 
      
 1167 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1168 
     | 
    
         
            +
             
     | 
| 
      
 1169 
     | 
    
         
            +
                    ax.plot(
         
     | 
| 
      
 1170 
     | 
    
         
            +
                        results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
         
     | 
| 
      
 1171 
     | 
    
         
            +
                        y_axes_range,
         
     | 
| 
      
 1172 
     | 
    
         
            +
                        "o",
         
     | 
| 
      
 1173 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1174 
     | 
    
         
            +
             
     | 
| 
      
 1175 
     | 
    
         
            +
                    plt.xlabel("Mean Augur Score")
         
     | 
| 
      
 1176 
     | 
    
         
            +
                    plt.ylabel("Cell Type")
         
     | 
| 
      
 1177 
     | 
    
         
            +
                    plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
         
     | 
| 
      
 1178 
     | 
    
         
            +
             
     | 
| 
      
 1179 
     | 
    
         
            +
                    if save:
         
     | 
| 
      
 1180 
     | 
    
         
            +
                        plt.savefig(save, bbox_inches="tight")
         
     | 
| 
      
 1181 
     | 
    
         
            +
                    if show:
         
     | 
| 
      
 1182 
     | 
    
         
            +
                        plt.show()
         
     | 
| 
      
 1183 
     | 
    
         
            +
                    if return_fig:
         
     | 
| 
      
 1184 
     | 
    
         
            +
                        return plt.gcf()
         
     | 
| 
      
 1185 
     | 
    
         
            +
                    if not (show or save):
         
     | 
| 
      
 1186 
     | 
    
         
            +
                        return ax
         
     | 
| 
      
 1187 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 1188 
     | 
    
         
            +
             
     | 
| 
      
 1189 
     | 
    
         
            +
                def plot_scatterplot(
         
     | 
| 
      
 1190 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1191 
     | 
    
         
            +
                    results1: dict[str, Any],
         
     | 
| 
      
 1192 
     | 
    
         
            +
                    results2: dict[str, Any],
         
     | 
| 
      
 1193 
     | 
    
         
            +
                    top_n: int = None,
         
     | 
| 
      
 1194 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 1195 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 1196 
     | 
    
         
            +
                    save: str | bool | None = None,
         
     | 
| 
      
 1197 
     | 
    
         
            +
                ) -> Axes | Figure | None:
         
     | 
| 
      
 1198 
     | 
    
         
            +
                    """Create scatterplot with two augur results.
         
     | 
| 
      
 1199 
     | 
    
         
            +
             
     | 
| 
      
 1200 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 1201 
     | 
    
         
            +
                        results1: results after running `predict()`
         
     | 
| 
      
 1202 
     | 
    
         
            +
                        results2: results after running `predict()`
         
     | 
| 
      
 1203 
     | 
    
         
            +
                        top_n: optionally, the number of top prioritized cell types to label in the plot
         
     | 
| 
      
 1204 
     | 
    
         
            +
                        return_figure: if `True` returns figure of the plot
         
     | 
| 
      
 1205 
     | 
    
         
            +
             
     | 
| 
      
 1206 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 1207 
     | 
    
         
            +
                        Axes of the plot.
         
     | 
| 
      
 1208 
     | 
    
         
            +
             
     | 
| 
      
 1209 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 1210 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 1211 
     | 
    
         
            +
                        >>> adata = pt.dt.sc_sim_augur()
         
     | 
| 
      
 1212 
     | 
    
         
            +
                        >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
         
     | 
| 
      
 1213 
     | 
    
         
            +
                        >>> loaded_data = ag_rfc.load(adata)
         
     | 
| 
      
 1214 
     | 
    
         
            +
                        >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
         
     | 
| 
      
 1215 
     | 
    
         
            +
                        >>> v_adata, v_results = ag_rfc.predict(
         
     | 
| 
      
 1216 
     | 
    
         
            +
                        ...     loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
         
     | 
| 
      
 1217 
     | 
    
         
            +
                        ... )
         
     | 
| 
      
 1218 
     | 
    
         
            +
                        >>> ag_rfc.plot_scatterplot(v_results, h_results)
         
     | 
| 
      
 1219 
     | 
    
         
            +
             
     | 
| 
      
 1220 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 1221 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/augur_scatterplot.png
         
     | 
| 
      
 1222 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1223 
     | 
    
         
            +
                    cell_types = results1["summary_metrics"].columns
         
     | 
| 
      
 1224 
     | 
    
         
            +
             
     | 
| 
      
 1225 
     | 
    
         
            +
                    fig, ax = plt.subplots()
         
     | 
| 
      
 1226 
     | 
    
         
            +
                    ax.scatter(
         
     | 
| 
      
 1227 
     | 
    
         
            +
                        results1["summary_metrics"].loc["mean_augur_score", cell_types],
         
     | 
| 
      
 1228 
     | 
    
         
            +
                        results2["summary_metrics"].loc["mean_augur_score", cell_types],
         
     | 
| 
      
 1229 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1230 
     | 
    
         
            +
             
     | 
| 
      
 1231 
     | 
    
         
            +
                    # adding optional labels
         
     | 
| 
      
 1232 
     | 
    
         
            +
                    top_n_cell_types = (
         
     | 
| 
      
 1233 
     | 
    
         
            +
                        (results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
         
     | 
| 
      
 1234 
     | 
    
         
            +
                        .sort_values(ascending=False)
         
     | 
| 
      
 1235 
     | 
    
         
            +
                        .index[:top_n]
         
     | 
| 
      
 1236 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1237 
     | 
    
         
            +
                    for txt in top_n_cell_types:
         
     | 
| 
      
 1238 
     | 
    
         
            +
                        ax.annotate(
         
     | 
| 
      
 1239 
     | 
    
         
            +
                            txt,
         
     | 
| 
      
 1240 
     | 
    
         
            +
                            (
         
     | 
| 
      
 1241 
     | 
    
         
            +
                                results1["summary_metrics"].loc["mean_augur_score", txt],
         
     | 
| 
      
 1242 
     | 
    
         
            +
                                results2["summary_metrics"].loc["mean_augur_score", txt],
         
     | 
| 
      
 1243 
     | 
    
         
            +
                            ),
         
     | 
| 
      
 1244 
     | 
    
         
            +
                        )
         
     | 
| 
      
 1245 
     | 
    
         
            +
             
     | 
| 
      
 1246 
     | 
    
         
            +
                    # adding diagonal
         
     | 
| 
      
 1247 
     | 
    
         
            +
                    limits = max(ax.get_xlim(), ax.get_ylim())
         
     | 
| 
      
 1248 
     | 
    
         
            +
                    (diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
         
     | 
| 
      
 1249 
     | 
    
         
            +
             
     | 
| 
      
 1250 
     | 
    
         
            +
                    plt.xlabel("Augur scores 1")
         
     | 
| 
      
 1251 
     | 
    
         
            +
                    plt.ylabel("Augur scores 2")
         
     | 
| 
      
 1252 
     | 
    
         
            +
             
     | 
| 
      
 1253 
     | 
    
         
            +
                    if save:
         
     | 
| 
      
 1254 
     | 
    
         
            +
                        plt.savefig(save, bbox_inches="tight")
         
     | 
| 
      
 1255 
     | 
    
         
            +
                    if show:
         
     | 
| 
      
 1256 
     | 
    
         
            +
                        plt.show()
         
     | 
| 
      
 1257 
     | 
    
         
            +
                    if return_fig:
         
     | 
| 
      
 1258 
     | 
    
         
            +
                        return plt.gcf()
         
     | 
| 
      
 1259 
     | 
    
         
            +
                    if not (show or save):
         
     | 
| 
      
 1260 
     | 
    
         
            +
                        return ax
         
     | 
| 
      
 1261 
     | 
    
         
            +
                    return None
         
     |