pertpy 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +2 -1
- pertpy/data/__init__.py +61 -0
- pertpy/data/_dataloader.py +27 -23
- pertpy/data/_datasets.py +58 -0
- pertpy/metadata/__init__.py +2 -0
- pertpy/metadata/_cell_line.py +39 -70
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_drug.py +2 -6
- pertpy/metadata/_look_up.py +38 -51
- pertpy/metadata/_metadata.py +7 -10
- pertpy/metadata/_moa.py +2 -6
- pertpy/plot/__init__.py +0 -5
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +2 -3
- pertpy/tools/__init__.py +42 -4
- pertpy/tools/_augur.py +14 -15
- pertpy/tools/_cinemaot.py +2 -2
- pertpy/tools/_coda/_base_coda.py +118 -142
- pertpy/tools/_coda/_sccoda.py +16 -15
- pertpy/tools/_coda/_tasccoda.py +21 -22
- pertpy/tools/_dialogue.py +18 -23
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +21 -16
- pertpy/tools/_distances/_distances.py +406 -70
- pertpy/tools/_enrichment.py +10 -15
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +76 -53
- pertpy/tools/_mixscape.py +15 -11
- pertpy/tools/_perturbation_space/_clustering.py +5 -2
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +20 -22
- pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
- pertpy/tools/_perturbation_space/_simple.py +3 -3
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +33 -28
- pertpy/tools/_scgen/_utils.py +2 -2
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +22 -13
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -171
- pertpy/plot/_coda.py +0 -601
- pertpy/plot/_guide_rna.py +0 -64
- pertpy/plot/_milopy.py +0 -209
- pertpy/plot/_mixscape.py +0 -355
- pertpy/tools/_differential_gene_expression.py +0 -325
- pertpy-0.7.0.dist-info/RECORD +0 -53
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
    
        pertpy/tools/_enrichment.py
    CHANGED
    
    | @@ -82,18 +82,15 @@ class Enrichment: | |
| 82 82 | 
             
                                 - A dictionary of dictionaries with group categories as keys. Use `nested=True` in this case.
         | 
| 83 83 | 
             
                                 If not provided, ChEMBL-derived drug target sets are used.
         | 
| 84 84 | 
             
                        nested: Indicates if `targets` is a dictionary of dictionaries with group categories as keys.
         | 
| 85 | 
            -
                                Defaults to False.
         | 
| 86 85 | 
             
                        categories: To subset the gene groups to specific categories, especially when `targets=None` or `nested=True`.
         | 
| 87 86 | 
             
                                    For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
         | 
| 88 87 | 
             
                        method: Method for scoring gene groups. `"mean"` calculates the mean over all genes,
         | 
| 89 88 | 
             
                                while `"seurat"` uses a background profile subtraction approach.
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                        layer: Specifies which `.layers` of AnnData to use for expression values. Defaults to `.X` if None.
         | 
| 89 | 
            +
                        layer: Specifies which `.layers` of AnnData to use for expression values.
         | 
| 92 90 | 
             
                        n_bins: The number of expression bins for the `'seurat'` method.
         | 
| 93 91 | 
             
                        ctrl_size: The number of genes to randomly sample from each expression bin for the `"seurat"` method.
         | 
| 94 92 | 
             
                        key_added: Prefix key that adds the results to `uns`.
         | 
| 95 93 | 
             
                                   Note that the actual values are `key_added_score`, `key_added_variables`, `key_added_genes`, `key_added_all_genes`.
         | 
| 96 | 
            -
                                   Defaults to `pertpy_enrichment`.
         | 
| 97 94 |  | 
| 98 95 | 
             
                    Returns:
         | 
| 99 96 | 
             
                        An AnnData object with scores.
         | 
| @@ -259,16 +256,15 @@ class Enrichment: | |
| 259 256 | 
             
                               in the original expression space.
         | 
| 260 257 | 
             
                        targets: The gene groups to evaluate, either as a dictionary with names of the
         | 
| 261 258 | 
             
                                 groups as keys and gene lists as values, or a dictionary of dictionaries
         | 
| 262 | 
            -
                                 with names of gene group categories as keys. | 
| 259 | 
            +
                                 with names of gene group categories as keys.
         | 
| 263 260 | 
             
                                 case it uses `d2c.score()` output or loads ChEMBL-derived drug target sets.
         | 
| 264 261 | 
             
                        nested: Indicates if `targets` is a dictionary of dictionaries with group
         | 
| 265 | 
            -
                                categories as keys. | 
| 262 | 
            +
                                categories as keys.
         | 
| 266 263 | 
             
                        categories: Used to subset the gene groups to one or more categories,
         | 
| 267 | 
            -
                                    applicable if `targets=None` or `nested=True`. | 
| 264 | 
            +
                                    applicable if `targets=None` or `nested=True`.
         | 
| 268 265 | 
             
                        absolute: If True, passes the absolute values of scores to GSEA, improving
         | 
| 269 | 
            -
                                  statistical power. | 
| 266 | 
            +
                                  statistical power.
         | 
| 270 267 | 
             
                        key_added: Prefix key that adds the results to `uns`.
         | 
| 271 | 
            -
                                   Defaults to `pertpy_enrichment_gsea`.
         | 
| 272 268 |  | 
| 273 269 | 
             
                    Returns:
         | 
| 274 270 | 
             
                        A dictionary with clusters as keys and data frames of test results sorted on
         | 
| @@ -317,13 +313,12 @@ class Enrichment: | |
| 317 313 | 
             
                        targets: Gene groups to evaluate, which can be targets of known drugs, GO terms, pathway memberships, etc.
         | 
| 318 314 | 
             
                                 Accepts a dictionary of dictionaries with group categories as keys.
         | 
| 319 315 | 
             
                                 If not provided, ChEMBL-derived or dgbidb drug target sets are used, given by `source`.
         | 
| 320 | 
            -
                        source: Source of drug target sets when `targets=None`, `chembl`, `dgidb` or `pharmgkb`. | 
| 316 | 
            +
                        source: Source of drug target sets when `targets=None`, `chembl`, `dgidb` or `pharmgkb`.
         | 
| 321 317 | 
             
                        categories: To subset the gene groups to specific categories, especially when `targets=None`.
         | 
| 322 318 | 
             
                                        For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
         | 
| 323 | 
            -
                        category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`. | 
| 319 | 
            +
                        category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`.
         | 
| 324 320 | 
             
                        groupby: dotplot groupby such as clusters or cell types.
         | 
| 325 321 | 
             
                        key: Prefix key of enrichment results in `uns`.
         | 
| 326 | 
            -
                             Defaults to `pertpy_enrichment`.
         | 
| 327 322 | 
             
                        kwargs: Passed to scanpy dotplot.
         | 
| 328 323 |  | 
| 329 324 | 
             
                    Returns:
         | 
| @@ -436,9 +431,9 @@ class Enrichment: | |
| 436 431 | 
             
                    Args:
         | 
| 437 432 | 
             
                        adata: AnnData object to plot.
         | 
| 438 433 | 
             
                        enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values.
         | 
| 439 | 
            -
                        n: How many top scores to show for each group. | 
| 440 | 
            -
                        key: GSEA results key in `uns`. | 
| 441 | 
            -
                        interactive_plot: Whether to plot interactively or not. | 
| 434 | 
            +
                        n: How many top scores to show for each group.
         | 
| 435 | 
            +
                        key: GSEA results key in `uns`.
         | 
| 436 | 
            +
                        interactive_plot: Whether to plot interactively or not.
         | 
| 442 437 |  | 
| 443 438 | 
             
                    Examples:
         | 
| 444 439 | 
             
                        >>> import pertpy as pt
         | 
    
        pertpy/tools/_kernel_pca.py
    CHANGED
    
    | @@ -31,7 +31,7 @@ def kernel_pca( | |
| 31 31 |  | 
| 32 32 | 
             
                Returns:
         | 
| 33 33 | 
             
                    If `copy=True`, returns the copy of `adata` with kernel pca in `.obsm["X_kpca"]`.
         | 
| 34 | 
            -
                    Otherwise writes kernel pca directly to `.obsm["X_kpca"]` of the provided `adata`.
         | 
| 34 | 
            +
                    Otherwise, writes kernel pca directly to `.obsm["X_kpca"]` of the provided `adata`.
         | 
| 35 35 | 
             
                    If `return_transformer=True`, returns also the fitted `KernelPCA` transformer.
         | 
| 36 36 | 
             
                """
         | 
| 37 37 | 
             
                if copy:
         | 
    
        pertpy/tools/_milo.py
    CHANGED
    
    | @@ -11,22 +11,16 @@ import pandas as pd | |
| 11 11 | 
             
            import scanpy as sc
         | 
| 12 12 | 
             
            import seaborn as sns
         | 
| 13 13 | 
             
            from anndata import AnnData
         | 
| 14 | 
            +
            from lamin_utils import logger
         | 
| 14 15 | 
             
            from mudata import MuData
         | 
| 15 | 
            -
            from rich import print
         | 
| 16 16 |  | 
| 17 17 | 
             
            if TYPE_CHECKING:
         | 
| 18 18 | 
             
                from collections.abc import Sequence
         | 
| 19 19 |  | 
| 20 20 | 
             
                from matplotlib.axes import Axes
         | 
| 21 21 | 
             
                from matplotlib.colors import Colormap
         | 
| 22 | 
            +
                from matplotlib.figure import Figure
         | 
| 22 23 |  | 
| 23 | 
            -
            try:
         | 
| 24 | 
            -
                from rpy2.robjects import conversion, numpy2ri, pandas2ri
         | 
| 25 | 
            -
                from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
         | 
| 26 | 
            -
            except ModuleNotFoundError:
         | 
| 27 | 
            -
                print(
         | 
| 28 | 
            -
                    "[bold yellow]ryp2 is not installed. Install with [green]pip install rpy2 [yellow]to run tools with R support."
         | 
| 29 | 
            -
                )
         | 
| 30 24 | 
             
            from scipy.sparse import csr_matrix
         | 
| 31 25 | 
             
            from sklearn.metrics.pairwise import euclidean_distances
         | 
| 32 26 |  | 
| @@ -35,7 +29,16 @@ class Milo: | |
| 35 29 | 
             
                """Python implementation of Milo."""
         | 
| 36 30 |  | 
| 37 31 | 
             
                def __init__(self):
         | 
| 38 | 
            -
                     | 
| 32 | 
            +
                    try:
         | 
| 33 | 
            +
                        from rpy2.robjects import conversion, numpy2ri, pandas2ri
         | 
| 34 | 
            +
                        from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
         | 
| 35 | 
            +
                    except ModuleNotFoundError:
         | 
| 36 | 
            +
                        raise ImportError("milo requires rpy2 to be installed.") from None
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    try:
         | 
| 39 | 
            +
                        importr("edgeR")
         | 
| 40 | 
            +
                    except ImportError as e:
         | 
| 41 | 
            +
                        raise ImportError("milo requires a valid R installation with edger installed:\n") from e
         | 
| 39 42 |  | 
| 40 43 | 
             
                def load(
         | 
| 41 44 | 
             
                    self,
         | 
| @@ -48,7 +51,7 @@ class Milo: | |
| 48 51 | 
             
                        input: AnnData
         | 
| 49 52 | 
             
                        feature_key: Key to store the cell-level AnnData object in the MuData object
         | 
| 50 53 | 
             
                    Returns:
         | 
| 51 | 
            -
                        MuData: MuData object with original AnnData. | 
| 54 | 
            +
                        MuData: MuData object with original AnnData.
         | 
| 52 55 |  | 
| 53 56 | 
             
                    Examples:
         | 
| 54 57 | 
             
                        >>> import pertpy as pt
         | 
| @@ -80,11 +83,10 @@ class Milo: | |
| 80 83 | 
             
                        neighbors_key: The key in `adata.obsp` or `mdata[feature_key].obsp` to use as KNN graph.
         | 
| 81 84 | 
             
                                       If not specified, `make_nhoods` looks .obsp[‘connectivities’] for connectivities (default storage places for `scanpy.pp.neighbors`).
         | 
| 82 85 | 
             
                                       If specified, it looks at .obsp[.uns[neighbors_key][‘connectivities_key’]] for connectivities.
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                         | 
| 85 | 
            -
                         | 
| 86 | 
            -
                         | 
| 87 | 
            -
                        copy: Determines whether a copy of the `adata` is returned. Defaults to False.
         | 
| 86 | 
            +
                        feature_key: If input data is MuData, specify key to cell-level AnnData object.
         | 
| 87 | 
            +
                        prop: Fraction of cells to sample for neighbourhood index search.
         | 
| 88 | 
            +
                        seed: Random seed for cell sampling.
         | 
| 89 | 
            +
                        copy: Determines whether a copy of the `adata` is returned.
         | 
| 88 90 |  | 
| 89 91 | 
             
                    Returns:
         | 
| 90 92 | 
             
                        If `copy=True`, returns the copy of `adata` with the result in `.obs`, `.obsm`, and `.uns`.
         | 
| @@ -128,7 +130,7 @@ class Milo: | |
| 128 130 | 
             
                        try:
         | 
| 129 131 | 
             
                            knn_graph = adata.obsp["connectivities"].copy()
         | 
| 130 132 | 
             
                        except KeyError:
         | 
| 131 | 
            -
                             | 
| 133 | 
            +
                            logger.error('No "connectivities" slot in adata.obsp -- please run scanpy.pp.neighbors(adata) first')
         | 
| 132 134 | 
             
                            raise
         | 
| 133 135 | 
             
                    else:
         | 
| 134 136 | 
             
                        try:
         | 
| @@ -183,6 +185,7 @@ class Milo: | |
| 183 185 | 
             
                    dist_mat = knn_dists[nhood_ixs, :]
         | 
| 184 186 | 
             
                    k_distances = dist_mat.max(1).toarray().ravel()
         | 
| 185 187 | 
             
                    adata.obs["nhood_kth_distance"] = 0
         | 
| 188 | 
            +
                    adata.obs["nhood_kth_distance"] = adata.obs["nhood_kth_distance"].astype(float)
         | 
| 186 189 | 
             
                    adata.obs.loc[adata.obs["nhood_ixs_refined"] == 1, "nhood_kth_distance"] = k_distances
         | 
| 187 190 |  | 
| 188 191 | 
             
                    if copy:
         | 
| @@ -199,7 +202,7 @@ class Milo: | |
| 199 202 | 
             
                    Args:
         | 
| 200 203 | 
             
                        data: AnnData object with neighbourhoods defined in `obsm['nhoods']` or MuData object with a modality with neighbourhoods defined in `obsm['nhoods']`
         | 
| 201 204 | 
             
                        sample_col: Column in adata.obs that contains sample information
         | 
| 202 | 
            -
                        feature_key: If input data is MuData, specify key to cell-level AnnData object. | 
| 205 | 
            +
                        feature_key: If input data is MuData, specify key to cell-level AnnData object.
         | 
| 203 206 |  | 
| 204 207 | 
             
                    Returns:
         | 
| 205 208 | 
             
                        MuData object storing the original (i.e. rna) AnnData in `mudata[feature_key]`
         | 
| @@ -230,7 +233,7 @@ class Milo: | |
| 230 233 | 
             
                        try:
         | 
| 231 234 | 
             
                            nhoods = adata.obsm["nhoods"]
         | 
| 232 235 | 
             
                        except KeyError:
         | 
| 233 | 
            -
                             | 
| 236 | 
            +
                            logger.error('Cannot find "nhoods" slot in adata.obsm -- please run milopy.make_nhoods(adata)')
         | 
| 234 237 | 
             
                            raise
         | 
| 235 238 | 
             
                    # Make nhood abundance matrix
         | 
| 236 239 | 
             
                    sample_dummies = pd.get_dummies(adata.obs[sample_col])
         | 
| @@ -238,7 +241,7 @@ class Milo: | |
| 238 241 | 
             
                    sample_dummies = csr_matrix(sample_dummies.values)
         | 
| 239 242 | 
             
                    nhood_count_mat = nhoods.T.dot(sample_dummies)
         | 
| 240 243 | 
             
                    sample_obs = pd.DataFrame(index=all_samples)
         | 
| 241 | 
            -
                    sample_adata = AnnData(X=nhood_count_mat.T, obs=sample_obs | 
| 244 | 
            +
                    sample_adata = AnnData(X=nhood_count_mat.T, obs=sample_obs)
         | 
| 242 245 | 
             
                    sample_adata.uns["sample_col"] = sample_col
         | 
| 243 246 | 
             
                    # Save nhood index info
         | 
| 244 247 | 
             
                    sample_adata.var["index_cell"] = adata.obs_names[adata.obs["nhood_ixs_refined"] == 1]
         | 
| @@ -270,10 +273,10 @@ class Milo: | |
| 270 273 | 
             
                        design: Formula for the test, following glm syntax from R (e.g. '~ condition').
         | 
| 271 274 | 
             
                                Terms should be columns in `milo_mdata[feature_key].obs`.
         | 
| 272 275 | 
             
                        model_contrasts: A string vector that defines the contrasts used to perform DA testing, following glm syntax from R (e.g. "conditionDisease - conditionControl").
         | 
| 273 | 
            -
                                         If no contrast is specified (default), then the last categorical level in condition of interest is used as the test group. | 
| 274 | 
            -
                        subset_samples: subset of samples (obs in `milo_mdata['milo']`) to use for the test. | 
| 275 | 
            -
                        add_intercept: whether to include an intercept in the model. If False, this is equivalent to adding + 0 in the design formula. When model_contrasts is specified, this is set to False by default. | 
| 276 | 
            -
                        feature_key: If input data is MuData, specify key to cell-level AnnData object. | 
| 276 | 
            +
                                         If no contrast is specified (default), then the last categorical level in condition of interest is used as the test group.
         | 
| 277 | 
            +
                        subset_samples: subset of samples (obs in `milo_mdata['milo']`) to use for the test.
         | 
| 278 | 
            +
                        add_intercept: whether to include an intercept in the model. If False, this is equivalent to adding + 0 in the design formula. When model_contrasts is specified, this is set to False by default.
         | 
| 279 | 
            +
                        feature_key: If input data is MuData, specify key to cell-level AnnData object.
         | 
| 277 280 | 
             
                        solver: The solver to fit the model to. One of "edger" (requires R, rpy2 and edgeR to be installed) or "batchglm"
         | 
| 278 281 |  | 
| 279 282 | 
             
                    Returns:
         | 
| @@ -297,8 +300,8 @@ class Milo: | |
| 297 300 | 
             
                    try:
         | 
| 298 301 | 
             
                        sample_adata = mdata["milo"]
         | 
| 299 302 | 
             
                    except KeyError:
         | 
| 300 | 
            -
                         | 
| 301 | 
            -
                            " | 
| 303 | 
            +
                        logger.error(
         | 
| 304 | 
            +
                            "milo_mdata should be a MuData object with two slots:"
         | 
| 302 305 | 
             
                            " feature_key and 'milo' - please run milopy.count_nhoods() first"
         | 
| 303 306 | 
             
                        )
         | 
| 304 307 | 
             
                        raise
         | 
| @@ -312,7 +315,7 @@ class Milo: | |
| 312 315 | 
             
                        sample_obs = adata.obs[covariates + [sample_col]].drop_duplicates()
         | 
| 313 316 | 
             
                    except KeyError:
         | 
| 314 317 | 
             
                        missing_cov = [x for x in covariates if x not in sample_adata.obs.columns]
         | 
| 315 | 
            -
                         | 
| 318 | 
            +
                        logger.warning("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
         | 
| 316 319 | 
             
                        raise
         | 
| 317 320 | 
             
                    sample_obs = sample_obs[covariates + [sample_col]]
         | 
| 318 321 | 
             
                    sample_obs.index = sample_obs[sample_col].astype("str")
         | 
| @@ -320,7 +323,7 @@ class Milo: | |
| 320 323 | 
             
                    try:
         | 
| 321 324 | 
             
                        assert sample_obs.loc[sample_adata.obs_names].shape[0] == len(sample_adata.obs_names)
         | 
| 322 325 | 
             
                    except AssertionError:
         | 
| 323 | 
            -
                         | 
| 326 | 
            +
                        logger.warning(
         | 
| 324 327 | 
             
                            f"Values in mdata[{feature_key}].obs[{covariates}] cannot be unambiguously assigned to each sample"
         | 
| 325 328 | 
             
                            f" -- each sample value should match a single covariate value"
         | 
| 326 329 | 
             
                        )
         | 
| @@ -332,7 +335,9 @@ class Milo: | |
| 332 335 | 
             
                        design_df = sample_adata.obs[covariates]
         | 
| 333 336 | 
             
                    except KeyError:
         | 
| 334 337 | 
             
                        missing_cov = [x for x in covariates if x not in sample_adata.obs.columns]
         | 
| 335 | 
            -
                         | 
| 338 | 
            +
                        logger.error(
         | 
| 339 | 
            +
                            'Covariates {c} are not columns in adata.uns["sample_adata"].obs'.format(c=" ".join(missing_cov))
         | 
| 340 | 
            +
                        )
         | 
| 336 341 | 
             
                        raise
         | 
| 337 342 | 
             
                    # Get count matrix
         | 
| 338 343 | 
             
                    count_mat = sample_adata.X.T.toarray()
         | 
| @@ -376,6 +381,8 @@ class Milo: | |
| 376 381 | 
             
                                return(colnames(m))
         | 
| 377 382 | 
             
                            }
         | 
| 378 383 | 
             
                            """
         | 
| 384 | 
            +
                            from rpy2.robjects.packages import STAP
         | 
| 385 | 
            +
             | 
| 379 386 | 
             
                            get_model_cols = STAP(r_str, "get_model_cols")
         | 
| 380 387 | 
             
                            model_mat_cols = get_model_cols.get_model_cols(design_df, design)
         | 
| 381 388 | 
             
                            model_df = pd.DataFrame(model)
         | 
| @@ -383,13 +390,16 @@ class Milo: | |
| 383 390 | 
             
                            try:
         | 
| 384 391 | 
             
                                mod_contrast = limma.makeContrasts(contrasts=model_contrasts, levels=model_df)
         | 
| 385 392 | 
             
                            except ValueError:
         | 
| 386 | 
            -
                                 | 
| 393 | 
            +
                                logger.error("Model contrasts must be in the form 'A-B' or 'A+B'")
         | 
| 387 394 | 
             
                                raise
         | 
| 388 395 | 
             
                            res = base.as_data_frame(
         | 
| 389 396 | 
             
                                edgeR.topTags(edgeR.glmQLFTest(fit, contrast=mod_contrast), sort_by="none", n=np.inf)
         | 
| 390 397 | 
             
                            )
         | 
| 391 398 | 
             
                        else:
         | 
| 392 399 | 
             
                            res = base.as_data_frame(edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf))
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                        from rpy2.robjects import conversion
         | 
| 402 | 
            +
             | 
| 393 403 | 
             
                        res = conversion.rpy2py(res)
         | 
| 394 404 | 
             
                        if not isinstance(res, pd.DataFrame):
         | 
| 395 405 | 
             
                            res = pd.DataFrame(res)
         | 
| @@ -414,7 +424,7 @@ class Milo: | |
| 414 424 | 
             
                    Args:
         | 
| 415 425 | 
             
                        mdata: MuData object
         | 
| 416 426 | 
             
                        anno_col: Column in adata.obs containing the cell annotations to use for nhood labelling
         | 
| 417 | 
            -
                        feature_key: If input data is MuData, specify key to cell-level AnnData object. | 
| 427 | 
            +
                        feature_key: If input data is MuData, specify key to cell-level AnnData object.
         | 
| 418 428 |  | 
| 419 429 | 
             
                    Returns:
         | 
| 420 430 | 
             
                        None. Adds in place:
         | 
| @@ -437,7 +447,7 @@ class Milo: | |
| 437 447 | 
             
                    try:
         | 
| 438 448 | 
             
                        sample_adata = mdata["milo"]
         | 
| 439 449 | 
             
                    except KeyError:
         | 
| 440 | 
            -
                         | 
| 450 | 
            +
                        logger.error(
         | 
| 441 451 | 
             
                            "milo_mdata should be a MuData object with two slots: feature_key and 'milo' - please run milopy.count_nhoods(adata) first"
         | 
| 442 452 | 
             
                        )
         | 
| 443 453 | 
             
                        raise
         | 
| @@ -468,7 +478,7 @@ class Milo: | |
| 468 478 | 
             
                    Args:
         | 
| 469 479 | 
             
                        mdata: MuData object
         | 
| 470 480 | 
             
                        anno_col: Column in adata.obs containing the cell annotations to use for nhood labelling
         | 
| 471 | 
            -
                        feature_key: If input data is MuData, specify key to cell-level AnnData object. | 
| 481 | 
            +
                        feature_key: If input data is MuData, specify key to cell-level AnnData object.
         | 
| 472 482 |  | 
| 473 483 | 
             
                    Returns:
         | 
| 474 484 | 
             
                        None. Adds in place:
         | 
| @@ -509,7 +519,7 @@ class Milo: | |
| 509 519 | 
             
                    Args:
         | 
| 510 520 | 
             
                        mdata: MuData object
         | 
| 511 521 | 
             
                        new_covariates: columns in `milo_mdata[feature_key].obs` to add to `milo_mdata['milo'].obs`.
         | 
| 512 | 
            -
                        feature_key: If input data is MuData, specify key to cell-level AnnData object. | 
| 522 | 
            +
                        feature_key: If input data is MuData, specify key to cell-level AnnData object.
         | 
| 513 523 |  | 
| 514 524 | 
             
                    Returns:
         | 
| 515 525 | 
             
                        None, adds columns to `milo_mdata['milo']` in place
         | 
| @@ -528,7 +538,7 @@ class Milo: | |
| 528 538 | 
             
                    try:
         | 
| 529 539 | 
             
                        sample_adata = mdata["milo"]
         | 
| 530 540 | 
             
                    except KeyError:
         | 
| 531 | 
            -
                         | 
| 541 | 
            +
                        logger.error(
         | 
| 532 542 | 
             
                            "milo_mdata should be a MuData object with two slots: feature_key and 'milo' - please run milopy.count_nhoods(adata) first"
         | 
| 533 543 | 
             
                        )
         | 
| 534 544 | 
             
                        raise
         | 
| @@ -542,14 +552,14 @@ class Milo: | |
| 542 552 | 
             
                        sample_obs = adata.obs[covariates + [sample_col]].drop_duplicates()
         | 
| 543 553 | 
             
                    except KeyError:
         | 
| 544 554 | 
             
                        missing_cov = [covar for covar in covariates if covar not in sample_adata.obs.columns]
         | 
| 545 | 
            -
                         | 
| 555 | 
            +
                        logger.error("Covariates {c} are not columns in adata.obs".format(c=" ".join(missing_cov)))
         | 
| 546 556 | 
             
                        raise
         | 
| 547 557 | 
             
                    sample_obs = sample_obs[covariates + [sample_col]].astype("str")
         | 
| 548 558 | 
             
                    sample_obs.index = sample_obs[sample_col]
         | 
| 549 559 | 
             
                    try:
         | 
| 550 560 | 
             
                        assert sample_obs.loc[sample_adata.obs_names].shape[0] == len(sample_adata.obs_names)
         | 
| 551 561 | 
             
                    except ValueError:
         | 
| 552 | 
            -
                         | 
| 562 | 
            +
                        logger.error(
         | 
| 553 563 | 
             
                            "Covariates cannot be unambiguously assigned to each sample -- each sample value should match a single covariate value"
         | 
| 554 564 | 
             
                        )
         | 
| 555 565 | 
             
                        raise
         | 
| @@ -560,8 +570,8 @@ class Milo: | |
| 560 570 |  | 
| 561 571 | 
             
                    Args:
         | 
| 562 572 | 
             
                        mdata: MuData object
         | 
| 563 | 
            -
                        basis: Name of the obsm basis to use for layout of neighbourhoods (key in `adata.obsm`). | 
| 564 | 
            -
                        feature_key: If input data is MuData, specify key to cell-level AnnData object. | 
| 573 | 
            +
                        basis: Name of the obsm basis to use for layout of neighbourhoods (key in `adata.obsm`).
         | 
| 574 | 
            +
                        feature_key: If input data is MuData, specify key to cell-level AnnData object.
         | 
| 565 575 |  | 
| 566 576 | 
             
                    Returns:
         | 
| 567 577 | 
             
                        - `milo_mdata['milo'].varp['nhood_connectivities']`: graph of overlap between neighbourhoods (i.e. no of shared cells)
         | 
| @@ -593,13 +603,13 @@ class Milo: | |
| 593 603 | 
             
                        "distances_key": "",
         | 
| 594 604 | 
             
                    }
         | 
| 595 605 |  | 
| 596 | 
            -
                def add_nhood_expression(self, mdata: MuData, layer: str | None = None, feature_key: str | None = "rna"):
         | 
| 606 | 
            +
                def add_nhood_expression(self, mdata: MuData, layer: str | None = None, feature_key: str | None = "rna") -> None:
         | 
| 597 607 | 
             
                    """Calculates the mean expression in neighbourhoods of each feature.
         | 
| 598 608 |  | 
| 599 609 | 
             
                    Args:
         | 
| 600 610 | 
             
                        mdata: MuData object
         | 
| 601 | 
            -
                        layer: If provided, use `milo_mdata[feature_key][layer]` as expression matrix instead of `milo_mdata[feature_key].X`. | 
| 602 | 
            -
                        feature_key: If input data is MuData, specify key to cell-level AnnData object. | 
| 611 | 
            +
                        layer: If provided, use `milo_mdata[feature_key][layer]` as expression matrix instead of `milo_mdata[feature_key].X`.
         | 
| 612 | 
            +
                        feature_key: If input data is MuData, specify key to cell-level AnnData object.
         | 
| 603 613 |  | 
| 604 614 | 
             
                    Returns:
         | 
| 605 615 | 
             
                        Updates adata in place to store the matrix of average expression in each neighbourhood in `milo_mdata['milo'].varm['expr']`
         | 
| @@ -618,7 +628,7 @@ class Milo: | |
| 618 628 | 
             
                    try:
         | 
| 619 629 | 
             
                        sample_adata = mdata["milo"]
         | 
| 620 630 | 
             
                    except KeyError:
         | 
| 621 | 
            -
                         | 
| 631 | 
            +
                        logger.error(
         | 
| 622 632 | 
             
                            "milo_mdata should be a MuData object with two slots:"
         | 
| 623 633 | 
             
                            " feature_key and 'milo' - please run milopy.count_nhoods(adata) first"
         | 
| 624 634 | 
             
                        )
         | 
| @@ -642,6 +652,9 @@ class Milo: | |
| 642 652 | 
             
                    self,
         | 
| 643 653 | 
             
                ):
         | 
| 644 654 | 
             
                    """Set up rpy2 to run edgeR"""
         | 
| 655 | 
            +
                    from rpy2.robjects import numpy2ri, pandas2ri
         | 
| 656 | 
            +
                    from rpy2.robjects.packages import importr
         | 
| 657 | 
            +
             | 
| 645 658 | 
             
                    numpy2ri.activate()
         | 
| 646 659 | 
             
                    pandas2ri.activate()
         | 
| 647 660 | 
             
                    edgeR = self._try_import_bioc_library("edgeR")
         | 
| @@ -660,11 +673,13 @@ class Milo: | |
| 660 673 | 
             
                    Args:
         | 
| 661 674 | 
             
                        name (str): R packages name
         | 
| 662 675 | 
             
                    """
         | 
| 676 | 
            +
                    from rpy2.robjects.packages import PackageNotInstalledError, importr
         | 
| 677 | 
            +
             | 
| 663 678 | 
             
                    try:
         | 
| 664 679 | 
             
                        _r_lib = importr(name)
         | 
| 665 680 | 
             
                        return _r_lib
         | 
| 666 681 | 
             
                    except PackageNotInstalledError:
         | 
| 667 | 
            -
                         | 
| 682 | 
            +
                        logger.error(f"Install Bioconductor library `{name!r}` first as `BiocManager::install({name!r}).`")
         | 
| 668 683 | 
             
                        raise
         | 
| 669 684 |  | 
| 670 685 | 
             
                def _graph_spatial_fdr(
         | 
| @@ -678,7 +693,7 @@ class Milo: | |
| 678 693 |  | 
| 679 694 | 
             
                    Args:
         | 
| 680 695 | 
             
                        sample_adata: Sample-level AnnData.
         | 
| 681 | 
            -
                        neighbors_key: The key in `adata.obsp` to use as KNN graph. | 
| 696 | 
            +
                        neighbors_key: The key in `adata.obsp` to use as KNN graph.
         | 
| 682 697 | 
             
                    """
         | 
| 683 698 | 
             
                    # use 1/connectivity as the weighting for the weighted BH adjustment from Cydar
         | 
| 684 699 | 
             
                    w = 1 / sample_adata.var["kth_distance"]
         | 
| @@ -718,10 +733,10 @@ class Milo: | |
| 718 733 | 
             
                    Args:
         | 
| 719 734 | 
             
                        mdata: MuData object
         | 
| 720 735 | 
             
                        alpha: Significance threshold. (default: 0.1)
         | 
| 721 | 
            -
                        min_logFC: Minimum absolute log-Fold Change to show results. If is 0, show all significant neighbourhoods. | 
| 736 | 
            +
                        min_logFC: Minimum absolute log-Fold Change to show results. If is 0, show all significant neighbourhoods.
         | 
| 722 737 | 
             
                        min_size: Minimum size of nodes in visualization. (default: 10)
         | 
| 723 | 
            -
                        plot_edges: If edges for neighbourhood overlaps whould be plotted. | 
| 724 | 
            -
                        title: Plot title. | 
| 738 | 
            +
                        plot_edges: If edges for neighbourhood overlaps whould be plotted.
         | 
| 739 | 
            +
                        title: Plot title.
         | 
| 725 740 | 
             
                        show: Show the plot, do not return axis.
         | 
| 726 741 | 
             
                        save: If `True` or a `str`, save the figure. A string is appended to the default filename.
         | 
| 727 742 | 
             
                              Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
         | 
| @@ -807,7 +822,7 @@ class Milo: | |
| 807 822 | 
             
                    Args:
         | 
| 808 823 | 
             
                        mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']`
         | 
| 809 824 | 
             
                        ix: index of neighbourhood to visualize
         | 
| 810 | 
            -
                        basis: Embedding to use for visualization. | 
| 825 | 
            +
                        basis: Embedding to use for visualization.
         | 
| 811 826 | 
             
                        show: Show the plot, do not return axis.
         | 
| 812 827 | 
             
                        save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
         | 
| 813 828 | 
             
                        **kwargs: Additional arguments to `scanpy.pl.embedding`.
         | 
| @@ -853,14 +868,14 @@ class Milo: | |
| 853 868 | 
             
                    return_fig: bool | None = None,
         | 
| 854 869 | 
             
                    save: bool | str | None = None,
         | 
| 855 870 | 
             
                    show: bool | None = None,
         | 
| 856 | 
            -
                ) -> None:
         | 
| 871 | 
            +
                ) -> Figure | Axes | None:
         | 
| 857 872 | 
             
                    """Plot beeswarm plot of logFC against nhood labels
         | 
| 858 873 |  | 
| 859 874 | 
             
                    Args:
         | 
| 860 875 | 
             
                        mdata: MuData object
         | 
| 861 876 | 
             
                        anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
         | 
| 862 877 | 
             
                        alpha: Significance threshold. (default: 0.1)
         | 
| 863 | 
            -
                        subset_nhoods: List of nhoods to plot. If None, plot all nhoods. | 
| 878 | 
            +
                        subset_nhoods: List of nhoods to plot. If None, plot all nhoods.
         | 
| 864 879 | 
             
                        palette: Name of Seaborn color palette for violinplots.
         | 
| 865 880 | 
             
                                 Defaults to pre-defined category colors for violinplots.
         | 
| 866 881 |  | 
| @@ -960,13 +975,17 @@ class Milo: | |
| 960 975 |  | 
| 961 976 | 
             
                    if save:
         | 
| 962 977 | 
             
                        plt.savefig(save, bbox_inches="tight")
         | 
| 978 | 
            +
                        return None
         | 
| 963 979 | 
             
                    if show:
         | 
| 964 980 | 
             
                        plt.show()
         | 
| 981 | 
            +
                        return None
         | 
| 965 982 | 
             
                    if return_fig:
         | 
| 966 983 | 
             
                        return plt.gcf()
         | 
| 967 984 | 
             
                    if (not show and not save) or (show is None and save is None):
         | 
| 968 985 | 
             
                        return plt.gca()
         | 
| 969 986 |  | 
| 987 | 
            +
                    return None
         | 
| 988 | 
            +
             | 
| 970 989 | 
             
                def plot_nhood_counts_by_cond(
         | 
| 971 990 | 
             
                    self,
         | 
| 972 991 | 
             
                    mdata: MuData,
         | 
| @@ -976,14 +995,14 @@ class Milo: | |
| 976 995 | 
             
                    return_fig: bool | None = None,
         | 
| 977 996 | 
             
                    save: bool | str | None = None,
         | 
| 978 997 | 
             
                    show: bool | None = None,
         | 
| 979 | 
            -
                ) -> None:
         | 
| 998 | 
            +
                ) -> Figure | Axes | None:
         | 
| 980 999 | 
             
                    """Plot boxplot of cell numbers vs condition of interest.
         | 
| 981 1000 |  | 
| 982 1001 | 
             
                    Args:
         | 
| 983 1002 | 
             
                        mdata: MuData object storing cell level and nhood level information
         | 
| 984 1003 | 
             
                        test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot)
         | 
| 985 | 
            -
                        subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods. | 
| 986 | 
            -
                        log_counts: Whether to plot log1p of cell counts. | 
| 1004 | 
            +
                        subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods.
         | 
| 1005 | 
            +
                        log_counts: Whether to plot log1p of cell counts.
         | 
| 987 1006 | 
             
                    """
         | 
| 988 1007 | 
             
                    try:
         | 
| 989 1008 | 
             
                        nhood_adata = mdata["milo"].T.copy()
         | 
| @@ -1014,9 +1033,13 @@ class Milo: | |
| 1014 1033 |  | 
| 1015 1034 | 
             
                    if save:
         | 
| 1016 1035 | 
             
                        plt.savefig(save, bbox_inches="tight")
         | 
| 1036 | 
            +
                        return None
         | 
| 1017 1037 | 
             
                    if show:
         | 
| 1018 1038 | 
             
                        plt.show()
         | 
| 1039 | 
            +
                        return None
         | 
| 1019 1040 | 
             
                    if return_fig:
         | 
| 1020 1041 | 
             
                        return plt.gcf()
         | 
| 1021 1042 | 
             
                    if not (show or save):
         | 
| 1022 1043 | 
             
                        return plt.gca()
         | 
| 1044 | 
            +
             | 
| 1045 | 
            +
                    return None
         | 
    
        pertpy/tools/_mixscape.py
    CHANGED
    
    | @@ -178,7 +178,7 @@ class Mixscape: | |
| 178 178 | 
             
                        split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
         | 
| 179 179 | 
             
                                the perturbation signature for every replicate separately.
         | 
| 180 180 | 
             
                        pval_cutoff: P-value cut-off for selection of significantly DE genes.
         | 
| 181 | 
            -
                        perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. | 
| 181 | 
            +
                        perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
         | 
| 182 182 | 
             
                        copy: Determines whether a copy of the `adata` is returned.
         | 
| 183 183 |  | 
| 184 184 | 
             
                    Returns:
         | 
| @@ -227,7 +227,7 @@ class Mixscape: | |
| 227 227 | 
             
                            X = adata_comp.layers["X_pert"]
         | 
| 228 228 | 
             
                        except KeyError:
         | 
| 229 229 | 
             
                            raise KeyError(
         | 
| 230 | 
            -
                                "No 'X_pert' found in .layers! Please run  | 
| 230 | 
            +
                                "No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
         | 
| 231 231 | 
             
                            ) from None
         | 
| 232 232 | 
             
                    # initialize return variables
         | 
| 233 233 | 
             
                    adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
         | 
| @@ -315,7 +315,9 @@ class Mixscape: | |
| 315 315 | 
             
                                )
         | 
| 316 316 |  | 
| 317 317 | 
             
                            adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
         | 
| 318 | 
            -
                            adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] =  | 
| 318 | 
            +
                            adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = np.round(
         | 
| 319 | 
            +
                                post_prob
         | 
| 320 | 
            +
                            ).astype("int64")
         | 
| 319 321 | 
             
                    adata.uns["mixscape"] = gv_list
         | 
| 320 322 |  | 
| 321 323 | 
             
                    if copy:
         | 
| @@ -344,15 +346,13 @@ class Mixscape: | |
| 344 346 | 
             
                        control: Control category from the `pert_key` column.
         | 
| 345 347 | 
             
                        mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
         | 
| 346 348 | 
             
                        layer: Key from `adata.layers` whose value will be used to perform tests on.
         | 
| 347 | 
            -
                        control: Control category from the `pert_key` column. | 
| 348 | 
            -
                        n_comps: Number of principal components to use. | 
| 349 | 
            +
                        control: Control category from the `pert_key` column.
         | 
| 350 | 
            +
                        n_comps: Number of principal components to use.
         | 
| 349 351 | 
             
                        min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
         | 
| 350 352 | 
             
                        logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
         | 
| 351 | 
            -
                                         Defaults to 0.25.
         | 
| 352 353 | 
             
                        split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
         | 
| 353 354 | 
             
                        pval_cutoff: P-value cut-off for selection of significantly DE genes.
         | 
| 354 355 | 
             
                        perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
         | 
| 355 | 
            -
                                           Defaults to KO.
         | 
| 356 356 | 
             
                        copy: Determines whether a copy of the `adata` is returned.
         | 
| 357 357 |  | 
| 358 358 | 
             
                    Returns:
         | 
| @@ -461,7 +461,13 @@ class Mixscape: | |
| 461 461 | 
             
                        adata_split = adata[split_mask].copy()
         | 
| 462 462 | 
             
                        # find top DE genes between cells with targeting and non-targeting gRNAs
         | 
| 463 463 | 
             
                        sc.tl.rank_genes_groups(
         | 
| 464 | 
            -
                            adata_split, | 
| 464 | 
            +
                            adata_split,
         | 
| 465 | 
            +
                            layer=layer,
         | 
| 466 | 
            +
                            groupby=labels,
         | 
| 467 | 
            +
                            groups=genes,
         | 
| 468 | 
            +
                            reference=control,
         | 
| 469 | 
            +
                            method="t-test",
         | 
| 470 | 
            +
                            use_raw=False,
         | 
| 465 471 | 
             
                        )
         | 
| 466 472 | 
             
                        # get DE genes for each gene
         | 
| 467 473 | 
             
                        for gene in genes:
         | 
| @@ -704,7 +710,6 @@ class Mixscape: | |
| 704 710 | 
             
                        before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
         | 
| 705 711 | 
             
                                         Default is set to NULL and plots cells by original class ID.
         | 
| 706 712 | 
             
                        perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
         | 
| 707 | 
            -
                                           Defaults to `KO`.
         | 
| 708 713 |  | 
| 709 714 | 
             
                    Examples:
         | 
| 710 715 | 
             
                        Visualizing the perturbation scores for the cells in a dataset:
         | 
| @@ -881,7 +886,7 @@ class Mixscape: | |
| 881 886 | 
             
                        keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
         | 
| 882 887 | 
             
                        groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
         | 
| 883 888 | 
             
                        log: Plot on logarithmic axis.
         | 
| 884 | 
            -
                        use_raw: Whether to use `raw` attribute of `adata`. | 
| 889 | 
            +
                        use_raw: Whether to use `raw` attribute of `adata`.
         | 
| 885 890 | 
             
                        stripplot: Add a stripplot on top of the violin plot.
         | 
| 886 891 | 
             
                        order: Order in which to show the categories.
         | 
| 887 892 | 
             
                        xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
         | 
| @@ -1075,7 +1080,6 @@ class Mixscape: | |
| 1075 1080 | 
             
                        mixscape_class: The column of `.obs` with the mixscape classification result.
         | 
| 1076 1081 | 
             
                        mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
         | 
| 1077 1082 | 
             
                        perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
         | 
| 1078 | 
            -
                                           Defaults to 'KO'.
         | 
| 1079 1083 | 
             
                        lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
         | 
| 1080 1084 | 
             
                        n_components: The number of dimensions of the embedding.
         | 
| 1081 1085 | 
             
                        show: Show the plot, do not return axis.
         | 
| @@ -7,6 +7,8 @@ from sklearn.metrics import pairwise_distances | |
| 7 7 | 
             
            from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
         | 
| 8 8 |  | 
| 9 9 | 
             
            if TYPE_CHECKING:
         | 
| 10 | 
            +
                from collections.abc import Iterable
         | 
| 11 | 
            +
             | 
| 10 12 | 
             
                from anndata import AnnData
         | 
| 11 13 |  | 
| 12 14 |  | 
| @@ -14,6 +16,7 @@ class ClusteringSpace(PerturbationSpace): | |
| 14 16 | 
             
                """Applies various clustering techniques to an embedding."""
         | 
| 15 17 |  | 
| 16 18 | 
             
                def __init__(self):
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 17 20 | 
             
                    self.X = None
         | 
| 18 21 |  | 
| 19 22 | 
             
                def evaluate_clustering(
         | 
| @@ -21,7 +24,7 @@ class ClusteringSpace(PerturbationSpace): | |
| 21 24 | 
             
                    adata: AnnData,
         | 
| 22 25 | 
             
                    true_label_col: str,
         | 
| 23 26 | 
             
                    cluster_col: str,
         | 
| 24 | 
            -
                    metrics:  | 
| 27 | 
            +
                    metrics: Iterable[str] = None,
         | 
| 25 28 | 
             
                    **kwargs,
         | 
| 26 29 | 
             
                ):
         | 
| 27 30 | 
             
                    """Evaluation of previously computed clustering against ground truth labels.
         | 
| @@ -30,7 +33,7 @@ class ClusteringSpace(PerturbationSpace): | |
| 30 33 | 
             
                        adata: AnnData object that contains the clustered data and the cluster labels.
         | 
| 31 34 | 
             
                        true_label_col: ground truth labels.
         | 
| 32 35 | 
             
                        cluster_col: cluster computed labels.
         | 
| 33 | 
            -
                        metrics: Metrics to compute.  | 
| 36 | 
            +
                        metrics: Metrics to compute. If `None` it defaults to ["nmi", "ari", "asw"].
         | 
| 34 37 | 
             
                        **kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed.
         | 
| 35 38 | 
             
                            For asw, metric, distances, sample_size, and random_state can be passed.
         | 
| 36 39 |  |