pertpy 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +2 -1
- pertpy/data/__init__.py +61 -0
- pertpy/data/_dataloader.py +27 -23
- pertpy/data/_datasets.py +58 -0
- pertpy/metadata/__init__.py +2 -0
- pertpy/metadata/_cell_line.py +39 -70
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_drug.py +2 -6
- pertpy/metadata/_look_up.py +38 -51
- pertpy/metadata/_metadata.py +7 -10
- pertpy/metadata/_moa.py +2 -6
- pertpy/plot/__init__.py +0 -5
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +2 -3
- pertpy/tools/__init__.py +42 -4
- pertpy/tools/_augur.py +14 -15
- pertpy/tools/_cinemaot.py +2 -2
- pertpy/tools/_coda/_base_coda.py +118 -142
- pertpy/tools/_coda/_sccoda.py +16 -15
- pertpy/tools/_coda/_tasccoda.py +21 -22
- pertpy/tools/_dialogue.py +18 -23
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +21 -16
- pertpy/tools/_distances/_distances.py +406 -70
- pertpy/tools/_enrichment.py +10 -15
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +76 -53
- pertpy/tools/_mixscape.py +15 -11
- pertpy/tools/_perturbation_space/_clustering.py +5 -2
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +20 -22
- pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
- pertpy/tools/_perturbation_space/_simple.py +3 -3
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +33 -28
- pertpy/tools/_scgen/_utils.py +2 -2
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +22 -13
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -171
- pertpy/plot/_coda.py +0 -601
- pertpy/plot/_guide_rna.py +0 -64
- pertpy/plot/_milopy.py +0 -209
- pertpy/plot/_mixscape.py +0 -355
- pertpy/tools/_differential_gene_expression.py +0 -325
- pertpy-0.7.0.dist-info/RECORD +0 -53
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
| @@ -1,7 +1,8 @@ | |
| 1 1 | 
             
            from __future__ import annotations
         | 
| 2 2 |  | 
| 3 | 
            +
            import multiprocessing
         | 
| 3 4 | 
             
            from abc import ABC, abstractmethod
         | 
| 4 | 
            -
            from typing import TYPE_CHECKING
         | 
| 5 | 
            +
            from typing import TYPE_CHECKING, Literal, NamedTuple
         | 
| 5 6 |  | 
| 6 7 | 
             
            import numba
         | 
| 7 8 | 
             
            import numpy as np
         | 
| @@ -13,18 +14,26 @@ from ott.solvers.linear.sinkhorn import Sinkhorn | |
| 13 14 | 
             
            from pandas import Series
         | 
| 14 15 | 
             
            from rich.progress import track
         | 
| 15 16 | 
             
            from scipy.sparse import issparse
         | 
| 16 | 
            -
            from scipy.spatial.distance import cosine
         | 
| 17 | 
            +
            from scipy.spatial.distance import cosine, mahalanobis
         | 
| 17 18 | 
             
            from scipy.special import gammaln
         | 
| 18 19 | 
             
            from scipy.stats import kendalltau, kstest, pearsonr, spearmanr
         | 
| 19 20 | 
             
            from sklearn.linear_model import LogisticRegression
         | 
| 20 21 | 
             
            from sklearn.metrics import pairwise_distances, r2_score
         | 
| 21 22 | 
             
            from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel
         | 
| 23 | 
            +
            from sklearn.neighbors import KernelDensity
         | 
| 22 24 | 
             
            from statsmodels.discrete.discrete_model import NegativeBinomialP
         | 
| 23 25 |  | 
| 24 26 | 
             
            if TYPE_CHECKING:
         | 
| 27 | 
            +
                from collections.abc import Callable
         | 
| 28 | 
            +
             | 
| 25 29 | 
             
                from anndata import AnnData
         | 
| 26 30 |  | 
| 27 31 |  | 
| 32 | 
            +
            class MeanVar(NamedTuple):
         | 
| 33 | 
            +
                mean: float
         | 
| 34 | 
            +
                variance: float
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 28 37 | 
             
            class Distance:
         | 
| 29 38 | 
             
                """Distance class, used to compute distances between groups of cells.
         | 
| 30 39 |  | 
| @@ -80,6 +89,11 @@ class Distance: | |
| 80 89 | 
             
                    Average of the classification probability of the perturbation for a binary classifier.
         | 
| 81 90 | 
             
                - "classifier_cp": classifier class projection
         | 
| 82 91 | 
             
                    Average of the class
         | 
| 92 | 
            +
                - "mean_var_distribution": Distance between mean-variance distributions between cells of 2 groups.
         | 
| 93 | 
            +
                   Mean square distance between the mean-variance distributions of cells from 2 groups using Kernel Density Estimation (KDE).
         | 
| 94 | 
            +
                - "mahalanobis": Mahalanobis distance between the means of cells from two groups.
         | 
| 95 | 
            +
                    It is originally used to measure distance between a point and a distribution.
         | 
| 96 | 
            +
                    in this context, it quantifies the difference between the mean profiles of a target group and a reference group.
         | 
| 83 97 |  | 
| 84 98 | 
             
                Attributes:
         | 
| 85 99 | 
             
                    metric: Name of distance metric.
         | 
| @@ -99,6 +113,7 @@ class Distance: | |
| 99 113 | 
             
                def __init__(
         | 
| 100 114 | 
             
                    self,
         | 
| 101 115 | 
             
                    metric: str = "edistance",
         | 
| 116 | 
            +
                    agg_fct: Callable = np.mean,
         | 
| 102 117 | 
             
                    layer_key: str = None,
         | 
| 103 118 | 
             
                    obsm_key: str = None,
         | 
| 104 119 | 
             
                    cell_wise_metric: str = "euclidean",
         | 
| @@ -106,37 +121,38 @@ class Distance: | |
| 106 121 | 
             
                    """Initialize Distance class.
         | 
| 107 122 |  | 
| 108 123 | 
             
                    Args:
         | 
| 109 | 
            -
                        metric: Distance metric to use. | 
| 124 | 
            +
                        metric: Distance metric to use.
         | 
| 125 | 
            +
                        agg_fct: Aggregation function to generate pseudobulk vectors.
         | 
| 110 126 | 
             
                        layer_key: Name of the counts layer containing raw counts to calculate distances for.
         | 
| 111 127 | 
             
                                          Mutually exclusive with 'obsm_key'.
         | 
| 112 | 
            -
                                           | 
| 128 | 
            +
                                          Is not used if `None`.
         | 
| 113 129 | 
             
                        obsm_key: Name of embedding in adata.obsm to use.
         | 
| 114 | 
            -
                                  Mutually exclusive with ' | 
| 115 | 
            -
                                  Defaults to None, but is set to "X_pca" if not set  | 
| 130 | 
            +
                                  Mutually exclusive with 'layer_key'.
         | 
| 131 | 
            +
                                  Defaults to None, but is set to "X_pca" if not explicitly set internally.
         | 
| 116 132 | 
             
                        cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells.
         | 
| 117 | 
            -
                                            Defaults to "euclidean".
         | 
| 118 133 | 
             
                    """
         | 
| 119 134 | 
             
                    metric_fct: AbstractDistance = None
         | 
| 135 | 
            +
                    self.aggregation_func = agg_fct
         | 
| 120 136 | 
             
                    if metric == "edistance":
         | 
| 121 137 | 
             
                        metric_fct = Edistance()
         | 
| 122 138 | 
             
                    elif metric == "euclidean":
         | 
| 123 | 
            -
                        metric_fct = EuclideanDistance()
         | 
| 139 | 
            +
                        metric_fct = EuclideanDistance(self.aggregation_func)
         | 
| 124 140 | 
             
                    elif metric == "root_mean_squared_error":
         | 
| 125 | 
            -
                        metric_fct = EuclideanDistance()
         | 
| 141 | 
            +
                        metric_fct = EuclideanDistance(self.aggregation_func)
         | 
| 126 142 | 
             
                    elif metric == "mse":
         | 
| 127 | 
            -
                        metric_fct = MeanSquaredDistance()
         | 
| 143 | 
            +
                        metric_fct = MeanSquaredDistance(self.aggregation_func)
         | 
| 128 144 | 
             
                    elif metric == "mean_absolute_error":
         | 
| 129 | 
            -
                        metric_fct = MeanAbsoluteDistance()
         | 
| 145 | 
            +
                        metric_fct = MeanAbsoluteDistance(self.aggregation_func)
         | 
| 130 146 | 
             
                    elif metric == "pearson_distance":
         | 
| 131 | 
            -
                        metric_fct = PearsonDistance()
         | 
| 147 | 
            +
                        metric_fct = PearsonDistance(self.aggregation_func)
         | 
| 132 148 | 
             
                    elif metric == "spearman_distance":
         | 
| 133 | 
            -
                        metric_fct = SpearmanDistance()
         | 
| 149 | 
            +
                        metric_fct = SpearmanDistance(self.aggregation_func)
         | 
| 134 150 | 
             
                    elif metric == "kendalltau_distance":
         | 
| 135 | 
            -
                        metric_fct = KendallTauDistance()
         | 
| 151 | 
            +
                        metric_fct = KendallTauDistance(self.aggregation_func)
         | 
| 136 152 | 
             
                    elif metric == "cosine_distance":
         | 
| 137 | 
            -
                        metric_fct = CosineDistance()
         | 
| 153 | 
            +
                        metric_fct = CosineDistance(self.aggregation_func)
         | 
| 138 154 | 
             
                    elif metric == "r2_distance":
         | 
| 139 | 
            -
                        metric_fct = R2ScoreDistance()
         | 
| 155 | 
            +
                        metric_fct = R2ScoreDistance(self.aggregation_func)
         | 
| 140 156 | 
             
                    elif metric == "mean_pairwise":
         | 
| 141 157 | 
             
                        metric_fct = MeanPairwiseDistance()
         | 
| 142 158 | 
             
                    elif metric == "mmd":
         | 
| @@ -155,14 +171,17 @@ class Distance: | |
| 155 171 | 
             
                        metric_fct = ClassifierProbaDistance()
         | 
| 156 172 | 
             
                    elif metric == "classifier_cp":
         | 
| 157 173 | 
             
                        metric_fct = ClassifierClassProjection()
         | 
| 174 | 
            +
                    elif metric == "mean_var_distribution":
         | 
| 175 | 
            +
                        metric_fct = MeanVarDistributionDistance()
         | 
| 176 | 
            +
                    elif metric == "mahalanobis":
         | 
| 177 | 
            +
                        metric_fct = MahalanobisDistance(self.aggregation_func)
         | 
| 158 178 | 
             
                    else:
         | 
| 159 179 | 
             
                        raise ValueError(f"Metric {metric} not recognized.")
         | 
| 160 180 | 
             
                    self.metric_fct = metric_fct
         | 
| 161 181 |  | 
| 162 182 | 
             
                    if layer_key and obsm_key:
         | 
| 163 183 | 
             
                        raise ValueError(
         | 
| 164 | 
            -
                            "Cannot use ' | 
| 165 | 
            -
                            "Please provide only one of the two keys."
         | 
| 184 | 
            +
                            "Cannot use 'layer_key' and 'obsm_key' at the same time.\n" "Please provide only one of the two keys."
         | 
| 166 185 | 
             
                        )
         | 
| 167 186 | 
             
                    if not layer_key and not obsm_key:
         | 
| 168 187 | 
             
                        obsm_key = "X_pca"
         | 
| @@ -195,37 +214,80 @@ class Distance: | |
| 195 214 | 
             
                        >>> D = Distance(X, Y)
         | 
| 196 215 | 
             
                    """
         | 
| 197 216 | 
             
                    if issparse(X):
         | 
| 198 | 
            -
                        X = X. | 
| 217 | 
            +
                        X = X.toarray()
         | 
| 199 218 | 
             
                    if issparse(Y):
         | 
| 200 | 
            -
                        Y = Y. | 
| 219 | 
            +
                        Y = Y.toarray()
         | 
| 201 220 |  | 
| 202 221 | 
             
                    if len(X) == 0 or len(Y) == 0:
         | 
| 203 222 | 
             
                        raise ValueError("Neither X nor Y can be empty.")
         | 
| 204 223 |  | 
| 205 224 | 
             
                    return self.metric_fct(X, Y, **kwargs)
         | 
| 206 225 |  | 
| 226 | 
            +
                def bootstrap(
         | 
| 227 | 
            +
                    self,
         | 
| 228 | 
            +
                    X: np.ndarray,
         | 
| 229 | 
            +
                    Y: np.ndarray,
         | 
| 230 | 
            +
                    *,
         | 
| 231 | 
            +
                    n_bootstrap: int = 100,
         | 
| 232 | 
            +
                    random_state: int = 0,
         | 
| 233 | 
            +
                    **kwargs,
         | 
| 234 | 
            +
                ) -> MeanVar:
         | 
| 235 | 
            +
                    """Bootstrap computation of mean and variance of the distance between vectors X and Y.
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    Args:
         | 
| 238 | 
            +
                        X: First vector of shape (n_samples, n_features).
         | 
| 239 | 
            +
                        Y: Second vector of shape (n_samples, n_features).
         | 
| 240 | 
            +
                        n_bootstrap: Number of bootstrap samples.
         | 
| 241 | 
            +
                        random_state: Random state for bootstrapping.
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    Returns:
         | 
| 244 | 
            +
                        MeanVar: Mean and variance of distance between X and Y.
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    Examples:
         | 
| 247 | 
            +
                        >>> import pertpy as pt
         | 
| 248 | 
            +
                        >>> adata = pt.dt.distance_example()
         | 
| 249 | 
            +
                        >>> Distance = pt.tools.Distance(metric="edistance")
         | 
| 250 | 
            +
                        >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
         | 
| 251 | 
            +
                        >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
         | 
| 252 | 
            +
                        >>> D = Distance.bootstrap(X, Y)
         | 
| 253 | 
            +
                    """
         | 
| 254 | 
            +
                    return self._bootstrap_mode(
         | 
| 255 | 
            +
                        X,
         | 
| 256 | 
            +
                        Y,
         | 
| 257 | 
            +
                        n_bootstraps=n_bootstrap,
         | 
| 258 | 
            +
                        random_state=random_state,
         | 
| 259 | 
            +
                        **kwargs,
         | 
| 260 | 
            +
                    )
         | 
| 261 | 
            +
             | 
| 207 262 | 
             
                def pairwise(
         | 
| 208 263 | 
             
                    self,
         | 
| 209 264 | 
             
                    adata: AnnData,
         | 
| 210 265 | 
             
                    groupby: str,
         | 
| 211 266 | 
             
                    groups: list[str] | None = None,
         | 
| 267 | 
            +
                    bootstrap: bool = False,
         | 
| 268 | 
            +
                    n_bootstrap: int = 100,
         | 
| 269 | 
            +
                    random_state: int = 0,
         | 
| 212 270 | 
             
                    show_progressbar: bool = True,
         | 
| 213 271 | 
             
                    n_jobs: int = -1,
         | 
| 214 272 | 
             
                    **kwargs,
         | 
| 215 | 
            -
                ) -> pd.DataFrame:
         | 
| 273 | 
            +
                ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
         | 
| 216 274 | 
             
                    """Get pairwise distances between groups of cells.
         | 
| 217 275 |  | 
| 218 276 | 
             
                    Args:
         | 
| 219 277 | 
             
                        adata: Annotated data matrix.
         | 
| 220 278 | 
             
                        groupby: Column name in adata.obs.
         | 
| 221 279 | 
             
                        groups: List of groups to compute pairwise distances for.
         | 
| 222 | 
            -
                                If None, uses all groups. | 
| 223 | 
            -
                         | 
| 280 | 
            +
                                If None, uses all groups.
         | 
| 281 | 
            +
                        bootstrap: Whether to bootstrap the distance.
         | 
| 282 | 
            +
                        n_bootstrap: Number of bootstrap samples.
         | 
| 283 | 
            +
                        random_state: Random state for bootstrapping.
         | 
| 284 | 
            +
                        show_progressbar: Whether to show progress bar.
         | 
| 224 285 | 
             
                        n_jobs: Number of cores to use. Defaults to -1 (all).
         | 
| 225 286 | 
             
                        kwargs: Additional keyword arguments passed to the metric function.
         | 
| 226 287 |  | 
| 227 288 | 
             
                    Returns:
         | 
| 228 289 | 
             
                        pd.DataFrame: Dataframe with pairwise distances.
         | 
| 290 | 
            +
                        tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of pairwise distances.
         | 
| 229 291 |  | 
| 230 292 | 
             
                    Examples:
         | 
| 231 293 | 
             
                        >>> import pertpy as pt
         | 
| @@ -236,6 +298,8 @@ class Distance: | |
| 236 298 | 
             
                    groups = adata.obs[groupby].unique() if groups is None else groups
         | 
| 237 299 | 
             
                    grouping = adata.obs[groupby].copy()
         | 
| 238 300 | 
             
                    df = pd.DataFrame(index=groups, columns=groups, dtype=float)
         | 
| 301 | 
            +
                    if bootstrap:
         | 
| 302 | 
            +
                        df_var = pd.DataFrame(index=groups, columns=groups, dtype=float)
         | 
| 239 303 | 
             
                    fct = track if show_progressbar else lambda iterable: iterable
         | 
| 240 304 |  | 
| 241 305 | 
             
                    # Some metrics are able to handle precomputed distances. This means that
         | 
| @@ -251,16 +315,29 @@ class Distance: | |
| 251 315 | 
             
                        for index_x, group_x in enumerate(fct(groups)):
         | 
| 252 316 | 
             
                            idx_x = grouping == group_x
         | 
| 253 317 | 
             
                            for group_y in groups[index_x:]:  # type: ignore
         | 
| 254 | 
            -
                                 | 
| 255 | 
            -
             | 
| 318 | 
            +
                                # subset the pairwise distance matrix to the two groups
         | 
| 319 | 
            +
                                idx_y = grouping == group_y
         | 
| 320 | 
            +
                                sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
         | 
| 321 | 
            +
                                sub_idx = grouping[idx_x | idx_y] == group_x
         | 
| 322 | 
            +
                                if not bootstrap:
         | 
| 323 | 
            +
                                    if group_x == group_y:
         | 
| 324 | 
            +
                                        dist = 0.0
         | 
| 325 | 
            +
                                    else:
         | 
| 326 | 
            +
                                        dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
         | 
| 327 | 
            +
                                    df.loc[group_x, group_y] = dist
         | 
| 328 | 
            +
                                    df.loc[group_y, group_x] = dist
         | 
| 329 | 
            +
             | 
| 256 330 | 
             
                                else:
         | 
| 257 | 
            -
                                     | 
| 258 | 
            -
             | 
| 259 | 
            -
             | 
| 260 | 
            -
             | 
| 261 | 
            -
             | 
| 262 | 
            -
             | 
| 263 | 
            -
             | 
| 331 | 
            +
                                    bootstrap_output = self._bootstrap_mode_precomputed(
         | 
| 332 | 
            +
                                        sub_pwd,
         | 
| 333 | 
            +
                                        sub_idx,
         | 
| 334 | 
            +
                                        n_bootstraps=n_bootstrap,
         | 
| 335 | 
            +
                                        random_state=random_state,
         | 
| 336 | 
            +
                                        **kwargs,
         | 
| 337 | 
            +
                                    )
         | 
| 338 | 
            +
                                    # In the bootstrap case, distance of group to itself is a mean and can be non-zero
         | 
| 339 | 
            +
                                    df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
         | 
| 340 | 
            +
                                    df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
         | 
| 264 341 | 
             
                    else:
         | 
| 265 342 | 
             
                        if self.layer_key:
         | 
| 266 343 | 
             
                            embedding = adata.layers[self.layer_key]
         | 
| @@ -269,18 +346,39 @@ class Distance: | |
| 269 346 | 
             
                        for index_x, group_x in enumerate(fct(groups)):
         | 
| 270 347 | 
             
                            cells_x = embedding[grouping == group_x].copy()
         | 
| 271 348 | 
             
                            for group_y in groups[index_x:]:  # type: ignore
         | 
| 272 | 
            -
                                 | 
| 273 | 
            -
             | 
| 349 | 
            +
                                cells_y = embedding[grouping == group_y].copy()
         | 
| 350 | 
            +
                                if not bootstrap:
         | 
| 351 | 
            +
                                    # By distance axiom, the distance between a group and itself is 0
         | 
| 352 | 
            +
                                    dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                                    df.loc[group_x, group_y] = dist
         | 
| 355 | 
            +
                                    df.loc[group_y, group_x] = dist
         | 
| 274 356 | 
             
                                else:
         | 
| 275 | 
            -
                                     | 
| 276 | 
            -
             | 
| 277 | 
            -
             | 
| 278 | 
            -
             | 
| 357 | 
            +
                                    bootstrap_output = self.bootstrap(
         | 
| 358 | 
            +
                                        cells_x,
         | 
| 359 | 
            +
                                        cells_y,
         | 
| 360 | 
            +
                                        n_bootstrap=n_bootstrap,
         | 
| 361 | 
            +
                                        random_state=random_state,
         | 
| 362 | 
            +
                                        **kwargs,
         | 
| 363 | 
            +
                                    )
         | 
| 364 | 
            +
                                    # In the bootstrap case, distance of group to itself is a mean and can be non-zero
         | 
| 365 | 
            +
                                    df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
         | 
| 366 | 
            +
                                    df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
         | 
| 367 | 
            +
             | 
| 279 368 | 
             
                    df.index.name = groupby
         | 
| 280 369 | 
             
                    df.columns.name = groupby
         | 
| 281 370 | 
             
                    df.name = f"pairwise {self.metric}"
         | 
| 282 371 |  | 
| 283 | 
            -
                     | 
| 372 | 
            +
                    if not bootstrap:
         | 
| 373 | 
            +
                        return df
         | 
| 374 | 
            +
                    else:
         | 
| 375 | 
            +
                        df = df.fillna(0)
         | 
| 376 | 
            +
                        df_var.index.name = groupby
         | 
| 377 | 
            +
                        df_var.columns.name = groupby
         | 
| 378 | 
            +
                        df_var = df_var.fillna(0)
         | 
| 379 | 
            +
                        df_var.name = f"pairwise {self.metric} variance"
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                        return df, df_var
         | 
| 284 382 |  | 
| 285 383 | 
             
                def onesided_distances(
         | 
| 286 384 | 
             
                    self,
         | 
| @@ -288,10 +386,13 @@ class Distance: | |
| 288 386 | 
             
                    groupby: str,
         | 
| 289 387 | 
             
                    selected_group: str | None = None,
         | 
| 290 388 | 
             
                    groups: list[str] | None = None,
         | 
| 389 | 
            +
                    bootstrap: bool = False,
         | 
| 390 | 
            +
                    n_bootstrap: int = 100,
         | 
| 391 | 
            +
                    random_state: int = 0,
         | 
| 291 392 | 
             
                    show_progressbar: bool = True,
         | 
| 292 393 | 
             
                    n_jobs: int = -1,
         | 
| 293 394 | 
             
                    **kwargs,
         | 
| 294 | 
            -
                ) -> pd.DataFrame:
         | 
| 395 | 
            +
                ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
         | 
| 295 396 | 
             
                    """Get distances between one selected cell group and the remaining other cell groups.
         | 
| 296 397 |  | 
| 297 398 | 
             
                    Args:
         | 
| @@ -299,13 +400,18 @@ class Distance: | |
| 299 400 | 
             
                        groupby: Column name in adata.obs.
         | 
| 300 401 | 
             
                        selected_group: Group to compute pairwise distances to all other.
         | 
| 301 402 | 
             
                        groups: List of groups to compute distances to selected_group for.
         | 
| 302 | 
            -
                                If None, uses all groups. | 
| 303 | 
            -
                         | 
| 403 | 
            +
                                If None, uses all groups.
         | 
| 404 | 
            +
                        bootstrap: Whether to bootstrap the distance.
         | 
| 405 | 
            +
                        n_bootstrap: Number of bootstrap samples.
         | 
| 406 | 
            +
                        random_state: Random state for bootstrapping.
         | 
| 407 | 
            +
                        show_progressbar: Whether to show progress bar.
         | 
| 304 408 | 
             
                        n_jobs: Number of cores to use. Defaults to -1 (all).
         | 
| 305 409 | 
             
                        kwargs: Additional keyword arguments passed to the metric function.
         | 
| 306 410 |  | 
| 307 411 | 
             
                    Returns:
         | 
| 308 412 | 
             
                        pd.DataFrame: Dataframe with distances of groups to selected_group.
         | 
| 413 | 
            +
                        tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group.
         | 
| 414 | 
            +
             | 
| 309 415 |  | 
| 310 416 | 
             
                    Examples:
         | 
| 311 417 | 
             
                        >>> import pertpy as pt
         | 
| @@ -314,20 +420,30 @@ class Distance: | |
| 314 420 | 
             
                        >>> pairwise_df = Distance.onesided_distances(adata, groupby="perturbation", selected_group="control")
         | 
| 315 421 | 
             
                    """
         | 
| 316 422 | 
             
                    if self.metric == "classifier_cp":
         | 
| 423 | 
            +
                        if bootstrap:
         | 
| 424 | 
            +
                            raise NotImplementedError("Currently, ClassifierClassProjection does not support bootstrapping.")
         | 
| 317 425 | 
             
                        return self.metric_fct.onesided_distances(  # type: ignore
         | 
| 318 | 
            -
                            adata, | 
| 426 | 
            +
                            adata,
         | 
| 427 | 
            +
                            groupby,
         | 
| 428 | 
            +
                            selected_group,
         | 
| 429 | 
            +
                            groups,
         | 
| 430 | 
            +
                            show_progressbar,
         | 
| 431 | 
            +
                            n_jobs,
         | 
| 432 | 
            +
                            **kwargs,
         | 
| 319 433 | 
             
                        )
         | 
| 320 434 |  | 
| 321 435 | 
             
                    groups = adata.obs[groupby].unique() if groups is None else groups
         | 
| 322 436 | 
             
                    grouping = adata.obs[groupby].copy()
         | 
| 323 437 | 
             
                    df = pd.Series(index=groups, dtype=float)
         | 
| 438 | 
            +
                    if bootstrap:
         | 
| 439 | 
            +
                        df_var = pd.Series(index=groups, dtype=float)
         | 
| 324 440 | 
             
                    fct = track if show_progressbar else lambda iterable: iterable
         | 
| 325 441 |  | 
| 326 442 | 
             
                    # Some metrics are able to handle precomputed distances. This means that
         | 
| 327 443 | 
             
                    # the pairwise distances between all cells are computed once and then
         | 
| 328 444 | 
             
                    # passed to the metric function. This is much faster than computing the
         | 
| 329 445 | 
             
                    # pairwise distances for each group separately. Other metrics are not
         | 
| 330 | 
            -
                    # able to handle precomputed distances such as the  | 
| 446 | 
            +
                    # able to handle precomputed distances such as the PseudobulkDistance.
         | 
| 331 447 | 
             
                    if self.metric_fct.accepts_precomputed:
         | 
| 332 448 | 
             
                        # Precompute the pairwise distances if needed
         | 
| 333 449 | 
             
                        if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp.keys():
         | 
| @@ -337,14 +453,25 @@ class Distance: | |
| 337 453 | 
             
                            idx_x = grouping == group_x
         | 
| 338 454 | 
             
                            group_y = selected_group
         | 
| 339 455 | 
             
                            if group_x == group_y:
         | 
| 340 | 
            -
                                 | 
| 456 | 
            +
                                df.loc[group_x] = 0.0  # by distance axiom
         | 
| 341 457 | 
             
                            else:
         | 
| 342 458 | 
             
                                idx_y = grouping == group_y
         | 
| 343 459 | 
             
                                # subset the pairwise distance matrix to the two groups
         | 
| 344 460 | 
             
                                sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
         | 
| 345 461 | 
             
                                sub_idx = grouping[idx_x | idx_y] == group_x
         | 
| 346 | 
            -
                                 | 
| 347 | 
            -
             | 
| 462 | 
            +
                                if not bootstrap:
         | 
| 463 | 
            +
                                    dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
         | 
| 464 | 
            +
                                    df.loc[group_x] = dist
         | 
| 465 | 
            +
                                else:
         | 
| 466 | 
            +
                                    bootstrap_output = self._bootstrap_mode_precomputed(
         | 
| 467 | 
            +
                                        sub_pwd,
         | 
| 468 | 
            +
                                        sub_idx,
         | 
| 469 | 
            +
                                        n_bootstraps=n_bootstrap,
         | 
| 470 | 
            +
                                        random_state=random_state,
         | 
| 471 | 
            +
                                        **kwargs,
         | 
| 472 | 
            +
                                    )
         | 
| 473 | 
            +
                                    df.loc[group_x] = bootstrap_output.mean
         | 
| 474 | 
            +
                                    df_var.loc[group_x] = bootstrap_output.variance
         | 
| 348 475 | 
             
                    else:
         | 
| 349 476 | 
             
                        if self.layer_key:
         | 
| 350 477 | 
             
                            embedding = adata.layers[self.layer_key]
         | 
| @@ -353,15 +480,32 @@ class Distance: | |
| 353 480 | 
             
                        for group_x in fct(groups):
         | 
| 354 481 | 
             
                            cells_x = embedding[grouping == group_x].copy()
         | 
| 355 482 | 
             
                            group_y = selected_group
         | 
| 356 | 
            -
                             | 
| 357 | 
            -
             | 
| 483 | 
            +
                            cells_y = embedding[grouping == group_y].copy()
         | 
| 484 | 
            +
                            if not bootstrap:
         | 
| 485 | 
            +
                                # By distance axiom, the distance between a group and itself is 0
         | 
| 486 | 
            +
                                dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
         | 
| 487 | 
            +
                                df.loc[group_x] = dist
         | 
| 358 488 | 
             
                            else:
         | 
| 359 | 
            -
                                 | 
| 360 | 
            -
             | 
| 361 | 
            -
             | 
| 489 | 
            +
                                bootstrap_output = self.bootstrap(
         | 
| 490 | 
            +
                                    cells_x,
         | 
| 491 | 
            +
                                    cells_y,
         | 
| 492 | 
            +
                                    n_bootstrap=n_bootstrap,
         | 
| 493 | 
            +
                                    random_state=random_state,
         | 
| 494 | 
            +
                                    **kwargs,
         | 
| 495 | 
            +
                                )
         | 
| 496 | 
            +
                                # In the bootstrap case, distance of group to itself is a mean and can be non-zero
         | 
| 497 | 
            +
                                df.loc[group_x] = bootstrap_output.mean
         | 
| 498 | 
            +
                                df_var.loc[group_x] = bootstrap_output.variance
         | 
| 362 499 | 
             
                    df.index.name = groupby
         | 
| 363 500 | 
             
                    df.name = f"{self.metric} to {selected_group}"
         | 
| 364 | 
            -
                     | 
| 501 | 
            +
                    if not bootstrap:
         | 
| 502 | 
            +
                        return df
         | 
| 503 | 
            +
                    else:
         | 
| 504 | 
            +
                        df_var.index.name = groupby
         | 
| 505 | 
            +
                        df_var = df_var.fillna(0)
         | 
| 506 | 
            +
                        df_var.name = f"pairwise {self.metric} variance to {selected_group}"
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                        return df, df_var
         | 
| 365 509 |  | 
| 366 510 | 
             
                def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None:
         | 
| 367 511 | 
             
                    """Precompute pairwise distances between all cells, writes to adata.obsp.
         | 
| @@ -387,6 +531,77 @@ class Distance: | |
| 387 531 | 
             
                    pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs)
         | 
| 388 532 | 
             
                    adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd
         | 
| 389 533 |  | 
| 534 | 
            +
                def compare_distance(
         | 
| 535 | 
            +
                    self,
         | 
| 536 | 
            +
                    pert: np.ndarray,
         | 
| 537 | 
            +
                    pred: np.ndarray,
         | 
| 538 | 
            +
                    ctrl: np.ndarray,
         | 
| 539 | 
            +
                    mode: Literal["simple", "scaled"] = "simple",
         | 
| 540 | 
            +
                    fit_to_pert_and_ctrl: bool = False,
         | 
| 541 | 
            +
                    **kwargs,
         | 
| 542 | 
            +
                ) -> float:
         | 
| 543 | 
            +
                    """Compute the score of simulating a perturbation.
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    Args:
         | 
| 546 | 
            +
                        pert: Real perturbed data.
         | 
| 547 | 
            +
                        pred: Simulated perturbed data.
         | 
| 548 | 
            +
                        ctrl: Control data
         | 
| 549 | 
            +
                        mode: Mode to use.
         | 
| 550 | 
            +
                        fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`.
         | 
| 551 | 
            +
                        kwargs: Additional keyword arguments passed to the metric function.
         | 
| 552 | 
            +
                    """
         | 
| 553 | 
            +
                    if mode == "simple":
         | 
| 554 | 
            +
                        pass  # nothing to be done
         | 
| 555 | 
            +
                    elif mode == "scaled":
         | 
| 556 | 
            +
                        from sklearn.preprocessing import MinMaxScaler
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                        scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl)
         | 
| 559 | 
            +
                        pred = scaler.transform(pred)
         | 
| 560 | 
            +
                        pert = scaler.transform(pert)
         | 
| 561 | 
            +
                    else:
         | 
| 562 | 
            +
                        raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.")
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                    d1 = self.metric_fct(pert, pred, **kwargs)
         | 
| 565 | 
            +
                    d2 = self.metric_fct(ctrl, pred, **kwargs)
         | 
| 566 | 
            +
                    return d1 / d2
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
         | 
| 569 | 
            +
                    rng = np.random.default_rng(random_state)
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                    distances = []
         | 
| 572 | 
            +
                    for _ in range(n_bootstraps):
         | 
| 573 | 
            +
                        X_bootstrapped = X[rng.choice(a=X.shape[0], size=X.shape[0], replace=True)]
         | 
| 574 | 
            +
                        Y_bootstrapped = Y[rng.choice(a=Y.shape[0], size=X.shape[0], replace=True)]
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                        distance = self(X_bootstrapped, Y_bootstrapped, **kwargs)
         | 
| 577 | 
            +
                        distances.append(distance)
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                    mean = np.mean(distances)
         | 
| 580 | 
            +
                    variance = np.var(distances)
         | 
| 581 | 
            +
                    return MeanVar(mean=mean, variance=variance)
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar:
         | 
| 584 | 
            +
                    rng = np.random.default_rng(random_state)
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                    distances = []
         | 
| 587 | 
            +
                    for _ in range(n_bootstraps):
         | 
| 588 | 
            +
                        # To maintain the number of cells for both groups (whatever balancing they may have),
         | 
| 589 | 
            +
                        # we sample the positive and negative indices separately
         | 
| 590 | 
            +
                        bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True)
         | 
| 591 | 
            +
                        bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True)
         | 
| 592 | 
            +
                        bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx])
         | 
| 593 | 
            +
                        bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx)
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                        bootstrap_sub_idx = sub_idx[bootstrap_idx]
         | 
| 596 | 
            +
                        bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs]
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                        distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs)
         | 
| 599 | 
            +
                        distances.append(distance)
         | 
| 600 | 
            +
             | 
| 601 | 
            +
                    mean = np.mean(distances)
         | 
| 602 | 
            +
                    variance = np.var(distances)
         | 
| 603 | 
            +
                    return MeanVar(mean=mean, variance=variance)
         | 
| 604 | 
            +
             | 
| 390 605 |  | 
| 391 606 | 
             
            class AbstractDistance(ABC):
         | 
| 392 607 | 
             
                """Abstract class of distance metrics between two sets of vectors."""
         | 
| @@ -500,12 +715,17 @@ class WassersteinDistance(AbstractDistance): | |
| 500 715 | 
             
            class EuclideanDistance(AbstractDistance):
         | 
| 501 716 | 
             
                """Euclidean distance between pseudobulk vectors."""
         | 
| 502 717 |  | 
| 503 | 
            -
                def __init__(self) -> None:
         | 
| 718 | 
            +
                def __init__(self, aggregation_func: Callable = np.mean) -> None:
         | 
| 504 719 | 
             
                    super().__init__()
         | 
| 505 720 | 
             
                    self.accepts_precomputed = False
         | 
| 721 | 
            +
                    self.aggregation_func = aggregation_func
         | 
| 506 722 |  | 
| 507 723 | 
             
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 508 | 
            -
                    return np.linalg.norm( | 
| 724 | 
            +
                    return np.linalg.norm(
         | 
| 725 | 
            +
                        self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
         | 
| 726 | 
            +
                        ord=2,
         | 
| 727 | 
            +
                        **kwargs,
         | 
| 728 | 
            +
                    )
         | 
| 509 729 |  | 
| 510 730 | 
             
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 511 731 | 
             
                    raise NotImplementedError("EuclideanDistance cannot be called on a pairwise distance matrix.")
         | 
| @@ -514,12 +734,21 @@ class EuclideanDistance(AbstractDistance): | |
| 514 734 | 
             
            class MeanSquaredDistance(AbstractDistance):
         | 
| 515 735 | 
             
                """Mean squared distance between pseudobulk vectors."""
         | 
| 516 736 |  | 
| 517 | 
            -
                def __init__(self) -> None:
         | 
| 737 | 
            +
                def __init__(self, aggregation_func: Callable = np.mean) -> None:
         | 
| 518 738 | 
             
                    super().__init__()
         | 
| 519 739 | 
             
                    self.accepts_precomputed = False
         | 
| 740 | 
            +
                    self.aggregation_func = aggregation_func
         | 
| 520 741 |  | 
| 521 742 | 
             
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 522 | 
            -
                    return  | 
| 743 | 
            +
                    return (
         | 
| 744 | 
            +
                        np.linalg.norm(
         | 
| 745 | 
            +
                            self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
         | 
| 746 | 
            +
                            ord=2,
         | 
| 747 | 
            +
                            **kwargs,
         | 
| 748 | 
            +
                        )
         | 
| 749 | 
            +
                        ** 2
         | 
| 750 | 
            +
                        / X.shape[1]
         | 
| 751 | 
            +
                    )
         | 
| 523 752 |  | 
| 524 753 | 
             
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 525 754 | 
             
                    raise NotImplementedError("MeanSquaredDistance cannot be called on a pairwise distance matrix.")
         | 
| @@ -528,12 +757,20 @@ class MeanSquaredDistance(AbstractDistance): | |
| 528 757 | 
             
            class MeanAbsoluteDistance(AbstractDistance):
         | 
| 529 758 | 
             
                """Absolute (Norm-1) distance between pseudobulk vectors."""
         | 
| 530 759 |  | 
| 531 | 
            -
                def __init__(self) -> None:
         | 
| 760 | 
            +
                def __init__(self, aggregation_func: Callable = np.mean) -> None:
         | 
| 532 761 | 
             
                    super().__init__()
         | 
| 533 762 | 
             
                    self.accepts_precomputed = False
         | 
| 763 | 
            +
                    self.aggregation_func = aggregation_func
         | 
| 534 764 |  | 
| 535 765 | 
             
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 536 | 
            -
                    return  | 
| 766 | 
            +
                    return (
         | 
| 767 | 
            +
                        np.linalg.norm(
         | 
| 768 | 
            +
                            self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0),
         | 
| 769 | 
            +
                            ord=1,
         | 
| 770 | 
            +
                            **kwargs,
         | 
| 771 | 
            +
                        )
         | 
| 772 | 
            +
                        / X.shape[1]
         | 
| 773 | 
            +
                    )
         | 
| 537 774 |  | 
| 538 775 | 
             
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 539 776 | 
             
                    raise NotImplementedError("MeanAbsoluteDistance cannot be called on a pairwise distance matrix.")
         | 
| @@ -558,12 +795,13 @@ class MeanPairwiseDistance(AbstractDistance): | |
| 558 795 | 
             
            class PearsonDistance(AbstractDistance):
         | 
| 559 796 | 
             
                """Pearson distance between pseudobulk vectors."""
         | 
| 560 797 |  | 
| 561 | 
            -
                def __init__(self) -> None:
         | 
| 798 | 
            +
                def __init__(self, aggregation_func: Callable = np.mean) -> None:
         | 
| 562 799 | 
             
                    super().__init__()
         | 
| 563 800 | 
             
                    self.accepts_precomputed = False
         | 
| 801 | 
            +
                    self.aggregation_func = aggregation_func
         | 
| 564 802 |  | 
| 565 803 | 
             
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 566 | 
            -
                    return 1 - pearsonr( | 
| 804 | 
            +
                    return 1 - pearsonr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
         | 
| 567 805 |  | 
| 568 806 | 
             
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 569 807 | 
             
                    raise NotImplementedError("PearsonDistance cannot be called on a pairwise distance matrix.")
         | 
| @@ -572,12 +810,13 @@ class PearsonDistance(AbstractDistance): | |
| 572 810 | 
             
            class SpearmanDistance(AbstractDistance):
         | 
| 573 811 | 
             
                """Spearman distance between pseudobulk vectors."""
         | 
| 574 812 |  | 
| 575 | 
            -
                def __init__(self) -> None:
         | 
| 813 | 
            +
                def __init__(self, aggregation_func: Callable = np.mean) -> None:
         | 
| 576 814 | 
             
                    super().__init__()
         | 
| 577 815 | 
             
                    self.accepts_precomputed = False
         | 
| 816 | 
            +
                    self.aggregation_func = aggregation_func
         | 
| 578 817 |  | 
| 579 818 | 
             
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 580 | 
            -
                    return 1 - spearmanr( | 
| 819 | 
            +
                    return 1 - spearmanr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0]
         | 
| 581 820 |  | 
| 582 821 | 
             
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 583 822 | 
             
                    raise NotImplementedError("SpearmanDistance cannot be called on a pairwise distance matrix.")
         | 
| @@ -586,12 +825,13 @@ class SpearmanDistance(AbstractDistance): | |
| 586 825 | 
             
            class KendallTauDistance(AbstractDistance):
         | 
| 587 826 | 
             
                """Kendall-tau distance between pseudobulk vectors."""
         | 
| 588 827 |  | 
| 589 | 
            -
                def __init__(self) -> None:
         | 
| 828 | 
            +
                def __init__(self, aggregation_func: Callable = np.mean) -> None:
         | 
| 590 829 | 
             
                    super().__init__()
         | 
| 591 830 | 
             
                    self.accepts_precomputed = False
         | 
| 831 | 
            +
                    self.aggregation_func = aggregation_func
         | 
| 592 832 |  | 
| 593 833 | 
             
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 594 | 
            -
                    x, y =  | 
| 834 | 
            +
                    x, y = self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)
         | 
| 595 835 | 
             
                    n = len(x)
         | 
| 596 836 | 
             
                    tau_corr = kendalltau(x, y).statistic
         | 
| 597 837 | 
             
                    tau_dist = (1 - tau_corr) * n * (n - 1) / 4
         | 
| @@ -604,12 +844,13 @@ class KendallTauDistance(AbstractDistance): | |
| 604 844 | 
             
            class CosineDistance(AbstractDistance):
         | 
| 605 845 | 
             
                """Cosine distance between pseudobulk vectors."""
         | 
| 606 846 |  | 
| 607 | 
            -
                def __init__(self) -> None:
         | 
| 847 | 
            +
                def __init__(self, aggregation_func: Callable = np.mean) -> None:
         | 
| 608 848 | 
             
                    super().__init__()
         | 
| 609 849 | 
             
                    self.accepts_precomputed = False
         | 
| 850 | 
            +
                    self.aggregation_func = aggregation_func
         | 
| 610 851 |  | 
| 611 852 | 
             
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 612 | 
            -
                    return cosine( | 
| 853 | 
            +
                    return cosine(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
         | 
| 613 854 |  | 
| 614 855 | 
             
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 615 856 | 
             
                    raise NotImplementedError("CosineDistance cannot be called on a pairwise distance matrix.")
         | 
| @@ -620,12 +861,13 @@ class R2ScoreDistance(AbstractDistance): | |
| 620 861 |  | 
| 621 862 | 
             
                # NOTE: This is not a distance metric but a similarity metric.
         | 
| 622 863 |  | 
| 623 | 
            -
                def __init__(self) -> None:
         | 
| 864 | 
            +
                def __init__(self, aggregation_func: Callable = np.mean) -> None:
         | 
| 624 865 | 
             
                    super().__init__()
         | 
| 625 866 | 
             
                    self.accepts_precomputed = False
         | 
| 867 | 
            +
                    self.aggregation_func = aggregation_func
         | 
| 626 868 |  | 
| 627 869 | 
             
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 628 | 
            -
                    return 1 - r2_score( | 
| 870 | 
            +
                    return 1 - r2_score(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))
         | 
| 629 871 |  | 
| 630 872 | 
             
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 631 873 | 
             
                    raise NotImplementedError("R2ScoreDistance cannot be called on a pairwise distance matrix.")
         | 
| @@ -834,6 +1076,7 @@ class ClassifierClassProjection(AbstractDistance): | |
| 834 1076 | 
             
                    Similar to the parent function, the returned dataframe contains only the specified groups.
         | 
| 835 1077 | 
             
                    """
         | 
| 836 1078 | 
             
                    groups = adata.obs[groupby].unique() if groups is None else groups
         | 
| 1079 | 
            +
                    fct = track if show_progressbar else lambda iterable: iterable
         | 
| 837 1080 |  | 
| 838 1081 | 
             
                    X = adata[adata.obs[groupby] != selected_group].X
         | 
| 839 1082 | 
             
                    labels = adata[adata.obs[groupby] != selected_group].obs[groupby].values
         | 
| @@ -844,7 +1087,8 @@ class ClassifierClassProjection(AbstractDistance): | |
| 844 1087 | 
             
                    test_probas = reg.predict_proba(Y)
         | 
| 845 1088 |  | 
| 846 1089 | 
             
                    df = pd.Series(index=groups, dtype=float)
         | 
| 847 | 
            -
             | 
| 1090 | 
            +
             | 
| 1091 | 
            +
                    for group in fct(groups):
         | 
| 848 1092 | 
             
                        if group == selected_group:
         | 
| 849 1093 | 
             
                            df.loc[group] = 0
         | 
| 850 1094 | 
             
                        else:
         | 
| @@ -857,3 +1101,95 @@ class ClassifierClassProjection(AbstractDistance): | |
| 857 1101 |  | 
| 858 1102 | 
             
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 859 1103 | 
             
                    raise NotImplementedError("ClassifierClassProjection cannot be called on a pairwise distance matrix.")
         | 
| 1104 | 
            +
             | 
| 1105 | 
            +
             | 
| 1106 | 
            +
            class MeanVarDistributionDistance(AbstractDistance):
         | 
| 1107 | 
            +
                """Distance between mean-var distributions of gene expression."""
         | 
| 1108 | 
            +
             | 
| 1109 | 
            +
                def __init__(self) -> None:
         | 
| 1110 | 
            +
                    super().__init__()
         | 
| 1111 | 
            +
                    self.accepts_precomputed = False
         | 
| 1112 | 
            +
             | 
| 1113 | 
            +
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 1114 | 
            +
                    """Difference of mean-var distributions in 2 matrices.
         | 
| 1115 | 
            +
             | 
| 1116 | 
            +
                    Args:
         | 
| 1117 | 
            +
                        X: Normalized and log transformed cells x genes count matrix.
         | 
| 1118 | 
            +
                        Y: Normalized and log transformed cells x genes count matrix.
         | 
| 1119 | 
            +
                    """
         | 
| 1120 | 
            +
             | 
| 1121 | 
            +
                    def _mean_var(x, log: bool = False):
         | 
| 1122 | 
            +
                        mean = np.mean(x, axis=0)
         | 
| 1123 | 
            +
                        var = np.var(x, axis=0)
         | 
| 1124 | 
            +
                        positive = mean > 0
         | 
| 1125 | 
            +
                        mean = mean[positive]
         | 
| 1126 | 
            +
                        var = var[positive]
         | 
| 1127 | 
            +
                        if log:
         | 
| 1128 | 
            +
                            mean = np.log(mean)
         | 
| 1129 | 
            +
                            var = np.log(var)
         | 
| 1130 | 
            +
                        return mean, var
         | 
| 1131 | 
            +
             | 
| 1132 | 
            +
                    def _prep_kde_data(x, y):
         | 
| 1133 | 
            +
                        return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
         | 
| 1134 | 
            +
             | 
| 1135 | 
            +
                    def _grid_points(d, n_points=100):
         | 
| 1136 | 
            +
                        # Make grid, add 1 bin on lower/upper end to get final n_points
         | 
| 1137 | 
            +
                        d_min = d.min()
         | 
| 1138 | 
            +
                        d_max = d.max()
         | 
| 1139 | 
            +
                        # Compute bin size
         | 
| 1140 | 
            +
                        d_bin = (d_max - d_min) / (n_points - 2)
         | 
| 1141 | 
            +
                        d_min = d_min - d_bin
         | 
| 1142 | 
            +
                        d_max = d_max + d_bin
         | 
| 1143 | 
            +
                        return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
         | 
| 1144 | 
            +
             | 
| 1145 | 
            +
                    def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
         | 
| 1146 | 
            +
                        # the thread_count is determined using the factor 0.875 as recommended here:
         | 
| 1147 | 
            +
                        # https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
         | 
| 1148 | 
            +
                        with multiprocessing.Pool(thread_count) as p:
         | 
| 1149 | 
            +
                            return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
         | 
| 1150 | 
            +
             | 
| 1151 | 
            +
                    def _kde_eval(d, grid):
         | 
| 1152 | 
            +
                        # Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
         | 
| 1153 | 
            +
                        # can not be compared well on regions further away from the data as they are -inf
         | 
| 1154 | 
            +
                        kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
         | 
| 1155 | 
            +
                        return _parallel_score_samples(kde, grid)
         | 
| 1156 | 
            +
             | 
| 1157 | 
            +
                    mean_x, var_x = _mean_var(X, log=True)
         | 
| 1158 | 
            +
                    mean_y, var_y = _mean_var(Y, log=True)
         | 
| 1159 | 
            +
             | 
| 1160 | 
            +
                    x = _prep_kde_data(mean_x, var_x)
         | 
| 1161 | 
            +
                    y = _prep_kde_data(mean_y, var_y)
         | 
| 1162 | 
            +
             | 
| 1163 | 
            +
                    # Gridpoints to eval KDE on
         | 
| 1164 | 
            +
                    mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
         | 
| 1165 | 
            +
                    var_grid = _grid_points(np.concatenate([var_x, var_y]))
         | 
| 1166 | 
            +
                    grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
         | 
| 1167 | 
            +
             | 
| 1168 | 
            +
                    kde_x = _kde_eval(x, grid)
         | 
| 1169 | 
            +
                    kde_y = _kde_eval(y, grid)
         | 
| 1170 | 
            +
             | 
| 1171 | 
            +
                    kde_diff = ((kde_x - kde_y) ** 2).mean()
         | 
| 1172 | 
            +
             | 
| 1173 | 
            +
                    return kde_diff
         | 
| 1174 | 
            +
             | 
| 1175 | 
            +
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 1176 | 
            +
                    raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
         | 
| 1177 | 
            +
             | 
| 1178 | 
            +
             | 
| 1179 | 
            +
            class MahalanobisDistance(AbstractDistance):
         | 
| 1180 | 
            +
                """Mahalanobis distance between pseudobulk vectors."""
         | 
| 1181 | 
            +
             | 
| 1182 | 
            +
                def __init__(self, aggregation_func: Callable = np.mean) -> None:
         | 
| 1183 | 
            +
                    super().__init__()
         | 
| 1184 | 
            +
                    self.accepts_precomputed = False
         | 
| 1185 | 
            +
                    self.aggregation_func = aggregation_func
         | 
| 1186 | 
            +
             | 
| 1187 | 
            +
                def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
         | 
| 1188 | 
            +
                    return mahalanobis(
         | 
| 1189 | 
            +
                        self.aggregation_func(X, axis=0),
         | 
| 1190 | 
            +
                        self.aggregation_func(Y, axis=0),
         | 
| 1191 | 
            +
                        np.linalg.inv(np.cov(X.T)),
         | 
| 1192 | 
            +
                    )
         | 
| 1193 | 
            +
             | 
| 1194 | 
            +
                def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
         | 
| 1195 | 
            +
                    raise NotImplementedError("Mahalanobis cannot be called on a pairwise distance matrix.")
         |