pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
 - pertpy/data/__init__.py +66 -1
 - pertpy/data/_dataloader.py +28 -26
 - pertpy/data/_datasets.py +261 -92
 - pertpy/metadata/__init__.py +6 -0
 - pertpy/metadata/_cell_line.py +795 -0
 - pertpy/metadata/_compound.py +128 -0
 - pertpy/metadata/_drug.py +238 -0
 - pertpy/metadata/_look_up.py +569 -0
 - pertpy/metadata/_metadata.py +70 -0
 - pertpy/metadata/_moa.py +125 -0
 - pertpy/plot/__init__.py +0 -13
 - pertpy/preprocessing/__init__.py +2 -0
 - pertpy/preprocessing/_guide_rna.py +89 -6
 - pertpy/tools/__init__.py +48 -15
 - pertpy/tools/_augur.py +329 -32
 - pertpy/tools/_cinemaot.py +145 -6
 - pertpy/tools/_coda/_base_coda.py +1237 -116
 - pertpy/tools/_coda/_sccoda.py +66 -36
 - pertpy/tools/_coda/_tasccoda.py +46 -39
 - pertpy/tools/_dialogue.py +180 -77
 - pertpy/tools/_differential_gene_expression/__init__.py +20 -0
 - pertpy/tools/_differential_gene_expression/_base.py +657 -0
 - pertpy/tools/_differential_gene_expression/_checks.py +41 -0
 - pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
 - pertpy/tools/_differential_gene_expression/_edger.py +125 -0
 - pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
 - pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
 - pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
 - pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
 - pertpy/tools/_distances/_distance_tests.py +29 -24
 - pertpy/tools/_distances/_distances.py +584 -98
 - pertpy/tools/_enrichment.py +460 -0
 - pertpy/tools/_kernel_pca.py +1 -1
 - pertpy/tools/_milo.py +406 -49
 - pertpy/tools/_mixscape.py +677 -55
 - pertpy/tools/_perturbation_space/_clustering.py +10 -3
 - pertpy/tools/_perturbation_space/_comparison.py +112 -0
 - pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
 - pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
 - pertpy/tools/_perturbation_space/_simple.py +52 -11
 - pertpy/tools/_scgen/__init__.py +1 -1
 - pertpy/tools/_scgen/_base_components.py +2 -3
 - pertpy/tools/_scgen/_scgen.py +706 -0
 - pertpy/tools/_scgen/_utils.py +3 -5
 - pertpy/tools/decoupler_LICENSE +674 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
 - pertpy-0.8.0.dist-info/RECORD +57 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
 - pertpy/plot/_augur.py +0 -234
 - pertpy/plot/_cinemaot.py +0 -81
 - pertpy/plot/_coda.py +0 -1001
 - pertpy/plot/_dialogue.py +0 -91
 - pertpy/plot/_guide_rna.py +0 -82
 - pertpy/plot/_milopy.py +0 -284
 - pertpy/plot/_mixscape.py +0 -594
 - pertpy/plot/_scgen.py +0 -337
 - pertpy/tools/_differential_gene_expression.py +0 -99
 - pertpy/tools/_metadata/__init__.py +0 -0
 - pertpy/tools/_metadata/_cell_line.py +0 -613
 - pertpy/tools/_metadata/_look_up.py +0 -342
 - pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
 - pertpy/tools/_scgen/_jax_scgen.py +0 -370
 - pertpy-0.6.0.dist-info/RECORD +0 -50
 - /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
 
    
        pertpy/tools/_dialogue.py
    CHANGED
    
    | 
         @@ -2,27 +2,33 @@ from __future__ import annotations 
     | 
|
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            import itertools
         
     | 
| 
       4 
4 
     | 
    
         
             
            from collections import defaultdict
         
     | 
| 
       5 
     | 
    
         
            -
            from typing import Any, Literal
         
     | 
| 
      
 5 
     | 
    
         
            +
            from typing import TYPE_CHECKING, Any, Literal
         
     | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
7 
     | 
    
         
             
            import anndata as ad
         
     | 
| 
      
 8 
     | 
    
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 
       8 
9 
     | 
    
         
             
            import numpy as np
         
     | 
| 
       9 
10 
     | 
    
         
             
            import pandas as pd
         
     | 
| 
       10 
11 
     | 
    
         
             
            import scanpy as sc
         
     | 
| 
       11 
     | 
    
         
            -
            import  
     | 
| 
      
 12 
     | 
    
         
            +
            import seaborn as sns
         
     | 
| 
       12 
13 
     | 
    
         
             
            import statsmodels.formula.api as smf
         
     | 
| 
       13 
14 
     | 
    
         
             
            import statsmodels.stats.multitest as ssm
         
     | 
| 
       14 
15 
     | 
    
         
             
            from anndata import AnnData
         
     | 
| 
      
 16 
     | 
    
         
            +
            from lamin_utils import logger
         
     | 
| 
       15 
17 
     | 
    
         
             
            from pandas import DataFrame
         
     | 
| 
       16 
     | 
    
         
            -
            from rich import print
         
     | 
| 
       17 
18 
     | 
    
         
             
            from rich.console import Group
         
     | 
| 
       18 
19 
     | 
    
         
             
            from rich.live import Live
         
     | 
| 
       19 
20 
     | 
    
         
             
            from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
         
     | 
| 
       20 
21 
     | 
    
         
             
            from scipy import stats
         
     | 
| 
       21 
22 
     | 
    
         
             
            from scipy.optimize import nnls
         
     | 
| 
      
 23 
     | 
    
         
            +
            from seaborn import PairGrid
         
     | 
| 
       22 
24 
     | 
    
         
             
            from sklearn.linear_model import LinearRegression
         
     | 
| 
       23 
25 
     | 
    
         
             
            from sparsecca import lp_pmd, multicca_permute, multicca_pmd
         
     | 
| 
       24 
26 
     | 
    
         
             
            from statsmodels.sandbox.stats.multicomp import multipletests
         
     | 
| 
       25 
27 
     | 
    
         | 
| 
      
 28 
     | 
    
         
            +
            if TYPE_CHECKING:
         
     | 
| 
      
 29 
     | 
    
         
            +
                from matplotlib.axes import Axes
         
     | 
| 
      
 30 
     | 
    
         
            +
                from matplotlib.figure import Figure
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
       26 
32 
     | 
    
         | 
| 
       27 
33 
     | 
    
         
             
            class Dialogue:
         
     | 
| 
       28 
34 
     | 
    
         
             
                """Python implementation of DIALOGUE"""
         
     | 
| 
         @@ -53,8 +59,6 @@ class Dialogue: 
     | 
|
| 
       53 
59 
     | 
    
         | 
| 
       54 
60 
     | 
    
         
             
                    Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
         
     | 
| 
       55 
61 
     | 
    
         | 
| 
       56 
     | 
    
         
            -
                    # TODO: Replace with decoupler's implementation
         
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
62 
     | 
    
         
             
                    Args:
         
     | 
| 
       59 
63 
     | 
    
         
             
                        groupby: The key to groupby for pseudobulks
         
     | 
| 
       60 
64 
     | 
    
         
             
                        strategy: The pseudobulking strategy. One of "median" or "mean"
         
     | 
| 
         @@ -62,14 +66,15 @@ class Dialogue: 
     | 
|
| 
       62 
66 
     | 
    
         
             
                    Returns:
         
     | 
| 
       63 
67 
     | 
    
         
             
                        A Pandas DataFrame of pseudobulk counts
         
     | 
| 
       64 
68 
     | 
    
         
             
                    """
         
     | 
| 
      
 69 
     | 
    
         
            +
                    # TODO: Replace with decoupler's implementation
         
     | 
| 
       65 
70 
     | 
    
         
             
                    pseudobulk = {"Genes": adata.var_names.values}
         
     | 
| 
       66 
71 
     | 
    
         | 
| 
       67 
72 
     | 
    
         
             
                    for category in adata.obs.loc[:, groupby].cat.categories:
         
     | 
| 
       68 
73 
     | 
    
         
             
                        temp = adata.obs.loc[:, groupby] == category
         
     | 
| 
       69 
74 
     | 
    
         
             
                        if strategy == "median":
         
     | 
| 
       70 
     | 
    
         
            -
                            pseudobulk[category] = adata[temp].X.median(axis=0) 
     | 
| 
      
 75 
     | 
    
         
            +
                            pseudobulk[category] = adata[temp].X.median(axis=0)
         
     | 
| 
       71 
76 
     | 
    
         
             
                        elif strategy == "mean":
         
     | 
| 
       72 
     | 
    
         
            -
                            pseudobulk[category] = adata[temp].X.mean(axis=0) 
     | 
| 
      
 77 
     | 
    
         
            +
                            pseudobulk[category] = adata[temp].X.mean(axis=0)
         
     | 
| 
       73 
78 
     | 
    
         | 
| 
       74 
79 
     | 
    
         
             
                    pseudobulk = pd.DataFrame(pseudobulk).set_index("Genes")
         
     | 
| 
       75 
80 
     | 
    
         | 
| 
         @@ -101,8 +106,6 @@ class Dialogue: 
     | 
|
| 
       101 
106 
     | 
    
         
             
                def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
         
     | 
| 
       102 
107 
     | 
    
         
             
                    """Row-wise mean center and scale by the standard deviation.
         
     | 
| 
       103 
108 
     | 
    
         | 
| 
       104 
     | 
    
         
            -
                    TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
         
     | 
| 
       105 
     | 
    
         
            -
             
     | 
| 
       106 
109 
     | 
    
         
             
                    Args:
         
     | 
| 
       107 
110 
     | 
    
         
             
                        pseudobulks: The pseudobulk PCA components.
         
     | 
| 
       108 
111 
     | 
    
         
             
                        normalize: Whether to mimic DIALOGUE behavior or not.
         
     | 
| 
         @@ -110,9 +113,9 @@ class Dialogue: 
     | 
|
| 
       110 
113 
     | 
    
         
             
                    Returns:
         
     | 
| 
       111 
114 
     | 
    
         
             
                        The scaled count matrix.
         
     | 
| 
       112 
115 
     | 
    
         
             
                    """
         
     | 
| 
      
 116 
     | 
    
         
            +
                    # TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
         
     | 
| 
       113 
117 
     | 
    
         
             
                    # DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
         
     | 
| 
       114 
118 
     | 
    
         
             
                    # However, performing this scaling _does_ increase overall correlation of the end result
         
     | 
| 
       115 
     | 
    
         
            -
                    # WHEN SAMPLE ORDER AND DIALOGUE2+3 PROCESSING IS IGNORED.
         
     | 
| 
       116 
119 
     | 
    
         
             
                    if normalize:
         
     | 
| 
       117 
120 
     | 
    
         
             
                        return pseudobulks.to_numpy()
         
     | 
| 
       118 
121 
     | 
    
         
             
                    else:
         
     | 
| 
         @@ -288,7 +291,7 @@ class Dialogue: 
     | 
|
| 
       288 
291 
     | 
    
         
             
                        mcp_name: Name of mcp which was used for calculation of column value.
         
     | 
| 
       289 
292 
     | 
    
         
             
                        max_length: Value needed to later decide at what index the threshold value should be extracted from column.
         
     | 
| 
       290 
293 
     | 
    
         
             
                        min_threshold: Minimal threshold to select final scores by if it is smaller than calculated threshold.
         
     | 
| 
       291 
     | 
    
         
            -
                        index: Column index to use eto calculate the significant genes. 
     | 
| 
      
 294 
     | 
    
         
            +
                        index: Column index to use eto calculate the significant genes.
         
     | 
| 
       292 
295 
     | 
    
         | 
| 
       293 
296 
     | 
    
         
             
                    Returns:
         
     | 
| 
       294 
297 
     | 
    
         
             
                        According to the values in a df column (default: zscore) the significant up and downregulated gene names
         
     | 
| 
         @@ -313,13 +316,13 @@ class Dialogue: 
     | 
|
| 
       313 
316 
     | 
    
         
             
                def _apply_HLM_per_MCP_for_one_pair(
         
     | 
| 
       314 
317 
     | 
    
         
             
                    self,
         
     | 
| 
       315 
318 
     | 
    
         
             
                    mcp_name: str,
         
     | 
| 
       316 
     | 
    
         
            -
                    scores_df:  
     | 
| 
      
 319 
     | 
    
         
            +
                    scores_df: pd.DataFrame,
         
     | 
| 
       317 
320 
     | 
    
         
             
                    ct_data: AnnData,
         
     | 
| 
       318 
321 
     | 
    
         
             
                    tme: pd.DataFrame,
         
     | 
| 
       319 
322 
     | 
    
         
             
                    sig: dict,
         
     | 
| 
       320 
323 
     | 
    
         
             
                    n_counts: str,
         
     | 
| 
       321 
324 
     | 
    
         
             
                    formula: str,
         
     | 
| 
       322 
     | 
    
         
            -
                    confounder: str,
         
     | 
| 
      
 325 
     | 
    
         
            +
                    confounder: str | None,
         
     | 
| 
       323 
326 
     | 
    
         
             
                ) -> tuple[pd.DataFrame, dict[str, Any]]:
         
     | 
| 
       324 
327 
     | 
    
         
             
                    """Applies hierarchical modeling for a single MCP.
         
     | 
| 
       325 
328 
     | 
    
         | 
| 
         @@ -340,7 +343,7 @@ class Dialogue: 
     | 
|
| 
       340 
343 
     | 
    
         
             
                    """
         
     | 
| 
       341 
344 
     | 
    
         
             
                    HLM_result = self._mixed_effects(
         
     | 
| 
       342 
345 
     | 
    
         
             
                        scores=scores_df[[mcp_name]],
         
     | 
| 
       343 
     | 
    
         
            -
                        x_labels=ct_data.obs[[n_counts, confounder]],
         
     | 
| 
      
 346 
     | 
    
         
            +
                        x_labels=ct_data.obs[[n_counts, confounder]] if confounder else ct_data.obs[[n_counts]],
         
     | 
| 
       344 
347 
     | 
    
         
             
                        tme=tme,
         
     | 
| 
       345 
348 
     | 
    
         
             
                        genes_in_mcp=list(sig[mcp_name]["up"]) + list(sig[mcp_name]["down"]),
         
     | 
| 
       346 
349 
     | 
    
         
             
                        formula=formula,
         
     | 
| 
         @@ -367,19 +370,13 @@ class Dialogue: 
     | 
|
| 
       367 
370 
     | 
    
         
             
                    return np.array(resid)
         
     | 
| 
       368 
371 
     | 
    
         | 
| 
       369 
372 
     | 
    
         
             
                def _iterative_nnls(self, A_orig: np.ndarray, y_orig: np.ndarray, feature_ranks: list[int], n_iter: int = 1000):
         
     | 
| 
       370 
     | 
    
         
            -
                    """Solves non-negative least 
     | 
| 
      
 373 
     | 
    
         
            +
                    """Solves non-negative least-squares separately for different feature categories.
         
     | 
| 
       371 
374 
     | 
    
         | 
| 
       372 
375 
     | 
    
         
             
                    Mimics DLG.iterative.nnls.
         
     | 
| 
       373 
376 
     | 
    
         
             
                    Variables are notated according to:
         
     | 
| 
       374 
377 
     | 
    
         | 
| 
       375 
378 
     | 
    
         
             
                        `argmin|Ax - y|`
         
     | 
| 
       376 
379 
     | 
    
         | 
| 
       377 
     | 
    
         
            -
                    Args:
         
     | 
| 
       378 
     | 
    
         
            -
                        A_orig:
         
     | 
| 
       379 
     | 
    
         
            -
                        y_orig:
         
     | 
| 
       380 
     | 
    
         
            -
                        feature_ranks:
         
     | 
| 
       381 
     | 
    
         
            -
                        n_iter: Passed to scipy.optimize.nnls. Defaults to 1000.
         
     | 
| 
       382 
     | 
    
         
            -
             
     | 
| 
       383 
380 
     | 
    
         
             
                    Returns:
         
     | 
| 
       384 
381 
     | 
    
         
             
                        Returns the aggregated coefficients from nnls.
         
     | 
| 
       385 
382 
     | 
    
         
             
                    """
         
     | 
| 
         @@ -398,7 +395,7 @@ class Dialogue: 
     | 
|
| 
       398 
395 
     | 
    
         | 
| 
       399 
396 
     | 
    
         
             
                    x_final = np.zeros(A_orig.shape[0])
         
     | 
| 
       400 
397 
     | 
    
         
             
                    Ax = np.zeros(A_orig.shape[1])
         
     | 
| 
       401 
     | 
    
         
            -
                    for _, mask in zip(sig_ranks, masks):
         
     | 
| 
      
 398 
     | 
    
         
            +
                    for _, mask in zip(sig_ranks, masks, strict=False):
         
     | 
| 
       402 
399 
     | 
    
         
             
                        A = A_orig[mask].T
         
     | 
| 
       403 
400 
     | 
    
         
             
                        coef_nnls, _ = nnls(A, y, maxiter=n_iter)
         
     | 
| 
       404 
401 
     | 
    
         
             
                        y = y - A @ coef_nnls  # residuals
         
     | 
| 
         @@ -516,8 +513,8 @@ class Dialogue: 
     | 
|
| 
       516 
513 
     | 
    
         
             
                        # TODO: probably format the up and down within get_top_elements
         
     | 
| 
       517 
514 
     | 
    
         
             
                        cca_sig: dict[str, Any] = defaultdict(dict)
         
     | 
| 
       518 
515 
     | 
    
         
             
                        for i in range(0, int(len(cca_sig_unformatted) / 2)):
         
     | 
| 
       519 
     | 
    
         
            -
                            cca_sig[f"MCP{i 
     | 
| 
       520 
     | 
    
         
            -
                            cca_sig[f"MCP{i 
     | 
| 
      
 516 
     | 
    
         
            +
                            cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
         
     | 
| 
      
 517 
     | 
    
         
            +
                            cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
         
     | 
| 
       521 
518 
     | 
    
         | 
| 
       522 
519 
     | 
    
         
             
                        cca_sig = dict(cca_sig)
         
     | 
| 
       523 
520 
     | 
    
         
             
                        cca_sig_results[ct] = cca_sig
         
     | 
| 
         @@ -555,7 +552,7 @@ class Dialogue: 
     | 
|
| 
       555 
552 
     | 
    
         | 
| 
       556 
553 
     | 
    
         
             
                    return cca_sig_results, new_mcp_scores
         
     | 
| 
       557 
554 
     | 
    
         | 
| 
       558 
     | 
    
         
            -
                def  
     | 
| 
      
 555 
     | 
    
         
            +
                def _load(
         
     | 
| 
       559 
556 
     | 
    
         
             
                    self,
         
     | 
| 
       560 
557 
     | 
    
         
             
                    adata: AnnData,
         
     | 
| 
       561 
558 
     | 
    
         
             
                    ct_order: list[str],
         
     | 
| 
         @@ -569,21 +566,11 @@ class Dialogue: 
     | 
|
| 
       569 
566 
     | 
    
         
             
                    Args:
         
     | 
| 
       570 
567 
     | 
    
         
             
                        adata: AnnData object generate celltype objects for
         
     | 
| 
       571 
568 
     | 
    
         
             
                        ct_order: The order of cell types
         
     | 
| 
       572 
     | 
    
         
            -
                        agg_pca: Whether to aggregate pseudobulks with PCA or not. 
     | 
| 
       573 
     | 
    
         
            -
                        normalize: Whether to mimic DIALOGUE behavior or not. 
     | 
| 
      
 569 
     | 
    
         
            +
                        agg_pca: Whether to aggregate pseudobulks with PCA or not.
         
     | 
| 
      
 570 
     | 
    
         
            +
                        normalize: Whether to mimic DIALOGUE behavior or not.
         
     | 
| 
       574 
571 
     | 
    
         | 
| 
       575 
572 
     | 
    
         
             
                    Returns:
         
     | 
| 
       576 
573 
     | 
    
         
             
                        A celltype_label:array dictionary.
         
     | 
| 
       577 
     | 
    
         
            -
             
     | 
| 
       578 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       579 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       580 
     | 
    
         
            -
                        >>> import scanpy as sc
         
     | 
| 
       581 
     | 
    
         
            -
                        >>> adata = pt.dt.dialogue_example()
         
     | 
| 
       582 
     | 
    
         
            -
                        >>> sc.pp.pca(adata)
         
     | 
| 
       583 
     | 
    
         
            -
                        >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
         
     | 
| 
       584 
     | 
    
         
            -
                            n_counts_key = "nCount_RNA", n_mpcs = 3)
         
     | 
| 
       585 
     | 
    
         
            -
                        >>> cell_types = adata.obs[dl.celltype_key].astype("category").cat.categories
         
     | 
| 
       586 
     | 
    
         
            -
                        >>> mcca_in, ct_subs = dl.load(adata, ct_order=cell_types)
         
     | 
| 
       587 
574 
     | 
    
         
             
                    """
         
     | 
| 
       588 
575 
     | 
    
         
             
                    ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
         
     | 
| 
       589 
576 
     | 
    
         
             
                    fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
         
     | 
| 
         @@ -620,7 +607,6 @@ class Dialogue: 
     | 
|
| 
       620 
607 
     | 
    
         
             
                        agg_pca: Whether to calculate cell-averaged PCA components.
         
     | 
| 
       621 
608 
     | 
    
         
             
                        solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
         
     | 
| 
       622 
609 
     | 
    
         
             
                                For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
         
     | 
| 
       623 
     | 
    
         
            -
                                Defaults to 'bs'.
         
     | 
| 
       624 
610 
     | 
    
         
             
                        normalize: Whether to mimic DIALOGUE as close as possible
         
     | 
| 
       625 
611 
     | 
    
         | 
| 
       626 
612 
     | 
    
         
             
                    Returns:
         
     | 
| 
         @@ -631,25 +617,31 @@ class Dialogue: 
     | 
|
| 
       631 
617 
     | 
    
         
             
                        >>> import scanpy as sc
         
     | 
| 
       632 
618 
     | 
    
         
             
                        >>> adata = pt.dt.dialogue_example()
         
     | 
| 
       633 
619 
     | 
    
         
             
                        >>> sc.pp.pca(adata)
         
     | 
| 
       634 
     | 
    
         
            -
                        >>> dl = pt.tl.Dialogue( 
     | 
| 
       635 
     | 
    
         
            -
             
     | 
| 
      
 620 
     | 
    
         
            +
                        >>> dl = pt.tl.Dialogue(
         
     | 
| 
      
 621 
     | 
    
         
            +
                        ...     sample_id="clinical.status", celltype_key="cell.subtypes", n_counts_key="nCount_RNA", n_mpcs=3
         
     | 
| 
      
 622 
     | 
    
         
            +
                        ... )
         
     | 
| 
       636 
623 
     | 
    
         
             
                        >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
         
     | 
| 
       637 
624 
     | 
    
         
             
                    """
         
     | 
| 
       638 
     | 
    
         
            -
                    # IMPORTANT NOTE: the order in which matrices are passed to multicca matters. 
     | 
| 
       639 
     | 
    
         
            -
                    # it is important here that to obtain the same result as in R, we pass the matrices in
         
     | 
| 
       640 
     | 
    
         
            -
                    # in the same order.
         
     | 
| 
      
 625 
     | 
    
         
            +
                    # IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
         
     | 
| 
      
 626 
     | 
    
         
            +
                    # As such, it is important here that to obtain the same result as in R, we pass the matrices in the same order.
         
     | 
| 
       641 
627 
     | 
    
         
             
                    if ct_order is not None:
         
     | 
| 
       642 
628 
     | 
    
         
             
                        cell_types = ct_order
         
     | 
| 
       643 
629 
     | 
    
         
             
                    else:
         
     | 
| 
       644 
630 
     | 
    
         
             
                        ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
         
     | 
| 
       645 
631 
     | 
    
         | 
| 
       646 
     | 
    
         
            -
                    mcca_in, ct_subs = self. 
     | 
| 
      
 632 
     | 
    
         
            +
                    mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
         
     | 
| 
       647 
633 
     | 
    
         | 
| 
       648 
634 
     | 
    
         
             
                    n_samples = mcca_in[0].shape[1]
         
     | 
| 
       649 
635 
     | 
    
         
             
                    if penalties is None:
         
     | 
| 
       650 
     | 
    
         
            -
                         
     | 
| 
       651 
     | 
    
         
            -
                             
     | 
| 
       652 
     | 
    
         
            -
             
     | 
| 
      
 636 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 637 
     | 
    
         
            +
                            penalties = multicca_permute(
         
     | 
| 
      
 638 
     | 
    
         
            +
                                mcca_in, penalties=np.sqrt(n_samples) / 2, nperms=10, niter=50, standardize=True
         
     | 
| 
      
 639 
     | 
    
         
            +
                            )["bestpenalties"]
         
     | 
| 
      
 640 
     | 
    
         
            +
                        except ValueError as e:
         
     | 
| 
      
 641 
     | 
    
         
            +
                            if "matmul: input operand 1 has a mismatch in its core dimension" in str(e):
         
     | 
| 
      
 642 
     | 
    
         
            +
                                raise ValueError("Please ensure that every cell type is represented in every sample.") from e
         
     | 
| 
      
 643 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 644 
     | 
    
         
            +
                                raise
         
     | 
| 
       653 
645 
     | 
    
         
             
                    else:
         
     | 
| 
       654 
646 
     | 
    
         
             
                        penalties = penalties
         
     | 
| 
       655 
647 
     | 
    
         | 
| 
         @@ -685,7 +677,7 @@ class Dialogue: 
     | 
|
| 
       685 
677 
     | 
    
         
             
                    ct_subs: dict,
         
     | 
| 
       686 
678 
     | 
    
         
             
                    mcp_scores: dict,
         
     | 
| 
       687 
679 
     | 
    
         
             
                    ws_dict: dict,
         
     | 
| 
       688 
     | 
    
         
            -
                    confounder: str,
         
     | 
| 
      
 680 
     | 
    
         
            +
                    confounder: str | None,
         
     | 
| 
       689 
681 
     | 
    
         
             
                    formula: str = None,
         
     | 
| 
       690 
682 
     | 
    
         
             
                ):
         
     | 
| 
       691 
683 
     | 
    
         
             
                    """Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
         
     | 
| 
         @@ -700,7 +692,6 @@ class Dialogue: 
     | 
|
| 
       700 
692 
     | 
    
         
             
                        A Pandas DataFrame containing:
         
     | 
| 
       701 
693 
     | 
    
         
             
                        - for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
         
     | 
| 
       702 
694 
     | 
    
         
             
                        - merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
         
     | 
| 
       703 
     | 
    
         
            -
                        TODO: Describe both returns
         
     | 
| 
       704 
695 
     | 
    
         | 
| 
       705 
696 
     | 
    
         
             
                    Examples:
         
     | 
| 
       706 
697 
     | 
    
         
             
                        >>> import pertpy as pt
         
     | 
| 
         @@ -713,7 +704,9 @@ class Dialogue: 
     | 
|
| 
       713 
704 
     | 
    
         
             
                        >>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \
         
     | 
| 
       714 
705 
     | 
    
         
             
                            confounder="gender")
         
     | 
| 
       715 
706 
     | 
    
         
             
                    """
         
     | 
| 
       716 
     | 
    
         
            -
                    #  
     | 
| 
      
 707 
     | 
    
         
            +
                    # TODO the returns of the function better
         
     | 
| 
      
 708 
     | 
    
         
            +
             
     | 
| 
      
 709 
     | 
    
         
            +
                    # all possible pairs of cell types without pairing same cell type
         
     | 
| 
       717 
710 
     | 
    
         
             
                    cell_types = list(ct_subs.keys())
         
     | 
| 
       718 
711 
     | 
    
         
             
                    pairs = list(itertools.combinations(cell_types, 2))
         
     | 
| 
       719 
712 
     | 
    
         | 
| 
         @@ -721,9 +714,9 @@ class Dialogue: 
     | 
|
| 
       721 
714 
     | 
    
         
             
                        formula = f"y ~ x + {self.n_counts_key}"
         
     | 
| 
       722 
715 
     | 
    
         | 
| 
       723 
716 
     | 
    
         
             
                    # Hierarchical modeling expects DataFrames
         
     | 
| 
       724 
     | 
    
         
            -
                    mcp_cell_types = {f"MCP{i 
     | 
| 
      
 717 
     | 
    
         
            +
                    mcp_cell_types = {f"MCP{i}": cell_types for i in range(self.n_mcps)}
         
     | 
| 
       725 
718 
     | 
    
         
             
                    mcp_scores_df = {
         
     | 
| 
       726 
     | 
    
         
            -
                        ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=mcp_cell_types.keys())
         
     | 
| 
      
 719 
     | 
    
         
            +
                        ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=list(mcp_cell_types.keys()))
         
     | 
| 
       727 
720 
     | 
    
         
             
                        for ct, v in mcp_scores.items()
         
     | 
| 
       728 
721 
     | 
    
         
             
                    }
         
     | 
| 
       729 
722 
     | 
    
         | 
| 
         @@ -762,10 +755,10 @@ class Dialogue: 
     | 
|
| 
       762 
755 
     | 
    
         
             
                                    mcps.append(mcp)
         
     | 
| 
       763 
756 
     | 
    
         | 
| 
       764 
757 
     | 
    
         
             
                            if len(mcps) == 0:
         
     | 
| 
       765 
     | 
    
         
            -
                                 
     | 
| 
      
 758 
     | 
    
         
            +
                                logger.warning(f"No shared MCPs between {cell_type_1} and {cell_type_2}.")
         
     | 
| 
       766 
759 
     | 
    
         
             
                                continue
         
     | 
| 
       767 
760 
     | 
    
         | 
| 
       768 
     | 
    
         
            -
                             
     | 
| 
      
 761 
     | 
    
         
            +
                            logger.info(f"{len(mcps)} MCPs identified for {cell_type_1} and {cell_type_2}.")
         
     | 
| 
       769 
762 
     | 
    
         | 
| 
       770 
763 
     | 
    
         
             
                            new_mcp_scores: dict[Any, list[Any]]
         
     | 
| 
       771 
764 
     | 
    
         
             
                            cca_sig, new_mcp_scores = self._calculate_cca_sig(
         
     | 
| 
         @@ -805,7 +798,7 @@ class Dialogue: 
     | 
|
| 
       805 
798 
     | 
    
         
             
                            for mcp in mcps:
         
     | 
| 
       806 
799 
     | 
    
         
             
                                mixed_model_progress.update(mm_task, description=f"[bold blue]Determining mixed effects for {mcp}")
         
     | 
| 
       807 
800 
     | 
    
         | 
| 
       808 
     | 
    
         
            -
                                # TODO Check  
     | 
| 
      
 801 
     | 
    
         
            +
                                # TODO Check whether the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
         
     | 
| 
       809 
802 
     | 
    
         
             
                                result = {}
         
     | 
| 
       810 
803 
     | 
    
         
             
                                result["HLM_result_1"], result["sig_genes_1"] = self._apply_HLM_per_MCP_for_one_pair(
         
     | 
| 
       811 
804 
     | 
    
         
             
                                    mcp_name=mcp,
         
     | 
| 
         @@ -875,22 +868,19 @@ class Dialogue: 
     | 
|
| 
       875 
868 
     | 
    
         
             
                    sample_label = self.sample_id
         
     | 
| 
       876 
869 
     | 
    
         
             
                    n_mcps = self.n_mcps
         
     | 
| 
       877 
870 
     | 
    
         | 
| 
       878 
     | 
    
         
            -
                    # create conditions_compare if not supplied
         
     | 
| 
       879 
871 
     | 
    
         
             
                    if conditions_compare is None:
         
     | 
| 
       880 
     | 
    
         
            -
                        conditions_compare = list(adata.obs[ 
     | 
| 
      
 872 
     | 
    
         
            +
                        conditions_compare = list(adata.obs[condition_label].cat.categories)  # type: ignore
         
     | 
| 
       881 
873 
     | 
    
         
             
                        if len(conditions_compare) != 2:
         
     | 
| 
       882 
874 
     | 
    
         
             
                            raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
         
     | 
| 
       883 
875 
     | 
    
         | 
| 
       884 
     | 
    
         
            -
                    # create data frames to store results
         
     | 
| 
       885 
876 
     | 
    
         
             
                    pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
         
     | 
| 
       886 
877 
     | 
    
         
             
                    tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
         
     | 
| 
       887 
878 
     | 
    
         
             
                    pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
         
     | 
| 
       888 
879 
     | 
    
         | 
| 
       889 
880 
     | 
    
         
             
                    response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
         
     | 
| 
       890 
881 
     | 
    
         
             
                    for celltype in adata.obs[celltype_label].unique():
         
     | 
| 
       891 
     | 
    
         
            -
                        # subset data to cell type
         
     | 
| 
       892 
882 
     | 
    
         
             
                        df = adata.obs[adata.obs[celltype_label] == celltype]
         
     | 
| 
       893 
     | 
    
         
            -
             
     | 
| 
      
 883 
     | 
    
         
            +
             
     | 
| 
       894 
884 
     | 
    
         
             
                        for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
         
     | 
| 
       895 
885 
     | 
    
         
             
                            mns = df.groupby(sample_label)[mcpnum].mean()
         
     | 
| 
       896 
886 
     | 
    
         
             
                            mns = pd.concat([mns, response], axis=1)
         
     | 
| 
         @@ -900,11 +890,10 @@ class Dialogue: 
     | 
|
| 
       900 
890 
     | 
    
         
             
                            )
         
     | 
| 
       901 
891 
     | 
    
         
             
                            pvals.loc[celltype, mcpnum] = res[1]
         
     | 
| 
       902 
892 
     | 
    
         
             
                            tstats.loc[celltype, mcpnum] = res[0]
         
     | 
| 
       903 
     | 
    
         
            -
                            # return(res)
         
     | 
| 
       904 
893 
     | 
    
         | 
| 
       905 
     | 
    
         
            -
                    # benjamini-hochberg correction for number of cell types (use BH because correlated MCPs)
         
     | 
| 
       906 
894 
     | 
    
         
             
                    for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
         
     | 
| 
       907 
895 
     | 
    
         
             
                        pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
         
     | 
| 
      
 896 
     | 
    
         
            +
             
     | 
| 
       908 
897 
     | 
    
         
             
                    return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
         
     | 
| 
       909 
898 
     | 
    
         | 
| 
       910 
899 
     | 
    
         
             
                def get_mlm_mcp_genes(
         
     | 
| 
         @@ -921,10 +910,8 @@ class Dialogue: 
     | 
|
| 
       921 
910 
     | 
    
         
             
                        celltype: Cell type of interest.
         
     | 
| 
       922 
911 
     | 
    
         
             
                        results: dl.MultilevelModeling result object.
         
     | 
| 
       923 
912 
     | 
    
         
             
                        MCP: MCP key of the result object.
         
     | 
| 
       924 
     | 
    
         
            -
                         
     | 
| 
       925 
     | 
    
         
            -
                                    Defaults to 0.70.
         
     | 
| 
      
 913 
     | 
    
         
            +
                        threshold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
         
     | 
| 
       926 
914 
     | 
    
         
             
                        focal_celltypes: None (compare against all cell types) or a list of other cell types which you want to compare against.
         
     | 
| 
       927 
     | 
    
         
            -
                                         Defaults to None.
         
     | 
| 
       928 
915 
     | 
    
         | 
| 
       929 
916 
     | 
    
         
             
                    Returns:
         
     | 
| 
       930 
917 
     | 
    
         
             
                        Dict with keys 'up_genes' and 'down_genes' and values of lists of genes
         
     | 
| 
         @@ -945,7 +932,6 @@ class Dialogue: 
     | 
|
| 
       945 
932 
     | 
    
         
             
                    # REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD
         
     | 
| 
       946 
933 
     | 
    
         
             
                    if MCP.startswith("mcp_"):
         
     | 
| 
       947 
934 
     | 
    
         
             
                        MCP = MCP.replace("mcp_", "MCP")
         
     | 
| 
       948 
     | 
    
         
            -
                        # convert from MCPx to MCPx+1
         
     | 
| 
       949 
935 
     | 
    
         
             
                        MCP = "MCP" + str(int(MCP[3:]) - 1)
         
     | 
| 
       950 
936 
     | 
    
         | 
| 
       951 
937 
     | 
    
         
             
                    # Extract all comparison keys from the results object
         
     | 
| 
         @@ -1004,27 +990,24 @@ class Dialogue: 
     | 
|
| 
       1004 
990 
     | 
    
         
             
                    Args:
         
     | 
| 
       1005 
991 
     | 
    
         
             
                        ct_subs: Dialogue output ct_subs dictionary
         
     | 
| 
       1006 
992 
     | 
    
         
             
                        mcp: The name of the marker gene expression column.
         
     | 
| 
       1007 
     | 
    
         
            -
                             Defaults to "mcp_0".
         
     | 
| 
       1008 
993 
     | 
    
         
             
                        fraction: Fraction of extreme cells to consider for gene ranking.
         
     | 
| 
       1009 
994 
     | 
    
         
             
                                  Should be between 0 and 1.
         
     | 
| 
       1010 
     | 
    
         
            -
                                  Defaults to 0.1.
         
     | 
| 
       1011 
995 
     | 
    
         | 
| 
       1012 
996 
     | 
    
         
             
                    Returns:
         
     | 
| 
       1013 
997 
     | 
    
         
             
                        Dictionary where keys are subpopulation names and values are Anndata
         
     | 
| 
       1014 
998 
     | 
    
         
             
                        objects containing the results of gene ranking analysis.
         
     | 
| 
       1015 
999 
     | 
    
         | 
| 
       1016 
1000 
     | 
    
         
             
                    Examples:
         
     | 
| 
       1017 
     | 
    
         
            -
                        ct_subs = {
         
     | 
| 
       1018 
     | 
    
         
            -
                        "subpop1": anndata_obj1,
         
     | 
| 
       1019 
     | 
    
         
            -
                        "subpop2": anndata_obj2,
         
     | 
| 
       1020 
     | 
    
         
            -
                        # ... more subpopulations ...
         
     | 
| 
       1021 
     | 
    
         
            -
                        }
         
     | 
| 
       1022 
     | 
    
         
            -
                        genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
         
     | 
| 
      
 1001 
     | 
    
         
            +
                        >>> ct_subs = {
         
     | 
| 
      
 1002 
     | 
    
         
            +
                        ...     "subpop1": anndata_obj1,
         
     | 
| 
      
 1003 
     | 
    
         
            +
                        ...     "subpop2": anndata_obj2,
         
     | 
| 
      
 1004 
     | 
    
         
            +
                        ...     # ... more subpopulations ...
         
     | 
| 
      
 1005 
     | 
    
         
            +
                        ... }
         
     | 
| 
      
 1006 
     | 
    
         
            +
                        >>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
         
     | 
| 
       1023 
1007 
     | 
    
         
             
                    """
         
     | 
| 
       1024 
1008 
     | 
    
         
             
                    genes = {}
         
     | 
| 
       1025 
1009 
     | 
    
         
             
                    for ct in ct_subs.keys():
         
     | 
| 
       1026 
1010 
     | 
    
         
             
                        mini = ct_subs[ct]
         
     | 
| 
       1027 
     | 
    
         
            -
                        mini.obs[mcp]
         
     | 
| 
       1028 
1011 
     | 
    
         
             
                        mini.obs["extrema"] = pd.qcut(
         
     | 
| 
       1029 
1012 
     | 
    
         
             
                            mini.obs[mcp],
         
     | 
| 
       1030 
1013 
     | 
    
         
             
                            [0, 0 + fraction, 1 - fraction, 1.0],
         
     | 
| 
         @@ -1034,6 +1017,7 @@ class Dialogue: 
     | 
|
| 
       1034 
1017 
     | 
    
         
             
                            mini, "extrema", groups=["high" + mcp + " " + ct], reference="low " + mcp + " " + ct
         
     | 
| 
       1035 
1018 
     | 
    
         
             
                        )
         
     | 
| 
       1036 
1019 
     | 
    
         
             
                        genes[ct] = mini  # .uns['rank_genes_groups']
         
     | 
| 
      
 1020 
     | 
    
         
            +
             
     | 
| 
       1037 
1021 
     | 
    
         
             
                    return genes
         
     | 
| 
       1038 
1022 
     | 
    
         | 
| 
       1039 
1023 
     | 
    
         
             
                def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
         
     | 
| 
         @@ -1046,7 +1030,7 @@ class Dialogue: 
     | 
|
| 
       1046 
1030 
     | 
    
         
             
                    Args:
         
     | 
| 
       1047 
1031 
     | 
    
         
             
                        ct_subs: Dialogue output ct_subs dictionary
         
     | 
| 
       1048 
1032 
     | 
    
         
             
                        fraction: Fraction of extreme cells to consider for gene ranking.
         
     | 
| 
       1049 
     | 
    
         
            -
                                  Should be between 0 and 1. 
     | 
| 
      
 1033 
     | 
    
         
            +
                                  Should be between 0 and 1.
         
     | 
| 
       1050 
1034 
     | 
    
         | 
| 
       1051 
1035 
     | 
    
         
             
                    Returns:
         
     | 
| 
       1052 
1036 
     | 
    
         
             
                        Nested dictionary where keys of the first level are MCPs (of the form "mcp_0" etc)
         
     | 
| 
         @@ -1064,7 +1048,7 @@ class Dialogue: 
     | 
|
| 
       1064 
1048 
     | 
    
         
             
                        >>> extrema_mcp_genes = dl.get_extrema_MCP_genes(ct_subs)
         
     | 
| 
       1065 
1049 
     | 
    
         
             
                    """
         
     | 
| 
       1066 
1050 
     | 
    
         
             
                    rank_dfs: dict[str, dict[Any, Any]] = {}
         
     | 
| 
       1067 
     | 
    
         
            -
                     
     | 
| 
      
 1051 
     | 
    
         
            +
                    ct_sub = next(iter(ct_subs.values()))
         
     | 
| 
       1068 
1052 
     | 
    
         
             
                    mcps = [col for col in ct_sub.obs.columns if col.startswith("mcp_")]
         
     | 
| 
       1069 
1053 
     | 
    
         | 
| 
       1070 
1054 
     | 
    
         
             
                    for mcp in mcps:
         
     | 
| 
         @@ -1072,4 +1056,123 @@ class Dialogue: 
     | 
|
| 
       1072 
1056 
     | 
    
         
             
                        ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
         
     | 
| 
       1073 
1057 
     | 
    
         
             
                        for celltype in ct_ranked.keys():
         
     | 
| 
       1074 
1058 
     | 
    
         
             
                            rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
         
     | 
| 
      
 1059 
     | 
    
         
            +
             
     | 
| 
       1075 
1060 
     | 
    
         
             
                    return rank_dfs
         
     | 
| 
      
 1061 
     | 
    
         
            +
             
     | 
| 
      
 1062 
     | 
    
         
            +
                def plot_split_violins(
         
     | 
| 
      
 1063 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1064 
     | 
    
         
            +
                    adata: AnnData,
         
     | 
| 
      
 1065 
     | 
    
         
            +
                    split_key: str,
         
     | 
| 
      
 1066 
     | 
    
         
            +
                    celltype_key: str,
         
     | 
| 
      
 1067 
     | 
    
         
            +
                    split_which: tuple[str, str] = None,
         
     | 
| 
      
 1068 
     | 
    
         
            +
                    mcp: str = "mcp_0",
         
     | 
| 
      
 1069 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 1070 
     | 
    
         
            +
                    ax: Axes | None = None,
         
     | 
| 
      
 1071 
     | 
    
         
            +
                    save: bool | str | None = None,
         
     | 
| 
      
 1072 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 1073 
     | 
    
         
            +
                ) -> Axes | Figure | None:
         
     | 
| 
      
 1074 
     | 
    
         
            +
                    """Plots split violin plots for a given MCP and split variable.
         
     | 
| 
      
 1075 
     | 
    
         
            +
             
     | 
| 
      
 1076 
     | 
    
         
            +
                    Any cells with a value for split_key not in split_which are removed from the plot.
         
     | 
| 
      
 1077 
     | 
    
         
            +
             
     | 
| 
      
 1078 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 1079 
     | 
    
         
            +
                        adata: Annotated data object.
         
     | 
| 
      
 1080 
     | 
    
         
            +
                        split_key: Variable in adata.obs used to split the data.
         
     | 
| 
      
 1081 
     | 
    
         
            +
                        celltype_key: Key for cell type annotations.
         
     | 
| 
      
 1082 
     | 
    
         
            +
                        split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
         
     | 
| 
      
 1083 
     | 
    
         
            +
                        mcp: Key for MCP data.
         
     | 
| 
      
 1084 
     | 
    
         
            +
             
     | 
| 
      
 1085 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 1086 
     | 
    
         
            +
                        A :class:`~matplotlib.axes.Axes` object
         
     | 
| 
      
 1087 
     | 
    
         
            +
             
     | 
| 
      
 1088 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 1089 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 1090 
     | 
    
         
            +
                        >>> import scanpy as sc
         
     | 
| 
      
 1091 
     | 
    
         
            +
                        >>> adata = pt.dt.dialogue_example()
         
     | 
| 
      
 1092 
     | 
    
         
            +
                        >>> sc.pp.pca(adata)
         
     | 
| 
      
 1093 
     | 
    
         
            +
                        >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
         
     | 
| 
      
 1094 
     | 
    
         
            +
                            n_counts_key = "nCount_RNA", n_mpcs = 3)
         
     | 
| 
      
 1095 
     | 
    
         
            +
                        >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
         
     | 
| 
      
 1096 
     | 
    
         
            +
                        >>> dl.plot_split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
         
     | 
| 
      
 1097 
     | 
    
         
            +
             
     | 
| 
      
 1098 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 1099 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/dialogue_violin.png
         
     | 
| 
      
 1100 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1101 
     | 
    
         
            +
                    df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
         
     | 
| 
      
 1102 
     | 
    
         
            +
                    if split_which is None:
         
     | 
| 
      
 1103 
     | 
    
         
            +
                        split_which = df[split_key].unique()
         
     | 
| 
      
 1104 
     | 
    
         
            +
                    df = df[df[split_key].isin(split_which)]
         
     | 
| 
      
 1105 
     | 
    
         
            +
                    df[split_key] = df[split_key].cat.remove_unused_categories()
         
     | 
| 
      
 1106 
     | 
    
         
            +
             
     | 
| 
      
 1107 
     | 
    
         
            +
                    ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
         
     | 
| 
      
 1108 
     | 
    
         
            +
             
     | 
| 
      
 1109 
     | 
    
         
            +
                    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
         
     | 
| 
      
 1110 
     | 
    
         
            +
             
     | 
| 
      
 1111 
     | 
    
         
            +
                    if save:
         
     | 
| 
      
 1112 
     | 
    
         
            +
                        plt.savefig(save, bbox_inches="tight")
         
     | 
| 
      
 1113 
     | 
    
         
            +
                    if show:
         
     | 
| 
      
 1114 
     | 
    
         
            +
                        plt.show()
         
     | 
| 
      
 1115 
     | 
    
         
            +
                    if return_fig:
         
     | 
| 
      
 1116 
     | 
    
         
            +
                        return plt.gcf()
         
     | 
| 
      
 1117 
     | 
    
         
            +
                    if not (show or save):
         
     | 
| 
      
 1118 
     | 
    
         
            +
                        return ax
         
     | 
| 
      
 1119 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 1120 
     | 
    
         
            +
             
     | 
| 
      
 1121 
     | 
    
         
            +
                def plot_pairplot(
         
     | 
| 
      
 1122 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1123 
     | 
    
         
            +
                    adata: AnnData,
         
     | 
| 
      
 1124 
     | 
    
         
            +
                    celltype_key: str,
         
     | 
| 
      
 1125 
     | 
    
         
            +
                    color: str,
         
     | 
| 
      
 1126 
     | 
    
         
            +
                    sample_id: str,
         
     | 
| 
      
 1127 
     | 
    
         
            +
                    mcp: str = "mcp_0",
         
     | 
| 
      
 1128 
     | 
    
         
            +
                    return_fig: bool | None = None,
         
     | 
| 
      
 1129 
     | 
    
         
            +
                    show: bool | None = None,
         
     | 
| 
      
 1130 
     | 
    
         
            +
                    save: bool | str | None = None,
         
     | 
| 
      
 1131 
     | 
    
         
            +
                ) -> PairGrid | Figure | None:
         
     | 
| 
      
 1132 
     | 
    
         
            +
                    """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
         
     | 
| 
      
 1133 
     | 
    
         
            +
             
     | 
| 
      
 1134 
     | 
    
         
            +
                    Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
         
     | 
| 
      
 1135 
     | 
    
         
            +
                    then creates a pairplot to visualize the relationships between these mean MCP values.
         
     | 
| 
      
 1136 
     | 
    
         
            +
             
     | 
| 
      
 1137 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 1138 
     | 
    
         
            +
                        adata: Annotated data object.
         
     | 
| 
      
 1139 
     | 
    
         
            +
                        celltype_key: Key in `adata.obs` containing cell type annotations.
         
     | 
| 
      
 1140 
     | 
    
         
            +
                        color: Key in `adata.obs` for color annotations. This parameter is used as the hue
         
     | 
| 
      
 1141 
     | 
    
         
            +
                        sample_id: Key in `adata.obs` for the sample annotations.
         
     | 
| 
      
 1142 
     | 
    
         
            +
                        mcp: Key in `adata.obs` for MCP feature values.
         
     | 
| 
      
 1143 
     | 
    
         
            +
             
     | 
| 
      
 1144 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 1145 
     | 
    
         
            +
                        Seaborn Pairgrid object.
         
     | 
| 
      
 1146 
     | 
    
         
            +
             
     | 
| 
      
 1147 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 1148 
     | 
    
         
            +
                        >>> import pertpy as pt
         
     | 
| 
      
 1149 
     | 
    
         
            +
                        >>> import scanpy as sc
         
     | 
| 
      
 1150 
     | 
    
         
            +
                        >>> adata = pt.dt.dialogue_example()
         
     | 
| 
      
 1151 
     | 
    
         
            +
                        >>> sc.pp.pca(adata)
         
     | 
| 
      
 1152 
     | 
    
         
            +
                        >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
         
     | 
| 
      
 1153 
     | 
    
         
            +
                            n_counts_key = "nCount_RNA", n_mpcs = 3)
         
     | 
| 
      
 1154 
     | 
    
         
            +
                        >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
         
     | 
| 
      
 1155 
     | 
    
         
            +
                        >>> dl.plot_pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
         
     | 
| 
      
 1156 
     | 
    
         
            +
             
     | 
| 
      
 1157 
     | 
    
         
            +
                    Preview:
         
     | 
| 
      
 1158 
     | 
    
         
            +
                        .. image:: /_static/docstring_previews/dialogue_pairplot.png
         
     | 
| 
      
 1159 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1160 
     | 
    
         
            +
                    mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
         
     | 
| 
      
 1161 
     | 
    
         
            +
                    mean_mcps = mean_mcps.reset_index()
         
     | 
| 
      
 1162 
     | 
    
         
            +
                    mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
         
     | 
| 
      
 1163 
     | 
    
         
            +
             
     | 
| 
      
 1164 
     | 
    
         
            +
                    aggstats = adata.obs.groupby([sample_id])[color].describe()
         
     | 
| 
      
 1165 
     | 
    
         
            +
                    aggstats = aggstats.loc[list(mcp_pivot.index), :]
         
     | 
| 
      
 1166 
     | 
    
         
            +
                    aggstats[color] = aggstats["top"]
         
     | 
| 
      
 1167 
     | 
    
         
            +
                    mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
         
     | 
| 
      
 1168 
     | 
    
         
            +
                    ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
         
     | 
| 
      
 1169 
     | 
    
         
            +
             
     | 
| 
      
 1170 
     | 
    
         
            +
                    if save:
         
     | 
| 
      
 1171 
     | 
    
         
            +
                        plt.savefig(save, bbox_inches="tight")
         
     | 
| 
      
 1172 
     | 
    
         
            +
                    if show:
         
     | 
| 
      
 1173 
     | 
    
         
            +
                        plt.show()
         
     | 
| 
      
 1174 
     | 
    
         
            +
                    if return_fig:
         
     | 
| 
      
 1175 
     | 
    
         
            +
                        return plt.gcf()
         
     | 
| 
      
 1176 
     | 
    
         
            +
                    if not (show or save):
         
     | 
| 
      
 1177 
     | 
    
         
            +
                        return ax
         
     | 
| 
      
 1178 
     | 
    
         
            +
                    return None
         
     | 
| 
         @@ -0,0 +1,20 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from ._base import ContrastType, LinearModelBase, MethodBase
         
     | 
| 
      
 2 
     | 
    
         
            +
            from ._dge_comparison import DGEEVAL
         
     | 
| 
      
 3 
     | 
    
         
            +
            from ._edger import EdgeR
         
     | 
| 
      
 4 
     | 
    
         
            +
            from ._pydeseq2 import PyDESeq2
         
     | 
| 
      
 5 
     | 
    
         
            +
            from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
         
     | 
| 
      
 6 
     | 
    
         
            +
            from ._statsmodels import Statsmodels
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
            __all__ = [
         
     | 
| 
      
 9 
     | 
    
         
            +
                "MethodBase",
         
     | 
| 
      
 10 
     | 
    
         
            +
                "LinearModelBase",
         
     | 
| 
      
 11 
     | 
    
         
            +
                "EdgeR",
         
     | 
| 
      
 12 
     | 
    
         
            +
                "PyDESeq2",
         
     | 
| 
      
 13 
     | 
    
         
            +
                "Statsmodels",
         
     | 
| 
      
 14 
     | 
    
         
            +
                "SimpleComparisonBase",
         
     | 
| 
      
 15 
     | 
    
         
            +
                "WilcoxonTest",
         
     | 
| 
      
 16 
     | 
    
         
            +
                "TTest",
         
     | 
| 
      
 17 
     | 
    
         
            +
                "ContrastType",
         
     | 
| 
      
 18 
     | 
    
         
            +
            ]
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
            AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
         
     |