pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
 - pertpy/data/__init__.py +66 -1
 - pertpy/data/_dataloader.py +28 -26
 - pertpy/data/_datasets.py +261 -92
 - pertpy/metadata/__init__.py +6 -0
 - pertpy/metadata/_cell_line.py +795 -0
 - pertpy/metadata/_compound.py +128 -0
 - pertpy/metadata/_drug.py +238 -0
 - pertpy/metadata/_look_up.py +569 -0
 - pertpy/metadata/_metadata.py +70 -0
 - pertpy/metadata/_moa.py +125 -0
 - pertpy/plot/__init__.py +0 -13
 - pertpy/preprocessing/__init__.py +2 -0
 - pertpy/preprocessing/_guide_rna.py +89 -6
 - pertpy/tools/__init__.py +48 -15
 - pertpy/tools/_augur.py +329 -32
 - pertpy/tools/_cinemaot.py +145 -6
 - pertpy/tools/_coda/_base_coda.py +1237 -116
 - pertpy/tools/_coda/_sccoda.py +66 -36
 - pertpy/tools/_coda/_tasccoda.py +46 -39
 - pertpy/tools/_dialogue.py +180 -77
 - pertpy/tools/_differential_gene_expression/__init__.py +20 -0
 - pertpy/tools/_differential_gene_expression/_base.py +657 -0
 - pertpy/tools/_differential_gene_expression/_checks.py +41 -0
 - pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
 - pertpy/tools/_differential_gene_expression/_edger.py +125 -0
 - pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
 - pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
 - pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
 - pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
 - pertpy/tools/_distances/_distance_tests.py +29 -24
 - pertpy/tools/_distances/_distances.py +584 -98
 - pertpy/tools/_enrichment.py +460 -0
 - pertpy/tools/_kernel_pca.py +1 -1
 - pertpy/tools/_milo.py +406 -49
 - pertpy/tools/_mixscape.py +677 -55
 - pertpy/tools/_perturbation_space/_clustering.py +10 -3
 - pertpy/tools/_perturbation_space/_comparison.py +112 -0
 - pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
 - pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
 - pertpy/tools/_perturbation_space/_simple.py +52 -11
 - pertpy/tools/_scgen/__init__.py +1 -1
 - pertpy/tools/_scgen/_base_components.py +2 -3
 - pertpy/tools/_scgen/_scgen.py +706 -0
 - pertpy/tools/_scgen/_utils.py +3 -5
 - pertpy/tools/decoupler_LICENSE +674 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
 - pertpy-0.8.0.dist-info/RECORD +57 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
 - pertpy/plot/_augur.py +0 -234
 - pertpy/plot/_cinemaot.py +0 -81
 - pertpy/plot/_coda.py +0 -1001
 - pertpy/plot/_dialogue.py +0 -91
 - pertpy/plot/_guide_rna.py +0 -82
 - pertpy/plot/_milopy.py +0 -284
 - pertpy/plot/_mixscape.py +0 -594
 - pertpy/plot/_scgen.py +0 -337
 - pertpy/tools/_differential_gene_expression.py +0 -99
 - pertpy/tools/_metadata/__init__.py +0 -0
 - pertpy/tools/_metadata/_cell_line.py +0 -613
 - pertpy/tools/_metadata/_look_up.py +0 -342
 - pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
 - pertpy/tools/_scgen/_jax_scgen.py +0 -370
 - pertpy-0.6.0.dist-info/RECORD +0 -50
 - /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
 
    
        pertpy/plot/_scgen.py
    DELETED
    
    | 
         @@ -1,337 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       2 
     | 
    
         
            -
            import pandas as pd
         
     | 
| 
       3 
     | 
    
         
            -
            import scanpy as sc
         
     | 
| 
       4 
     | 
    
         
            -
            from adjustText import adjust_text
         
     | 
| 
       5 
     | 
    
         
            -
            from matplotlib import pyplot
         
     | 
| 
       6 
     | 
    
         
            -
            from scipy import stats
         
     | 
| 
       7 
     | 
    
         
            -
            from scvi import REGISTRY_KEYS
         
     | 
| 
       8 
     | 
    
         
            -
             
     | 
| 
       9 
     | 
    
         
            -
             
     | 
| 
       10 
     | 
    
         
            -
            class JaxscgenPlot:
         
     | 
| 
       11 
     | 
    
         
            -
                """Plotting functions for Jaxscgen."""
         
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
     | 
    
         
            -
                @staticmethod
         
     | 
| 
       14 
     | 
    
         
            -
                def reg_mean_plot(
         
     | 
| 
       15 
     | 
    
         
            -
                    adata,
         
     | 
| 
       16 
     | 
    
         
            -
                    condition_key,
         
     | 
| 
       17 
     | 
    
         
            -
                    axis_keys,
         
     | 
| 
       18 
     | 
    
         
            -
                    labels,
         
     | 
| 
       19 
     | 
    
         
            -
                    path_to_save="./reg_mean.pdf",
         
     | 
| 
       20 
     | 
    
         
            -
                    save=True,
         
     | 
| 
       21 
     | 
    
         
            -
                    gene_list=None,
         
     | 
| 
       22 
     | 
    
         
            -
                    show=False,
         
     | 
| 
       23 
     | 
    
         
            -
                    top_100_genes=None,
         
     | 
| 
       24 
     | 
    
         
            -
                    verbose=False,
         
     | 
| 
       25 
     | 
    
         
            -
                    legend=True,
         
     | 
| 
       26 
     | 
    
         
            -
                    title=None,
         
     | 
| 
       27 
     | 
    
         
            -
                    x_coeff=0.30,
         
     | 
| 
       28 
     | 
    
         
            -
                    y_coeff=0.8,
         
     | 
| 
       29 
     | 
    
         
            -
                    fontsize=14,
         
     | 
| 
       30 
     | 
    
         
            -
                    **kwargs,
         
     | 
| 
       31 
     | 
    
         
            -
                ):
         
     | 
| 
       32 
     | 
    
         
            -
                    """Plots mean matching figure for a set of specific genes.
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
                    Args:
         
     | 
| 
       35 
     | 
    
         
            -
                        adata:  AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
         
     | 
| 
       36 
     | 
    
         
            -
                                AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
         
     | 
| 
       37 
     | 
    
         
            -
                                corresponding to batch and cell type metadata, respectively.
         
     | 
| 
       38 
     | 
    
         
            -
                        condition_key: The key for the condition
         
     | 
| 
       39 
     | 
    
         
            -
                        axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
         
     | 
| 
       40 
     | 
    
         
            -
                                   `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
         
     | 
| 
       41 
     | 
    
         
            -
                        labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
         
     | 
| 
       42 
     | 
    
         
            -
                        path_to_save: path to save the plot.
         
     | 
| 
       43 
     | 
    
         
            -
                        save: Specify if the plot should be saved or not.
         
     | 
| 
       44 
     | 
    
         
            -
                        gene_list: list of gene names to be plotted.
         
     | 
| 
       45 
     | 
    
         
            -
                        show: if `True`: will show to the plot after saving it.
         
     | 
| 
       46 
     | 
    
         
            -
                        top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
         
     | 
| 
       47 
     | 
    
         
            -
                        verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
         
     | 
| 
       48 
     | 
    
         
            -
                        legend: if `True`: plots a legend, defaults to `True`.
         
     | 
| 
       49 
     | 
    
         
            -
                        title: Set if you want the plot to display a title.
         
     | 
| 
       50 
     | 
    
         
            -
                        x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
         
     | 
| 
       51 
     | 
    
         
            -
                        y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
         
     | 
| 
       52 
     | 
    
         
            -
                        fontsize: Fontsize used for text in the plot, defaults to 14.
         
     | 
| 
       53 
     | 
    
         
            -
                        **kwargs:
         
     | 
| 
       54 
     | 
    
         
            -
             
     | 
| 
       55 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       56 
     | 
    
         
            -
                        >>> import pertpy at pt
         
     | 
| 
       57 
     | 
    
         
            -
                        >>> data = pt.dt.kang_2018()
         
     | 
| 
       58 
     | 
    
         
            -
                        >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
         
     | 
| 
       59 
     | 
    
         
            -
                        >>> model = pt.tl.SCGEN(data)
         
     | 
| 
       60 
     | 
    
         
            -
                        >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
         
     | 
| 
       61 
     | 
    
         
            -
                        >>> pred, delta = model.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
         
     | 
| 
       62 
     | 
    
         
            -
                        >>> pred.obs['label'] = 'pred'
         
     | 
| 
       63 
     | 
    
         
            -
                        >>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred)
         
     | 
| 
       64 
     | 
    
         
            -
                        >>> r2_value = pt.pl.scg.reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \
         
     | 
| 
       65 
     | 
    
         
            -
                            labels={"x": "predicted", "y": "ground truth"}, save=False, show=True)
         
     | 
| 
       66 
     | 
    
         
            -
                    """
         
     | 
| 
       67 
     | 
    
         
            -
                    import seaborn as sns
         
     | 
| 
       68 
     | 
    
         
            -
             
     | 
| 
       69 
     | 
    
         
            -
                    sns.set()
         
     | 
| 
       70 
     | 
    
         
            -
                    sns.set(color_codes=True)
         
     | 
| 
       71 
     | 
    
         
            -
             
     | 
| 
       72 
     | 
    
         
            -
                    diff_genes = top_100_genes
         
     | 
| 
       73 
     | 
    
         
            -
                    stim = adata[adata.obs[condition_key] == axis_keys["y"]]
         
     | 
| 
       74 
     | 
    
         
            -
                    ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
         
     | 
| 
       75 
     | 
    
         
            -
                    if diff_genes is not None:
         
     | 
| 
       76 
     | 
    
         
            -
                        if hasattr(diff_genes, "tolist"):
         
     | 
| 
       77 
     | 
    
         
            -
                            diff_genes = diff_genes.tolist()
         
     | 
| 
       78 
     | 
    
         
            -
                        adata_diff = adata[:, diff_genes]
         
     | 
| 
       79 
     | 
    
         
            -
                        stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
         
     | 
| 
       80 
     | 
    
         
            -
                        ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
         
     | 
| 
       81 
     | 
    
         
            -
                        x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel()
         
     | 
| 
       82 
     | 
    
         
            -
                        y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
         
     | 
| 
       83 
     | 
    
         
            -
                        m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
         
     | 
| 
       84 
     | 
    
         
            -
                        if verbose:
         
     | 
| 
       85 
     | 
    
         
            -
                            print("top_100 DEGs mean: ", r_value_diff**2)
         
     | 
| 
       86 
     | 
    
         
            -
                    x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
         
     | 
| 
       87 
     | 
    
         
            -
                    y = np.asarray(np.mean(stim.X, axis=0)).ravel()
         
     | 
| 
       88 
     | 
    
         
            -
                    m, b, r_value, p_value, std_err = stats.linregress(x, y)
         
     | 
| 
       89 
     | 
    
         
            -
                    if verbose:
         
     | 
| 
       90 
     | 
    
         
            -
                        print("All genes mean: ", r_value**2)
         
     | 
| 
       91 
     | 
    
         
            -
                    df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
         
     | 
| 
       92 
     | 
    
         
            -
                    ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
         
     | 
| 
       93 
     | 
    
         
            -
                    ax.tick_params(labelsize=fontsize)
         
     | 
| 
       94 
     | 
    
         
            -
                    if "range" in kwargs:
         
     | 
| 
       95 
     | 
    
         
            -
                        start, stop, step = kwargs.get("range")
         
     | 
| 
       96 
     | 
    
         
            -
                        ax.set_xticks(np.arange(start, stop, step))
         
     | 
| 
       97 
     | 
    
         
            -
                        ax.set_yticks(np.arange(start, stop, step))
         
     | 
| 
       98 
     | 
    
         
            -
                    ax.set_xlabel(labels["x"], fontsize=fontsize)
         
     | 
| 
       99 
     | 
    
         
            -
                    ax.set_ylabel(labels["y"], fontsize=fontsize)
         
     | 
| 
       100 
     | 
    
         
            -
                    if gene_list is not None:
         
     | 
| 
       101 
     | 
    
         
            -
                        texts = []
         
     | 
| 
       102 
     | 
    
         
            -
                        for i in gene_list:
         
     | 
| 
       103 
     | 
    
         
            -
                            j = adata.var_names.tolist().index(i)
         
     | 
| 
       104 
     | 
    
         
            -
                            x_bar = x[j]
         
     | 
| 
       105 
     | 
    
         
            -
                            y_bar = y[j]
         
     | 
| 
       106 
     | 
    
         
            -
                            texts.append(pyplot.text(x_bar, y_bar, i, fontsize=11, color="black"))
         
     | 
| 
       107 
     | 
    
         
            -
                            pyplot.plot(x_bar, y_bar, "o", color="red", markersize=5)
         
     | 
| 
       108 
     | 
    
         
            -
                            # if "y1" in axis_keys.keys():
         
     | 
| 
       109 
     | 
    
         
            -
                            # y1_bar = y1[j]
         
     | 
| 
       110 
     | 
    
         
            -
                            # pyplot.text(x_bar, y1_bar, i, fontsize=11, color="black")
         
     | 
| 
       111 
     | 
    
         
            -
                    if gene_list is not None:
         
     | 
| 
       112 
     | 
    
         
            -
                        adjust_text(
         
     | 
| 
       113 
     | 
    
         
            -
                            texts,
         
     | 
| 
       114 
     | 
    
         
            -
                            x=x,
         
     | 
| 
       115 
     | 
    
         
            -
                            y=y,
         
     | 
| 
       116 
     | 
    
         
            -
                            arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5},
         
     | 
| 
       117 
     | 
    
         
            -
                            force_points=(0.0, 0.0),
         
     | 
| 
       118 
     | 
    
         
            -
                        )
         
     | 
| 
       119 
     | 
    
         
            -
                    if legend:
         
     | 
| 
       120 
     | 
    
         
            -
                        pyplot.legend(loc="center left", bbox_to_anchor=(1, 0.5))
         
     | 
| 
       121 
     | 
    
         
            -
                    if title is None:
         
     | 
| 
       122 
     | 
    
         
            -
                        pyplot.title("", fontsize=fontsize)
         
     | 
| 
       123 
     | 
    
         
            -
                    else:
         
     | 
| 
       124 
     | 
    
         
            -
                        pyplot.title(title, fontsize=fontsize)
         
     | 
| 
       125 
     | 
    
         
            -
                    ax.text(
         
     | 
| 
       126 
     | 
    
         
            -
                        max(x) - max(x) * x_coeff,
         
     | 
| 
       127 
     | 
    
         
            -
                        max(y) - y_coeff * max(y),
         
     | 
| 
       128 
     | 
    
         
            -
                        r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
         
     | 
| 
       129 
     | 
    
         
            -
                        fontsize=kwargs.get("textsize", fontsize),
         
     | 
| 
       130 
     | 
    
         
            -
                    )
         
     | 
| 
       131 
     | 
    
         
            -
                    if diff_genes is not None:
         
     | 
| 
       132 
     | 
    
         
            -
                        ax.text(
         
     | 
| 
       133 
     | 
    
         
            -
                            max(x) - max(x) * x_coeff,
         
     | 
| 
       134 
     | 
    
         
            -
                            max(y) - (y_coeff + 0.15) * max(y),
         
     | 
| 
       135 
     | 
    
         
            -
                            r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
         
     | 
| 
       136 
     | 
    
         
            -
                            fontsize=kwargs.get("textsize", fontsize),
         
     | 
| 
       137 
     | 
    
         
            -
                        )
         
     | 
| 
       138 
     | 
    
         
            -
                    if save:
         
     | 
| 
       139 
     | 
    
         
            -
                        pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
         
     | 
| 
       140 
     | 
    
         
            -
                    if show:
         
     | 
| 
       141 
     | 
    
         
            -
                        pyplot.show()
         
     | 
| 
       142 
     | 
    
         
            -
                    pyplot.close()
         
     | 
| 
       143 
     | 
    
         
            -
                    if diff_genes is not None:
         
     | 
| 
       144 
     | 
    
         
            -
                        return r_value**2, r_value_diff**2
         
     | 
| 
       145 
     | 
    
         
            -
                    else:
         
     | 
| 
       146 
     | 
    
         
            -
                        return r_value**2
         
     | 
| 
       147 
     | 
    
         
            -
             
     | 
| 
       148 
     | 
    
         
            -
                @staticmethod
         
     | 
| 
       149 
     | 
    
         
            -
                def reg_var_plot(
         
     | 
| 
       150 
     | 
    
         
            -
                    adata,
         
     | 
| 
       151 
     | 
    
         
            -
                    condition_key,
         
     | 
| 
       152 
     | 
    
         
            -
                    axis_keys,
         
     | 
| 
       153 
     | 
    
         
            -
                    labels,
         
     | 
| 
       154 
     | 
    
         
            -
                    path_to_save="./reg_var.pdf",
         
     | 
| 
       155 
     | 
    
         
            -
                    save=True,
         
     | 
| 
       156 
     | 
    
         
            -
                    gene_list=None,
         
     | 
| 
       157 
     | 
    
         
            -
                    top_100_genes=None,
         
     | 
| 
       158 
     | 
    
         
            -
                    show=False,
         
     | 
| 
       159 
     | 
    
         
            -
                    legend=True,
         
     | 
| 
       160 
     | 
    
         
            -
                    title=None,
         
     | 
| 
       161 
     | 
    
         
            -
                    verbose=False,
         
     | 
| 
       162 
     | 
    
         
            -
                    x_coeff=0.30,
         
     | 
| 
       163 
     | 
    
         
            -
                    y_coeff=0.8,
         
     | 
| 
       164 
     | 
    
         
            -
                    fontsize=14,
         
     | 
| 
       165 
     | 
    
         
            -
                    **kwargs,
         
     | 
| 
       166 
     | 
    
         
            -
                ):
         
     | 
| 
       167 
     | 
    
         
            -
                    """Plots variance matching figure for a set of specific genes.
         
     | 
| 
       168 
     | 
    
         
            -
             
     | 
| 
       169 
     | 
    
         
            -
                    Args:
         
     | 
| 
       170 
     | 
    
         
            -
                        adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
         
     | 
| 
       171 
     | 
    
         
            -
                               AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
         
     | 
| 
       172 
     | 
    
         
            -
                               corresponding to batch and cell type metadata, respectively.
         
     | 
| 
       173 
     | 
    
         
            -
                        condition_key: Key of the condition.
         
     | 
| 
       174 
     | 
    
         
            -
                        axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
         
     | 
| 
       175 
     | 
    
         
            -
                                   `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
         
     | 
| 
       176 
     | 
    
         
            -
                        labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
         
     | 
| 
       177 
     | 
    
         
            -
                        path_to_save: path to save the plot.
         
     | 
| 
       178 
     | 
    
         
            -
                        save: Specify if the plot should be saved or not.
         
     | 
| 
       179 
     | 
    
         
            -
                        gene_list: list of gene names to be plotted.
         
     | 
| 
       180 
     | 
    
         
            -
                        show: if `True`: will show to the plot after saving it.
         
     | 
| 
       181 
     | 
    
         
            -
                        top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
         
     | 
| 
       182 
     | 
    
         
            -
                        legend: if `True`: plots a legend, defaults to `True`.
         
     | 
| 
       183 
     | 
    
         
            -
                        title: Set if you want the plot to display a title.
         
     | 
| 
       184 
     | 
    
         
            -
                        verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
         
     | 
| 
       185 
     | 
    
         
            -
                        x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
         
     | 
| 
       186 
     | 
    
         
            -
                        y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
         
     | 
| 
       187 
     | 
    
         
            -
                        fontsize: Fontsize used for text in the plot, defaults to 14.
         
     | 
| 
       188 
     | 
    
         
            -
                    """
         
     | 
| 
       189 
     | 
    
         
            -
                    import seaborn as sns
         
     | 
| 
       190 
     | 
    
         
            -
             
     | 
| 
       191 
     | 
    
         
            -
                    sns.set()
         
     | 
| 
       192 
     | 
    
         
            -
                    sns.set(color_codes=True)
         
     | 
| 
       193 
     | 
    
         
            -
             
     | 
| 
       194 
     | 
    
         
            -
                    sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon")
         
     | 
| 
       195 
     | 
    
         
            -
                    diff_genes = top_100_genes
         
     | 
| 
       196 
     | 
    
         
            -
                    stim = adata[adata.obs[condition_key] == axis_keys["y"]]
         
     | 
| 
       197 
     | 
    
         
            -
                    ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
         
     | 
| 
       198 
     | 
    
         
            -
                    if diff_genes is not None:
         
     | 
| 
       199 
     | 
    
         
            -
                        if hasattr(diff_genes, "tolist"):
         
     | 
| 
       200 
     | 
    
         
            -
                            diff_genes = diff_genes.tolist()
         
     | 
| 
       201 
     | 
    
         
            -
                        adata_diff = adata[:, diff_genes]
         
     | 
| 
       202 
     | 
    
         
            -
                        stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
         
     | 
| 
       203 
     | 
    
         
            -
                        ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
         
     | 
| 
       204 
     | 
    
         
            -
                        x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel()
         
     | 
| 
       205 
     | 
    
         
            -
                        y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
         
     | 
| 
       206 
     | 
    
         
            -
                        m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
         
     | 
| 
       207 
     | 
    
         
            -
                        if verbose:
         
     | 
| 
       208 
     | 
    
         
            -
                            print("Top 100 DEGs var: ", r_value_diff**2)
         
     | 
| 
       209 
     | 
    
         
            -
                    if "y1" in axis_keys.keys():
         
     | 
| 
       210 
     | 
    
         
            -
                        real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
         
     | 
| 
       211 
     | 
    
         
            -
                    x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
         
     | 
| 
       212 
     | 
    
         
            -
                    y = np.asarray(np.var(stim.X, axis=0)).ravel()
         
     | 
| 
       213 
     | 
    
         
            -
                    m, b, r_value, p_value, std_err = stats.linregress(x, y)
         
     | 
| 
       214 
     | 
    
         
            -
                    if verbose:
         
     | 
| 
       215 
     | 
    
         
            -
                        print("All genes var: ", r_value**2)
         
     | 
| 
       216 
     | 
    
         
            -
                    df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
         
     | 
| 
       217 
     | 
    
         
            -
                    ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
         
     | 
| 
       218 
     | 
    
         
            -
                    ax.tick_params(labelsize=fontsize)
         
     | 
| 
       219 
     | 
    
         
            -
                    if "range" in kwargs:
         
     | 
| 
       220 
     | 
    
         
            -
                        start, stop, step = kwargs.get("range")
         
     | 
| 
       221 
     | 
    
         
            -
                        ax.set_xticks(np.arange(start, stop, step))
         
     | 
| 
       222 
     | 
    
         
            -
                        ax.set_yticks(np.arange(start, stop, step))
         
     | 
| 
       223 
     | 
    
         
            -
                    # _p1 = pyplot.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}")
         
     | 
| 
       224 
     | 
    
         
            -
                    # pyplot.plot(x, m * x + b, "-", color="green")
         
     | 
| 
       225 
     | 
    
         
            -
                    ax.set_xlabel(labels["x"], fontsize=fontsize)
         
     | 
| 
       226 
     | 
    
         
            -
                    ax.set_ylabel(labels["y"], fontsize=fontsize)
         
     | 
| 
       227 
     | 
    
         
            -
                    if "y1" in axis_keys.keys():
         
     | 
| 
       228 
     | 
    
         
            -
                        y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
         
     | 
| 
       229 
     | 
    
         
            -
                        _ = pyplot.scatter(
         
     | 
| 
       230 
     | 
    
         
            -
                            x,
         
     | 
| 
       231 
     | 
    
         
            -
                            y1,
         
     | 
| 
       232 
     | 
    
         
            -
                            marker="*",
         
     | 
| 
       233 
     | 
    
         
            -
                            c="grey",
         
     | 
| 
       234 
     | 
    
         
            -
                            alpha=0.5,
         
     | 
| 
       235 
     | 
    
         
            -
                            label=f"{axis_keys['x']}-{axis_keys['y1']}",
         
     | 
| 
       236 
     | 
    
         
            -
                        )
         
     | 
| 
       237 
     | 
    
         
            -
                    if gene_list is not None:
         
     | 
| 
       238 
     | 
    
         
            -
                        for i in gene_list:
         
     | 
| 
       239 
     | 
    
         
            -
                            j = adata.var_names.tolist().index(i)
         
     | 
| 
       240 
     | 
    
         
            -
                            x_bar = x[j]
         
     | 
| 
       241 
     | 
    
         
            -
                            y_bar = y[j]
         
     | 
| 
       242 
     | 
    
         
            -
                            pyplot.text(x_bar, y_bar, i, fontsize=11, color="black")
         
     | 
| 
       243 
     | 
    
         
            -
                            pyplot.plot(x_bar, y_bar, "o", color="red", markersize=5)
         
     | 
| 
       244 
     | 
    
         
            -
                            if "y1" in axis_keys.keys():
         
     | 
| 
       245 
     | 
    
         
            -
                                y1_bar = y1[j]
         
     | 
| 
       246 
     | 
    
         
            -
                                pyplot.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
         
     | 
| 
       247 
     | 
    
         
            -
                    if legend:
         
     | 
| 
       248 
     | 
    
         
            -
                        pyplot.legend(loc="center left", bbox_to_anchor=(1, 0.5))
         
     | 
| 
       249 
     | 
    
         
            -
                    if title is None:
         
     | 
| 
       250 
     | 
    
         
            -
                        pyplot.title("", fontsize=12)
         
     | 
| 
       251 
     | 
    
         
            -
                    else:
         
     | 
| 
       252 
     | 
    
         
            -
                        pyplot.title(title, fontsize=12)
         
     | 
| 
       253 
     | 
    
         
            -
                    ax.text(
         
     | 
| 
       254 
     | 
    
         
            -
                        max(x) - max(x) * x_coeff,
         
     | 
| 
       255 
     | 
    
         
            -
                        max(y) - y_coeff * max(y),
         
     | 
| 
       256 
     | 
    
         
            -
                        r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
         
     | 
| 
       257 
     | 
    
         
            -
                        fontsize=kwargs.get("textsize", fontsize),
         
     | 
| 
       258 
     | 
    
         
            -
                    )
         
     | 
| 
       259 
     | 
    
         
            -
                    if diff_genes is not None:
         
     | 
| 
       260 
     | 
    
         
            -
                        ax.text(
         
     | 
| 
       261 
     | 
    
         
            -
                            max(x) - max(x) * x_coeff,
         
     | 
| 
       262 
     | 
    
         
            -
                            max(y) - (y_coeff + 0.15) * max(y),
         
     | 
| 
       263 
     | 
    
         
            -
                            r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
         
     | 
| 
       264 
     | 
    
         
            -
                            fontsize=kwargs.get("textsize", fontsize),
         
     | 
| 
       265 
     | 
    
         
            -
                        )
         
     | 
| 
       266 
     | 
    
         
            -
             
     | 
| 
       267 
     | 
    
         
            -
                    if save:
         
     | 
| 
       268 
     | 
    
         
            -
                        pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
         
     | 
| 
       269 
     | 
    
         
            -
                    if show:
         
     | 
| 
       270 
     | 
    
         
            -
                        pyplot.show()
         
     | 
| 
       271 
     | 
    
         
            -
                    pyplot.close()
         
     | 
| 
       272 
     | 
    
         
            -
                    if diff_genes is not None:
         
     | 
| 
       273 
     | 
    
         
            -
                        return r_value**2, r_value_diff**2
         
     | 
| 
       274 
     | 
    
         
            -
                    else:
         
     | 
| 
       275 
     | 
    
         
            -
                        return r_value**2
         
     | 
| 
       276 
     | 
    
         
            -
             
     | 
| 
       277 
     | 
    
         
            -
                @staticmethod
         
     | 
| 
       278 
     | 
    
         
            -
                def binary_classifier(
         
     | 
| 
       279 
     | 
    
         
            -
                    scgen,
         
     | 
| 
       280 
     | 
    
         
            -
                    adata,
         
     | 
| 
       281 
     | 
    
         
            -
                    delta,
         
     | 
| 
       282 
     | 
    
         
            -
                    ctrl_key,
         
     | 
| 
       283 
     | 
    
         
            -
                    stim_key,
         
     | 
| 
       284 
     | 
    
         
            -
                    path_to_save,
         
     | 
| 
       285 
     | 
    
         
            -
                    save=True,
         
     | 
| 
       286 
     | 
    
         
            -
                    fontsize=14,
         
     | 
| 
       287 
     | 
    
         
            -
                ):
         
     | 
| 
       288 
     | 
    
         
            -
                    """Latent space classifier.
         
     | 
| 
       289 
     | 
    
         
            -
             
     | 
| 
       290 
     | 
    
         
            -
                    Builds a linear classifier based on the dot product between
         
     | 
| 
       291 
     | 
    
         
            -
                    the difference vector and the latent representation of each
         
     | 
| 
       292 
     | 
    
         
            -
                    cell and plots the dot product results between delta and latent representation.
         
     | 
| 
       293 
     | 
    
         
            -
             
     | 
| 
       294 
     | 
    
         
            -
                    Args:
         
     | 
| 
       295 
     | 
    
         
            -
                        scgen: ScGen object that was trained.
         
     | 
| 
       296 
     | 
    
         
            -
                        adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
         
     | 
| 
       297 
     | 
    
         
            -
                               AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
         
     | 
| 
       298 
     | 
    
         
            -
                               corresponding to batch and cell type metadata, respectively.
         
     | 
| 
       299 
     | 
    
         
            -
                        delta: Difference between stimulated and control cells in latent space
         
     | 
| 
       300 
     | 
    
         
            -
                        ctrl_key: Key for `control` part of the `data` found in `condition_key`.
         
     | 
| 
       301 
     | 
    
         
            -
                        stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
         
     | 
| 
       302 
     | 
    
         
            -
                        path_to_save: Path to save the plot.
         
     | 
| 
       303 
     | 
    
         
            -
                        save: Specify if the plot should be saved or not.
         
     | 
| 
       304 
     | 
    
         
            -
                        fontsize: Set the font size of the plot.
         
     | 
| 
       305 
     | 
    
         
            -
                    """
         
     | 
| 
       306 
     | 
    
         
            -
                    # matplotlib.rcParams.update(matplotlib.rcParamsDefault)
         
     | 
| 
       307 
     | 
    
         
            -
                    pyplot.close("all")
         
     | 
| 
       308 
     | 
    
         
            -
                    adata = scgen._validate_anndata(adata)
         
     | 
| 
       309 
     | 
    
         
            -
                    condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
         
     | 
| 
       310 
     | 
    
         
            -
                    cd = adata[adata.obs[condition_key] == ctrl_key, :]
         
     | 
| 
       311 
     | 
    
         
            -
                    stim = adata[adata.obs[condition_key] == stim_key, :]
         
     | 
| 
       312 
     | 
    
         
            -
                    all_latent_cd = scgen.get_latent_representation(cd.X)
         
     | 
| 
       313 
     | 
    
         
            -
                    all_latent_stim = scgen.get_latent_representation(stim.X)
         
     | 
| 
       314 
     | 
    
         
            -
                    dot_cd = np.zeros(len(all_latent_cd))
         
     | 
| 
       315 
     | 
    
         
            -
                    dot_sal = np.zeros(len(all_latent_stim))
         
     | 
| 
       316 
     | 
    
         
            -
                    for ind, vec in enumerate(all_latent_cd):
         
     | 
| 
       317 
     | 
    
         
            -
                        dot_cd[ind] = np.dot(delta, vec)
         
     | 
| 
       318 
     | 
    
         
            -
                    for ind, vec in enumerate(all_latent_stim):
         
     | 
| 
       319 
     | 
    
         
            -
                        dot_sal[ind] = np.dot(delta, vec)
         
     | 
| 
       320 
     | 
    
         
            -
                    pyplot.hist(
         
     | 
| 
       321 
     | 
    
         
            -
                        dot_cd,
         
     | 
| 
       322 
     | 
    
         
            -
                        label=ctrl_key,
         
     | 
| 
       323 
     | 
    
         
            -
                        bins=50,
         
     | 
| 
       324 
     | 
    
         
            -
                    )
         
     | 
| 
       325 
     | 
    
         
            -
                    pyplot.hist(dot_sal, label=stim_key, bins=50)
         
     | 
| 
       326 
     | 
    
         
            -
                    pyplot.axvline(0, color="k", linestyle="dashed", linewidth=1)
         
     | 
| 
       327 
     | 
    
         
            -
                    pyplot.title("  ", fontsize=fontsize)
         
     | 
| 
       328 
     | 
    
         
            -
                    pyplot.xlabel("  ", fontsize=fontsize)
         
     | 
| 
       329 
     | 
    
         
            -
                    pyplot.ylabel("  ", fontsize=fontsize)
         
     | 
| 
       330 
     | 
    
         
            -
                    pyplot.xticks(fontsize=fontsize)
         
     | 
| 
       331 
     | 
    
         
            -
                    pyplot.yticks(fontsize=fontsize)
         
     | 
| 
       332 
     | 
    
         
            -
                    ax = pyplot.gca()
         
     | 
| 
       333 
     | 
    
         
            -
                    ax.grid(False)
         
     | 
| 
       334 
     | 
    
         
            -
             
     | 
| 
       335 
     | 
    
         
            -
                    if save:
         
     | 
| 
       336 
     | 
    
         
            -
                        pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
         
     | 
| 
       337 
     | 
    
         
            -
                    pyplot.show()
         
     | 
| 
         @@ -1,99 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from __future__ import annotations
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            from typing import TYPE_CHECKING, Literal
         
     | 
| 
       4 
     | 
    
         
            -
             
     | 
| 
       5 
     | 
    
         
            -
            import decoupler as dc
         
     | 
| 
       6 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       7 
     | 
    
         
            -
            import numpy.typing as npt
         
     | 
| 
       8 
     | 
    
         
            -
             
     | 
| 
       9 
     | 
    
         
            -
            if TYPE_CHECKING:
         
     | 
| 
       10 
     | 
    
         
            -
                import pandas as pd
         
     | 
| 
       11 
     | 
    
         
            -
                from anndata import AnnData
         
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
     | 
    
         
            -
             
     | 
| 
       14 
     | 
    
         
            -
            class DifferentialGeneExpression:
         
     | 
| 
       15 
     | 
    
         
            -
                """Support for differential gene expression for scverse."""
         
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
                def pseudobulk(
         
     | 
| 
       18 
     | 
    
         
            -
                    self,
         
     | 
| 
       19 
     | 
    
         
            -
                    adata: AnnData,
         
     | 
| 
       20 
     | 
    
         
            -
                    sample_col: str,
         
     | 
| 
       21 
     | 
    
         
            -
                    groups_col: str,
         
     | 
| 
       22 
     | 
    
         
            -
                    obs: pd.DataFrame = None,
         
     | 
| 
       23 
     | 
    
         
            -
                    layer: str = None,
         
     | 
| 
       24 
     | 
    
         
            -
                    use_raw: bool = False,
         
     | 
| 
       25 
     | 
    
         
            -
                    min_prop: float = 0.2,
         
     | 
| 
       26 
     | 
    
         
            -
                    min_counts: int = 1000,
         
     | 
| 
       27 
     | 
    
         
            -
                    min_samples: int = 2,
         
     | 
| 
       28 
     | 
    
         
            -
                    dtype: npt.DTypeLike = np.float32,
         
     | 
| 
       29 
     | 
    
         
            -
                ) -> AnnData:
         
     | 
| 
       30 
     | 
    
         
            -
                    """Generate Pseudobulk for DE analysis.
         
     | 
| 
       31 
     | 
    
         
            -
             
     | 
| 
       32 
     | 
    
         
            -
                    Wraps decoupler's get_pseudobulk function.
         
     | 
| 
       33 
     | 
    
         
            -
                    See: https://decoupler-py.readthedocs.io/en/latest/generated/decoupler.get_pseudobulk.html#decoupler.get_pseudobulk
         
     | 
| 
       34 
     | 
    
         
            -
                    for more details
         
     | 
| 
       35 
     | 
    
         
            -
             
     | 
| 
       36 
     | 
    
         
            -
                    Args:
         
     | 
| 
       37 
     | 
    
         
            -
                        adata: Input AnnData object.
         
     | 
| 
       38 
     | 
    
         
            -
                        sample_col: Column of obs where to extract the samples names.
         
     | 
| 
       39 
     | 
    
         
            -
                        groups_col: Column of obs where to extract the groups names.
         
     | 
| 
       40 
     | 
    
         
            -
                        obs: If provided, metadata dataframe.
         
     | 
| 
       41 
     | 
    
         
            -
                        layer: If provided, which layer to use.
         
     | 
| 
       42 
     | 
    
         
            -
                        use_raw: Use raw attribute of adata if present.
         
     | 
| 
       43 
     | 
    
         
            -
                        min_prop: Minimum proportion of cells with non-zero values.
         
     | 
| 
       44 
     | 
    
         
            -
                        min_counts: Minimum number of cells per sample.
         
     | 
| 
       45 
     | 
    
         
            -
                        min_samples: Minimum number of samples per feature.
         
     | 
| 
       46 
     | 
    
         
            -
                        dtype: Type of float used.
         
     | 
| 
       47 
     | 
    
         
            -
             
     | 
| 
       48 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       49 
     | 
    
         
            -
                        Returns new AnnData object with unormalized pseudobulk profiles per sample and group.
         
     | 
| 
       50 
     | 
    
         
            -
                    """
         
     | 
| 
       51 
     | 
    
         
            -
                    pseudobulk_adata = dc.get_pseudobulk(
         
     | 
| 
       52 
     | 
    
         
            -
                        adata,
         
     | 
| 
       53 
     | 
    
         
            -
                        sample_col=sample_col,
         
     | 
| 
       54 
     | 
    
         
            -
                        groups_col=groups_col,
         
     | 
| 
       55 
     | 
    
         
            -
                        obs=obs,
         
     | 
| 
       56 
     | 
    
         
            -
                        layer=layer,
         
     | 
| 
       57 
     | 
    
         
            -
                        use_raw=use_raw,
         
     | 
| 
       58 
     | 
    
         
            -
                        min_prop=min_prop,
         
     | 
| 
       59 
     | 
    
         
            -
                        min_counts=min_counts,
         
     | 
| 
       60 
     | 
    
         
            -
                        min_smpls=min_samples,
         
     | 
| 
       61 
     | 
    
         
            -
                        dtype=dtype,
         
     | 
| 
       62 
     | 
    
         
            -
                    )
         
     | 
| 
       63 
     | 
    
         
            -
             
     | 
| 
       64 
     | 
    
         
            -
                    return pseudobulk_adata
         
     | 
| 
       65 
     | 
    
         
            -
             
     | 
| 
       66 
     | 
    
         
            -
                def de_analysis(
         
     | 
| 
       67 
     | 
    
         
            -
                    self,
         
     | 
| 
       68 
     | 
    
         
            -
                    adata: AnnData,
         
     | 
| 
       69 
     | 
    
         
            -
                    groupby: str,
         
     | 
| 
       70 
     | 
    
         
            -
                    method: Literal["t-test", "wilcoxon", "pydeseq2", "deseq2", "edger"],
         
     | 
| 
       71 
     | 
    
         
            -
                    *formula: str | None,
         
     | 
| 
       72 
     | 
    
         
            -
                    contrast: str | None,
         
     | 
| 
       73 
     | 
    
         
            -
                    inplace: bool = True,
         
     | 
| 
       74 
     | 
    
         
            -
                    key_added: str | None,
         
     | 
| 
       75 
     | 
    
         
            -
                ) -> pd.DataFrame:
         
     | 
| 
       76 
     | 
    
         
            -
                    """Perform differential expression analysis.
         
     | 
| 
       77 
     | 
    
         
            -
             
     | 
| 
       78 
     | 
    
         
            -
                    Args:
         
     | 
| 
       79 
     | 
    
         
            -
                        adata: single-cell or pseudobulk AnnData object
         
     | 
| 
       80 
     | 
    
         
            -
                        groupby: Column in adata.obs that contains the factor to test, e.g. `treatment`.
         
     | 
| 
       81 
     | 
    
         
            -
                                 For simple statistical tests (t-test, wilcoxon), it is sufficient to specify groupby.
         
     | 
| 
       82 
     | 
    
         
            -
                                 Linear models require to specify a formula.
         
     | 
| 
       83 
     | 
    
         
            -
                                 In that case, the `groupby` column is used to compute the contrast.
         
     | 
| 
       84 
     | 
    
         
            -
                        method: Which method to use to perform the DE test.
         
     | 
| 
       85 
     | 
    
         
            -
                        formula: model specification for linear models. E.g. `~ treatment + sex + age`.
         
     | 
| 
       86 
     | 
    
         
            -
                                 MUST contain the factor specified in `groupby`.
         
     | 
| 
       87 
     | 
    
         
            -
                        contrast: See e.g. https://www.statsmodels.org/devel/contrasts.html for more information.
         
     | 
| 
       88 
     | 
    
         
            -
                        inplace: if True, save the result in `adata.varm[key_added]`
         
     | 
| 
       89 
     | 
    
         
            -
                        key_added: Key under which the result is saved in `adata.varm` if inplace is True.
         
     | 
| 
       90 
     | 
    
         
            -
                                   If set to None this defaults to `de_{method}_{groupby}`.
         
     | 
| 
       91 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       92 
     | 
    
         
            -
                        Depending on the method a Pandas DataFrame containing at least:
         
     | 
| 
       93 
     | 
    
         
            -
                        * gene_id
         
     | 
| 
       94 
     | 
    
         
            -
                        * log2 fold change
         
     | 
| 
       95 
     | 
    
         
            -
                        * mean expression
         
     | 
| 
       96 
     | 
    
         
            -
                        * unadjusted p-value
         
     | 
| 
       97 
     | 
    
         
            -
                        * adjusted p-value
         
     | 
| 
       98 
     | 
    
         
            -
                    """
         
     | 
| 
       99 
     | 
    
         
            -
                    raise NotImplementedError
         
     | 
| 
         
            File without changes
         
     |