pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
 - pertpy/data/__init__.py +66 -1
 - pertpy/data/_dataloader.py +28 -26
 - pertpy/data/_datasets.py +261 -92
 - pertpy/metadata/__init__.py +6 -0
 - pertpy/metadata/_cell_line.py +795 -0
 - pertpy/metadata/_compound.py +128 -0
 - pertpy/metadata/_drug.py +238 -0
 - pertpy/metadata/_look_up.py +569 -0
 - pertpy/metadata/_metadata.py +70 -0
 - pertpy/metadata/_moa.py +125 -0
 - pertpy/plot/__init__.py +0 -13
 - pertpy/preprocessing/__init__.py +2 -0
 - pertpy/preprocessing/_guide_rna.py +89 -6
 - pertpy/tools/__init__.py +48 -15
 - pertpy/tools/_augur.py +329 -32
 - pertpy/tools/_cinemaot.py +145 -6
 - pertpy/tools/_coda/_base_coda.py +1237 -116
 - pertpy/tools/_coda/_sccoda.py +66 -36
 - pertpy/tools/_coda/_tasccoda.py +46 -39
 - pertpy/tools/_dialogue.py +180 -77
 - pertpy/tools/_differential_gene_expression/__init__.py +20 -0
 - pertpy/tools/_differential_gene_expression/_base.py +657 -0
 - pertpy/tools/_differential_gene_expression/_checks.py +41 -0
 - pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
 - pertpy/tools/_differential_gene_expression/_edger.py +125 -0
 - pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
 - pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
 - pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
 - pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
 - pertpy/tools/_distances/_distance_tests.py +29 -24
 - pertpy/tools/_distances/_distances.py +584 -98
 - pertpy/tools/_enrichment.py +460 -0
 - pertpy/tools/_kernel_pca.py +1 -1
 - pertpy/tools/_milo.py +406 -49
 - pertpy/tools/_mixscape.py +677 -55
 - pertpy/tools/_perturbation_space/_clustering.py +10 -3
 - pertpy/tools/_perturbation_space/_comparison.py +112 -0
 - pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
 - pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
 - pertpy/tools/_perturbation_space/_simple.py +52 -11
 - pertpy/tools/_scgen/__init__.py +1 -1
 - pertpy/tools/_scgen/_base_components.py +2 -3
 - pertpy/tools/_scgen/_scgen.py +706 -0
 - pertpy/tools/_scgen/_utils.py +3 -5
 - pertpy/tools/decoupler_LICENSE +674 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
 - pertpy-0.8.0.dist-info/RECORD +57 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
 - pertpy/plot/_augur.py +0 -234
 - pertpy/plot/_cinemaot.py +0 -81
 - pertpy/plot/_coda.py +0 -1001
 - pertpy/plot/_dialogue.py +0 -91
 - pertpy/plot/_guide_rna.py +0 -82
 - pertpy/plot/_milopy.py +0 -284
 - pertpy/plot/_mixscape.py +0 -594
 - pertpy/plot/_scgen.py +0 -337
 - pertpy/tools/_differential_gene_expression.py +0 -99
 - pertpy/tools/_metadata/__init__.py +0 -0
 - pertpy/tools/_metadata/_cell_line.py +0 -613
 - pertpy/tools/_metadata/_look_up.py +0 -342
 - pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
 - pertpy/tools/_scgen/_jax_scgen.py +0 -370
 - pertpy-0.6.0.dist-info/RECORD +0 -50
 - /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
 - {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
 
| 
         @@ -1,370 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from __future__ import annotations
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            from typing import TYPE_CHECKING, Any
         
     | 
| 
       4 
     | 
    
         
            -
             
     | 
| 
       5 
     | 
    
         
            -
            import jax.numpy as jnp
         
     | 
| 
       6 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       7 
     | 
    
         
            -
            from anndata import AnnData
         
     | 
| 
       8 
     | 
    
         
            -
            from jax import Array
         
     | 
| 
       9 
     | 
    
         
            -
            from scvi import REGISTRY_KEYS
         
     | 
| 
       10 
     | 
    
         
            -
            from scvi.data import AnnDataManager
         
     | 
| 
       11 
     | 
    
         
            -
            from scvi.data.fields import CategoricalObsField, LayerField
         
     | 
| 
       12 
     | 
    
         
            -
            from scvi.model.base import BaseModelClass, JaxTrainingMixin
         
     | 
| 
       13 
     | 
    
         
            -
            from scvi.utils import setup_anndata_dsp
         
     | 
| 
       14 
     | 
    
         
            -
             
     | 
| 
       15 
     | 
    
         
            -
            from ._jax_scgenvae import JaxSCGENVAE
         
     | 
| 
       16 
     | 
    
         
            -
            from ._utils import balancer, extractor
         
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
     | 
    
         
            -
            if TYPE_CHECKING:
         
     | 
| 
       19 
     | 
    
         
            -
                from collections.abc import Sequence
         
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
       21 
     | 
    
         
            -
            font = {"family": "Arial", "size": 14}
         
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
            class SCGEN(JaxTrainingMixin, BaseModelClass):
         
     | 
| 
       25 
     | 
    
         
            -
                """Jax Implementation of scGen model for batch removal and perturbation prediction."""
         
     | 
| 
       26 
     | 
    
         
            -
             
     | 
| 
       27 
     | 
    
         
            -
                def __init__(
         
     | 
| 
       28 
     | 
    
         
            -
                    self,
         
     | 
| 
       29 
     | 
    
         
            -
                    adata: AnnData,
         
     | 
| 
       30 
     | 
    
         
            -
                    n_hidden: int = 800,
         
     | 
| 
       31 
     | 
    
         
            -
                    n_latent: int = 100,
         
     | 
| 
       32 
     | 
    
         
            -
                    n_layers: int = 2,
         
     | 
| 
       33 
     | 
    
         
            -
                    dropout_rate: float = 0.2,
         
     | 
| 
       34 
     | 
    
         
            -
                    **model_kwargs,
         
     | 
| 
       35 
     | 
    
         
            -
                ):
         
     | 
| 
       36 
     | 
    
         
            -
                    super().__init__(adata)
         
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
                    self.module = JaxSCGENVAE(
         
     | 
| 
       39 
     | 
    
         
            -
                        n_input=self.summary_stats.n_vars,
         
     | 
| 
       40 
     | 
    
         
            -
                        n_hidden=n_hidden,
         
     | 
| 
       41 
     | 
    
         
            -
                        n_latent=n_latent,
         
     | 
| 
       42 
     | 
    
         
            -
                        n_layers=n_layers,
         
     | 
| 
       43 
     | 
    
         
            -
                        dropout_rate=dropout_rate,
         
     | 
| 
       44 
     | 
    
         
            -
                        **model_kwargs,
         
     | 
| 
       45 
     | 
    
         
            -
                    )
         
     | 
| 
       46 
     | 
    
         
            -
                    self._model_summary_string = (
         
     | 
| 
       47 
     | 
    
         
            -
                        "SCGEN Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: " "{}"
         
     | 
| 
       48 
     | 
    
         
            -
                    ).format(
         
     | 
| 
       49 
     | 
    
         
            -
                        n_hidden,
         
     | 
| 
       50 
     | 
    
         
            -
                        n_latent,
         
     | 
| 
       51 
     | 
    
         
            -
                        n_layers,
         
     | 
| 
       52 
     | 
    
         
            -
                        dropout_rate,
         
     | 
| 
       53 
     | 
    
         
            -
                    )
         
     | 
| 
       54 
     | 
    
         
            -
                    self.init_params_ = self._get_init_params(locals())
         
     | 
| 
       55 
     | 
    
         
            -
             
     | 
| 
       56 
     | 
    
         
            -
                def predict(
         
     | 
| 
       57 
     | 
    
         
            -
                    self,
         
     | 
| 
       58 
     | 
    
         
            -
                    ctrl_key=None,
         
     | 
| 
       59 
     | 
    
         
            -
                    stim_key=None,
         
     | 
| 
       60 
     | 
    
         
            -
                    adata_to_predict=None,
         
     | 
| 
       61 
     | 
    
         
            -
                    celltype_to_predict=None,
         
     | 
| 
       62 
     | 
    
         
            -
                    restrict_arithmetic_to="all",
         
     | 
| 
       63 
     | 
    
         
            -
                ) -> tuple[AnnData, Any]:
         
     | 
| 
       64 
     | 
    
         
            -
                    """Predicts the cell type provided by the user in stimulated condition.
         
     | 
| 
       65 
     | 
    
         
            -
             
     | 
| 
       66 
     | 
    
         
            -
                    Args:
         
     | 
| 
       67 
     | 
    
         
            -
                        ctrl_key: Key for `control` part of the `data` found in `condition_key`.
         
     | 
| 
       68 
     | 
    
         
            -
                        stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
         
     | 
| 
       69 
     | 
    
         
            -
                        adata_to_predict: Adata for unperturbed cells you want to be predicted.
         
     | 
| 
       70 
     | 
    
         
            -
                        celltype_to_predict: The cell type you want to be predicted.
         
     | 
| 
       71 
     | 
    
         
            -
                        restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
         
     | 
| 
       72 
     | 
    
         
            -
             
     | 
| 
       73 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       74 
     | 
    
         
            -
                        `np nd-array` of predicted cells in primary space.
         
     | 
| 
       75 
     | 
    
         
            -
                    delta: float
         
     | 
| 
       76 
     | 
    
         
            -
                        Difference between stimulated and control cells in latent space
         
     | 
| 
       77 
     | 
    
         
            -
             
     | 
| 
       78 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       79 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       80 
     | 
    
         
            -
                        >>> data = pt.dt.kang_2018()
         
     | 
| 
       81 
     | 
    
         
            -
                        >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
         
     | 
| 
       82 
     | 
    
         
            -
                        >>> model = pt.tl.SCGEN(data)
         
     | 
| 
       83 
     | 
    
         
            -
                        >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
         
     | 
| 
       84 
     | 
    
         
            -
                        >>> pred, delta = model.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
         
     | 
| 
       85 
     | 
    
         
            -
                    """
         
     | 
| 
       86 
     | 
    
         
            -
                    # use keys registered from `setup_anndata()`
         
     | 
| 
       87 
     | 
    
         
            -
                    cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
         
     | 
| 
       88 
     | 
    
         
            -
                    condition_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
         
     | 
| 
       89 
     | 
    
         
            -
             
     | 
| 
       90 
     | 
    
         
            -
                    if restrict_arithmetic_to == "all":
         
     | 
| 
       91 
     | 
    
         
            -
                        ctrl_x = self.adata[self.adata.obs[condition_key] == ctrl_key, :]
         
     | 
| 
       92 
     | 
    
         
            -
                        stim_x = self.adata[self.adata.obs[condition_key] == stim_key, :]
         
     | 
| 
       93 
     | 
    
         
            -
                        ctrl_x = balancer(ctrl_x, cell_type_key)
         
     | 
| 
       94 
     | 
    
         
            -
                        stim_x = balancer(stim_x, cell_type_key)
         
     | 
| 
       95 
     | 
    
         
            -
                    else:
         
     | 
| 
       96 
     | 
    
         
            -
                        key = list(restrict_arithmetic_to.keys())[0]
         
     | 
| 
       97 
     | 
    
         
            -
                        values = restrict_arithmetic_to[key]
         
     | 
| 
       98 
     | 
    
         
            -
                        subset = self.adata[self.adata.obs[key].isin(values)]
         
     | 
| 
       99 
     | 
    
         
            -
                        ctrl_x = subset[subset.obs[condition_key] == ctrl_key, :]
         
     | 
| 
       100 
     | 
    
         
            -
                        stim_x = subset[subset.obs[condition_key] == stim_key, :]
         
     | 
| 
       101 
     | 
    
         
            -
                        if len(values) > 1:
         
     | 
| 
       102 
     | 
    
         
            -
                            ctrl_x = balancer(ctrl_x, cell_type_key)
         
     | 
| 
       103 
     | 
    
         
            -
                            stim_x = balancer(stim_x, cell_type_key)
         
     | 
| 
       104 
     | 
    
         
            -
                    if celltype_to_predict is not None and adata_to_predict is not None:
         
     | 
| 
       105 
     | 
    
         
            -
                        raise Exception("Please provide either a cell type or adata not both!")
         
     | 
| 
       106 
     | 
    
         
            -
                    if celltype_to_predict is None and adata_to_predict is None:
         
     | 
| 
       107 
     | 
    
         
            -
                        raise Exception("Please provide a cell type name or adata for your unperturbed cells")
         
     | 
| 
       108 
     | 
    
         
            -
                    if celltype_to_predict is not None:
         
     | 
| 
       109 
     | 
    
         
            -
                        ctrl_pred = extractor(
         
     | 
| 
       110 
     | 
    
         
            -
                            self.adata,
         
     | 
| 
       111 
     | 
    
         
            -
                            celltype_to_predict,
         
     | 
| 
       112 
     | 
    
         
            -
                            condition_key,
         
     | 
| 
       113 
     | 
    
         
            -
                            cell_type_key,
         
     | 
| 
       114 
     | 
    
         
            -
                            ctrl_key,
         
     | 
| 
       115 
     | 
    
         
            -
                            stim_key,
         
     | 
| 
       116 
     | 
    
         
            -
                        )[1]
         
     | 
| 
       117 
     | 
    
         
            -
                    else:
         
     | 
| 
       118 
     | 
    
         
            -
                        ctrl_pred = adata_to_predict
         
     | 
| 
       119 
     | 
    
         
            -
             
     | 
| 
       120 
     | 
    
         
            -
                    eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0])
         
     | 
| 
       121 
     | 
    
         
            -
                    rng = np.random.default_rng()
         
     | 
| 
       122 
     | 
    
         
            -
                    cd_ind = rng.choice(range(ctrl_x.shape[0]), size=eq, replace=False)
         
     | 
| 
       123 
     | 
    
         
            -
                    stim_ind = rng.choice(range(stim_x.shape[0]), size=eq, replace=False)
         
     | 
| 
       124 
     | 
    
         
            -
                    ctrl_adata = ctrl_x[cd_ind, :]
         
     | 
| 
       125 
     | 
    
         
            -
                    stim_adata = stim_x[stim_ind, :]
         
     | 
| 
       126 
     | 
    
         
            -
             
     | 
| 
       127 
     | 
    
         
            -
                    latent_ctrl = self._avg_vector(ctrl_adata)
         
     | 
| 
       128 
     | 
    
         
            -
                    latent_stim = self._avg_vector(stim_adata)
         
     | 
| 
       129 
     | 
    
         
            -
             
     | 
| 
       130 
     | 
    
         
            -
                    delta = latent_stim - latent_ctrl
         
     | 
| 
       131 
     | 
    
         
            -
             
     | 
| 
       132 
     | 
    
         
            -
                    latent_cd = self.get_latent_representation(ctrl_pred)
         
     | 
| 
       133 
     | 
    
         
            -
             
     | 
| 
       134 
     | 
    
         
            -
                    stim_pred = delta + latent_cd
         
     | 
| 
       135 
     | 
    
         
            -
                    predicted_cells = self.module.as_bound().generative(stim_pred)["px"]
         
     | 
| 
       136 
     | 
    
         
            -
             
     | 
| 
       137 
     | 
    
         
            -
                    predicted_adata = AnnData(
         
     | 
| 
       138 
     | 
    
         
            -
                        X=np.array(predicted_cells),
         
     | 
| 
       139 
     | 
    
         
            -
                        obs=ctrl_pred.obs.copy(),
         
     | 
| 
       140 
     | 
    
         
            -
                        var=ctrl_pred.var.copy(),
         
     | 
| 
       141 
     | 
    
         
            -
                        obsm=ctrl_pred.obsm.copy(),
         
     | 
| 
       142 
     | 
    
         
            -
                    )
         
     | 
| 
       143 
     | 
    
         
            -
                    return predicted_adata, delta
         
     | 
| 
       144 
     | 
    
         
            -
             
     | 
| 
       145 
     | 
    
         
            -
                def _avg_vector(self, adata):
         
     | 
| 
       146 
     | 
    
         
            -
                    return np.mean(self.get_latent_representation(adata), axis=0)
         
     | 
| 
       147 
     | 
    
         
            -
             
     | 
| 
       148 
     | 
    
         
            -
                def get_decoded_expression(
         
     | 
| 
       149 
     | 
    
         
            -
                    self,
         
     | 
| 
       150 
     | 
    
         
            -
                    adata: AnnData | None = None,
         
     | 
| 
       151 
     | 
    
         
            -
                    indices: Sequence[int] | None = None,
         
     | 
| 
       152 
     | 
    
         
            -
                    batch_size: int | None = None,
         
     | 
| 
       153 
     | 
    
         
            -
                ) -> Array:
         
     | 
| 
       154 
     | 
    
         
            -
                    """Get decoded expression.
         
     | 
| 
       155 
     | 
    
         
            -
             
     | 
| 
       156 
     | 
    
         
            -
                    Args:
         
     | 
| 
       157 
     | 
    
         
            -
                        adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
         
     | 
| 
       158 
     | 
    
         
            -
                               AnnData object used to initialize the model.
         
     | 
| 
       159 
     | 
    
         
            -
                        indices: Indices of cells in adata to use. If `None`, all cells are used.
         
     | 
| 
       160 
     | 
    
         
            -
                        batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
         
     | 
| 
       161 
     | 
    
         
            -
             
     | 
| 
       162 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       163 
     | 
    
         
            -
                        Decoded expression for each cell
         
     | 
| 
       164 
     | 
    
         
            -
             
     | 
| 
       165 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       166 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       167 
     | 
    
         
            -
                        >>> data = pt.dt.kang_2018()
         
     | 
| 
       168 
     | 
    
         
            -
                        >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
         
     | 
| 
       169 
     | 
    
         
            -
                        >>> model = pt.tl.SCGEN(data)
         
     | 
| 
       170 
     | 
    
         
            -
                        >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
         
     | 
| 
       171 
     | 
    
         
            -
                        >>> decoded_X = model.get_decoded_expression()
         
     | 
| 
       172 
     | 
    
         
            -
                    """
         
     | 
| 
       173 
     | 
    
         
            -
                    if self.is_trained_ is False:
         
     | 
| 
       174 
     | 
    
         
            -
                        raise RuntimeError("Please train the model first.")
         
     | 
| 
       175 
     | 
    
         
            -
             
     | 
| 
       176 
     | 
    
         
            -
                    adata = self._validate_anndata(adata)
         
     | 
| 
       177 
     | 
    
         
            -
                    scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
         
     | 
| 
       178 
     | 
    
         
            -
                    decoded = []
         
     | 
| 
       179 
     | 
    
         
            -
                    for tensors in scdl:
         
     | 
| 
       180 
     | 
    
         
            -
                        _, generative_outputs = self.module.as_bound()(tensors, compute_loss=False)
         
     | 
| 
       181 
     | 
    
         
            -
                        px = generative_outputs["px"]
         
     | 
| 
       182 
     | 
    
         
            -
                        decoded.append(px)
         
     | 
| 
       183 
     | 
    
         
            -
             
     | 
| 
       184 
     | 
    
         
            -
                    return jnp.concatenate(decoded)
         
     | 
| 
       185 
     | 
    
         
            -
             
     | 
| 
       186 
     | 
    
         
            -
                def batch_removal(self, adata: AnnData | None = None) -> AnnData:
         
     | 
| 
       187 
     | 
    
         
            -
                    """Removes batch effects.
         
     | 
| 
       188 
     | 
    
         
            -
             
     | 
| 
       189 
     | 
    
         
            -
                    Args:
         
     | 
| 
       190 
     | 
    
         
            -
                        adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
         
     | 
| 
       191 
     | 
    
         
            -
                               AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
         
     | 
| 
       192 
     | 
    
         
            -
                               corresponding to batch and cell type metadata, respectively.
         
     | 
| 
       193 
     | 
    
         
            -
             
     | 
| 
       194 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       195 
     | 
    
         
            -
                        corrected: `~anndata.AnnData`
         
     | 
| 
       196 
     | 
    
         
            -
                        AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
         
     | 
| 
       197 
     | 
    
         
            -
                        A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
         
     | 
| 
       198 
     | 
    
         
            -
             
     | 
| 
       199 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       200 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       201 
     | 
    
         
            -
                        >>> data = pt.dt.kang_2018()
         
     | 
| 
       202 
     | 
    
         
            -
                        >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
         
     | 
| 
       203 
     | 
    
         
            -
                        >>> model = pt.tl.SCGEN(data)
         
     | 
| 
       204 
     | 
    
         
            -
                        >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
         
     | 
| 
       205 
     | 
    
         
            -
                        >>> corrected_adata = model.batch_removal()
         
     | 
| 
       206 
     | 
    
         
            -
                    """
         
     | 
| 
       207 
     | 
    
         
            -
                    adata = self._validate_anndata(adata)
         
     | 
| 
       208 
     | 
    
         
            -
                    latent_all = self.get_latent_representation(adata)
         
     | 
| 
       209 
     | 
    
         
            -
                    # use keys registered from `setup_anndata()`
         
     | 
| 
       210 
     | 
    
         
            -
                    cell_label_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
         
     | 
| 
       211 
     | 
    
         
            -
                    batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
         
     | 
| 
       212 
     | 
    
         
            -
             
     | 
| 
       213 
     | 
    
         
            -
                    adata_latent = AnnData(latent_all)
         
     | 
| 
       214 
     | 
    
         
            -
                    adata_latent.obs = adata.obs.copy(deep=True)
         
     | 
| 
       215 
     | 
    
         
            -
                    unique_cell_types = np.unique(adata_latent.obs[cell_label_key])
         
     | 
| 
       216 
     | 
    
         
            -
                    shared_ct = []
         
     | 
| 
       217 
     | 
    
         
            -
                    not_shared_ct = []
         
     | 
| 
       218 
     | 
    
         
            -
                    for cell_type in unique_cell_types:
         
     | 
| 
       219 
     | 
    
         
            -
                        temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
         
     | 
| 
       220 
     | 
    
         
            -
                        if len(np.unique(temp_cell.obs[batch_key])) < 2:
         
     | 
| 
       221 
     | 
    
         
            -
                            cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
         
     | 
| 
       222 
     | 
    
         
            -
                            not_shared_ct.append(cell_type_ann)
         
     | 
| 
       223 
     | 
    
         
            -
                            continue
         
     | 
| 
       224 
     | 
    
         
            -
                        temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
         
     | 
| 
       225 
     | 
    
         
            -
                        batch_list = {}
         
     | 
| 
       226 
     | 
    
         
            -
                        batch_ind = {}
         
     | 
| 
       227 
     | 
    
         
            -
                        max_batch = 0
         
     | 
| 
       228 
     | 
    
         
            -
                        max_batch_ind = ""
         
     | 
| 
       229 
     | 
    
         
            -
                        batches = np.unique(temp_cell.obs[batch_key])
         
     | 
| 
       230 
     | 
    
         
            -
                        for i in batches:
         
     | 
| 
       231 
     | 
    
         
            -
                            temp = temp_cell[temp_cell.obs[batch_key] == i]
         
     | 
| 
       232 
     | 
    
         
            -
                            temp_ind = temp_cell.obs[batch_key] == i
         
     | 
| 
       233 
     | 
    
         
            -
                            if max_batch < len(temp):
         
     | 
| 
       234 
     | 
    
         
            -
                                max_batch = len(temp)
         
     | 
| 
       235 
     | 
    
         
            -
                                max_batch_ind = i
         
     | 
| 
       236 
     | 
    
         
            -
                            batch_list[i] = temp
         
     | 
| 
       237 
     | 
    
         
            -
                            batch_ind[i] = temp_ind
         
     | 
| 
       238 
     | 
    
         
            -
                        max_batch_ann = batch_list[max_batch_ind]
         
     | 
| 
       239 
     | 
    
         
            -
                        for study in batch_list:
         
     | 
| 
       240 
     | 
    
         
            -
                            delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
         
     | 
| 
       241 
     | 
    
         
            -
                            batch_list[study].X = delta + batch_list[study].X
         
     | 
| 
       242 
     | 
    
         
            -
                            temp_cell[batch_ind[study]].X = batch_list[study].X
         
     | 
| 
       243 
     | 
    
         
            -
                        shared_ct.append(temp_cell)
         
     | 
| 
       244 
     | 
    
         
            -
             
     | 
| 
       245 
     | 
    
         
            -
                    all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
         
     | 
| 
       246 
     | 
    
         
            -
                    if "concat_batch" in all_shared_ann.obs.columns:
         
     | 
| 
       247 
     | 
    
         
            -
                        del all_shared_ann.obs["concat_batch"]
         
     | 
| 
       248 
     | 
    
         
            -
                    if len(not_shared_ct) < 1:
         
     | 
| 
       249 
     | 
    
         
            -
                        corrected = AnnData(
         
     | 
| 
       250 
     | 
    
         
            -
                            np.array(self.module.as_bound().generative(all_shared_ann.X)["px"]),
         
     | 
| 
       251 
     | 
    
         
            -
                            obs=all_shared_ann.obs,
         
     | 
| 
       252 
     | 
    
         
            -
                        )
         
     | 
| 
       253 
     | 
    
         
            -
                        corrected.var_names = adata.var_names.tolist()
         
     | 
| 
       254 
     | 
    
         
            -
                        corrected = corrected[adata.obs_names]
         
     | 
| 
       255 
     | 
    
         
            -
                        if adata.raw is not None:
         
     | 
| 
       256 
     | 
    
         
            -
                            adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
         
     | 
| 
       257 
     | 
    
         
            -
                            adata_raw.obs_names = adata.obs_names
         
     | 
| 
       258 
     | 
    
         
            -
                            corrected.raw = adata_raw
         
     | 
| 
       259 
     | 
    
         
            -
                        corrected.obsm["latent"] = all_shared_ann.X
         
     | 
| 
       260 
     | 
    
         
            -
                        corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
         
     | 
| 
       261 
     | 
    
         
            -
                        return corrected
         
     | 
| 
       262 
     | 
    
         
            -
                    else:
         
     | 
| 
       263 
     | 
    
         
            -
                        all_not_shared_ann = AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
         
     | 
| 
       264 
     | 
    
         
            -
                        all_corrected_data = AnnData.concatenate(
         
     | 
| 
       265 
     | 
    
         
            -
                            all_shared_ann,
         
     | 
| 
       266 
     | 
    
         
            -
                            all_not_shared_ann,
         
     | 
| 
       267 
     | 
    
         
            -
                            batch_key="concat_batch",
         
     | 
| 
       268 
     | 
    
         
            -
                            index_unique=None,
         
     | 
| 
       269 
     | 
    
         
            -
                        )
         
     | 
| 
       270 
     | 
    
         
            -
                        if "concat_batch" in all_shared_ann.obs.columns:
         
     | 
| 
       271 
     | 
    
         
            -
                            del all_corrected_data.obs["concat_batch"]
         
     | 
| 
       272 
     | 
    
         
            -
                        corrected = AnnData(
         
     | 
| 
       273 
     | 
    
         
            -
                            np.array(self.module.as_bound().generative(all_corrected_data.X)["px"]),
         
     | 
| 
       274 
     | 
    
         
            -
                            obs=all_corrected_data.obs,
         
     | 
| 
       275 
     | 
    
         
            -
                        )
         
     | 
| 
       276 
     | 
    
         
            -
                        corrected.var_names = adata.var_names.tolist()
         
     | 
| 
       277 
     | 
    
         
            -
                        corrected = corrected[adata.obs_names]
         
     | 
| 
       278 
     | 
    
         
            -
                        if adata.raw is not None:
         
     | 
| 
       279 
     | 
    
         
            -
                            adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
         
     | 
| 
       280 
     | 
    
         
            -
                            adata_raw.obs_names = adata.obs_names
         
     | 
| 
       281 
     | 
    
         
            -
                            corrected.raw = adata_raw
         
     | 
| 
       282 
     | 
    
         
            -
                        corrected.obsm["latent"] = all_corrected_data.X
         
     | 
| 
       283 
     | 
    
         
            -
                        corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
         
     | 
| 
       284 
     | 
    
         
            -
             
     | 
| 
       285 
     | 
    
         
            -
                        return corrected
         
     | 
| 
       286 
     | 
    
         
            -
             
     | 
| 
       287 
     | 
    
         
            -
                @classmethod
         
     | 
| 
       288 
     | 
    
         
            -
                @setup_anndata_dsp.dedent
         
     | 
| 
       289 
     | 
    
         
            -
                def setup_anndata(
         
     | 
| 
       290 
     | 
    
         
            -
                    cls,
         
     | 
| 
       291 
     | 
    
         
            -
                    adata: AnnData,
         
     | 
| 
       292 
     | 
    
         
            -
                    batch_key: str | None = None,
         
     | 
| 
       293 
     | 
    
         
            -
                    labels_key: str | None = None,
         
     | 
| 
       294 
     | 
    
         
            -
                    **kwargs,
         
     | 
| 
       295 
     | 
    
         
            -
                ):
         
     | 
| 
       296 
     | 
    
         
            -
                    """%(summary)s.
         
     | 
| 
       297 
     | 
    
         
            -
             
     | 
| 
       298 
     | 
    
         
            -
                    scGen expects the expression data to come from `adata.X`
         
     | 
| 
       299 
     | 
    
         
            -
             
     | 
| 
       300 
     | 
    
         
            -
                    %(param_batch_key)s
         
     | 
| 
       301 
     | 
    
         
            -
                    %(param_labels_key)s
         
     | 
| 
       302 
     | 
    
         
            -
             
     | 
| 
       303 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       304 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       305 
     | 
    
         
            -
                        >>> data = pt.dt.kang_2018()
         
     | 
| 
       306 
     | 
    
         
            -
                        >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
         
     | 
| 
       307 
     | 
    
         
            -
                    """
         
     | 
| 
       308 
     | 
    
         
            -
                    setup_method_args = cls._get_setup_method_args(**locals())
         
     | 
| 
       309 
     | 
    
         
            -
                    anndata_fields = [
         
     | 
| 
       310 
     | 
    
         
            -
                        LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=False),
         
     | 
| 
       311 
     | 
    
         
            -
                        CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
         
     | 
| 
       312 
     | 
    
         
            -
                        CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
         
     | 
| 
       313 
     | 
    
         
            -
                    ]
         
     | 
| 
       314 
     | 
    
         
            -
                    adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
         
     | 
| 
       315 
     | 
    
         
            -
                    adata_manager.register_fields(adata, **kwargs)
         
     | 
| 
       316 
     | 
    
         
            -
                    cls.register_manager(adata_manager)
         
     | 
| 
       317 
     | 
    
         
            -
             
     | 
| 
       318 
     | 
    
         
            -
                def to_device(self, device):
         
     | 
| 
       319 
     | 
    
         
            -
                    pass
         
     | 
| 
       320 
     | 
    
         
            -
             
     | 
| 
       321 
     | 
    
         
            -
                @property
         
     | 
| 
       322 
     | 
    
         
            -
                def device(self):
         
     | 
| 
       323 
     | 
    
         
            -
                    return self.module.device
         
     | 
| 
       324 
     | 
    
         
            -
             
     | 
| 
       325 
     | 
    
         
            -
                def get_latent_representation(
         
     | 
| 
       326 
     | 
    
         
            -
                    self,
         
     | 
| 
       327 
     | 
    
         
            -
                    adata: AnnData | None = None,
         
     | 
| 
       328 
     | 
    
         
            -
                    indices: Sequence[int] | None = None,
         
     | 
| 
       329 
     | 
    
         
            -
                    give_mean: bool = True,
         
     | 
| 
       330 
     | 
    
         
            -
                    n_samples: int = 1,
         
     | 
| 
       331 
     | 
    
         
            -
                    batch_size: int | None = None,
         
     | 
| 
       332 
     | 
    
         
            -
                ) -> np.ndarray:
         
     | 
| 
       333 
     | 
    
         
            -
                    """Return the latent representation for each cell.
         
     | 
| 
       334 
     | 
    
         
            -
             
     | 
| 
       335 
     | 
    
         
            -
                    Args:
         
     | 
| 
       336 
     | 
    
         
            -
                        adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
         
     | 
| 
       337 
     | 
    
         
            -
                               AnnData object used to initialize the model.
         
     | 
| 
       338 
     | 
    
         
            -
                        indices: Indices of cells in adata to use. If `None`, all cells are used.
         
     | 
| 
       339 
     | 
    
         
            -
                        batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
         
     | 
| 
       340 
     | 
    
         
            -
             
     | 
| 
       341 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       342 
     | 
    
         
            -
                        Low-dimensional representation for each cell
         
     | 
| 
       343 
     | 
    
         
            -
             
     | 
| 
       344 
     | 
    
         
            -
                    Examples:
         
     | 
| 
       345 
     | 
    
         
            -
                        >>> import pertpy as pt
         
     | 
| 
       346 
     | 
    
         
            -
                        >>> data = pt.dt.kang_2018()
         
     | 
| 
       347 
     | 
    
         
            -
                        >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
         
     | 
| 
       348 
     | 
    
         
            -
                        >>> model = pt.tl.SCGEN(data)
         
     | 
| 
       349 
     | 
    
         
            -
                        >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
         
     | 
| 
       350 
     | 
    
         
            -
                        >>> latent_X = model.get_latent_representation()
         
     | 
| 
       351 
     | 
    
         
            -
                    """
         
     | 
| 
       352 
     | 
    
         
            -
                    self._check_if_trained(warn=False)
         
     | 
| 
       353 
     | 
    
         
            -
             
     | 
| 
       354 
     | 
    
         
            -
                    adata = self._validate_anndata(adata)
         
     | 
| 
       355 
     | 
    
         
            -
                    scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True)
         
     | 
| 
       356 
     | 
    
         
            -
             
     | 
| 
       357 
     | 
    
         
            -
                    jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"n_samples": n_samples})
         
     | 
| 
       358 
     | 
    
         
            -
             
     | 
| 
       359 
     | 
    
         
            -
                    latent = []
         
     | 
| 
       360 
     | 
    
         
            -
                    for array_dict in scdl:
         
     | 
| 
       361 
     | 
    
         
            -
                        out = jit_inference_fn(self.module.rngs, array_dict)
         
     | 
| 
       362 
     | 
    
         
            -
                        if give_mean:
         
     | 
| 
       363 
     | 
    
         
            -
                            z = out["qz"].mean
         
     | 
| 
       364 
     | 
    
         
            -
                        else:
         
     | 
| 
       365 
     | 
    
         
            -
                            z = out["z"]
         
     | 
| 
       366 
     | 
    
         
            -
                        latent.append(z)
         
     | 
| 
       367 
     | 
    
         
            -
                    concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
         
     | 
| 
       368 
     | 
    
         
            -
                    latent = jnp.concatenate(latent, axis=concat_axis)  # type: ignore
         
     | 
| 
       369 
     | 
    
         
            -
             
     | 
| 
       370 
     | 
    
         
            -
                    return self.module.as_numpy_array(latent)
         
     | 
    
        pertpy-0.6.0.dist-info/RECORD
    DELETED
    
    | 
         @@ -1,50 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            pertpy/__init__.py,sha256=3__crpMVG7ky5lmD91Pq9qIGWgUuZQTH8xpiM5qcUJA,546
         
     | 
| 
       2 
     | 
    
         
            -
            pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       3 
     | 
    
         
            -
            pertpy/data/__init__.py,sha256=dvFUk-vAVelA65esA4EbIAVEQoE3s9K6LmE31-j2fC0,1197
         
     | 
| 
       4 
     | 
    
         
            -
            pertpy/data/_dataloader.py,sha256=pNDXLSNzOLeFM_mqf9nNvN_6Y4uA4gJfG3Y7VS-03ko,2397
         
     | 
| 
       5 
     | 
    
         
            -
            pertpy/data/_datasets.py,sha256=q20-f7MT2neWTplN300QOBlu-ihWa8IRTKxUgLgemIw,59496
         
     | 
| 
       6 
     | 
    
         
            -
            pertpy/plot/__init__.py,sha256=HB6nEBfOPmOVRHOJsJ7IJcxx2j6-6oQ__sJRaszBKuk,455
         
     | 
| 
       7 
     | 
    
         
            -
            pertpy/plot/_augur.py,sha256=pRhgc1RdRhXp6xl7-y8Z4o8beUBfltJY3XUeN9GJKbs,9064
         
     | 
| 
       8 
     | 
    
         
            -
            pertpy/plot/_cinemaot.py,sha256=tPTab-5jqalGLfa1NNeevG3_ExbKRfnIE8RRnt8Eecc,3199
         
     | 
| 
       9 
     | 
    
         
            -
            pertpy/plot/_coda.py,sha256=Ma24jc5KhuY3dtIJ6xO-pp0JpW7vWc-TPhSKJMXBEmQ,43650
         
     | 
| 
       10 
     | 
    
         
            -
            pertpy/plot/_dialogue.py,sha256=TGv_fb5f1zPEaJA8SgCue77IJkHKsQLR8f8oIz9SEcE,3881
         
     | 
| 
       11 
     | 
    
         
            -
            pertpy/plot/_guide_rna.py,sha256=Z-_vjHcOIK-DXLDTZGl5HmG6A2TnJBHv9L8VK7L3_fA,3286
         
     | 
| 
       12 
     | 
    
         
            -
            pertpy/plot/_milopy.py,sha256=6K9DtmHiCh6FUb5xScUZTxXUZoRCwD0oyfAMu0SmRGA,10994
         
     | 
| 
       13 
     | 
    
         
            -
            pertpy/plot/_mixscape.py,sha256=KeLCqWRcn2092VqB94PqBtP_wxD_OY4uS8GcZ2RXc7Y,27903
         
     | 
| 
       14 
     | 
    
         
            -
            pertpy/plot/_scgen.py,sha256=KnPe8iOqDDZw0MpSxOU7Xr-2t1UtHKehYgBQ7_4O8d4,15125
         
     | 
| 
       15 
     | 
    
         
            -
            pertpy/preprocessing/__init__.py,sha256=uja9T469LLYQAGgrTyFa4MudXci6NXnAgOn97FHXcxA,40
         
     | 
| 
       16 
     | 
    
         
            -
            pertpy/preprocessing/_guide_rna.py,sha256=EYSrsMP7FpztS0NQhn1xg0oBZZ5RT5fz6YBFvmOab58,4247
         
     | 
| 
       17 
     | 
    
         
            -
            pertpy/tools/__init__.py,sha256=QiFFM1IL7K47vuTbQqjgB8rVzauWmn6JVVpQG9AikvA,1108
         
     | 
| 
       18 
     | 
    
         
            -
            pertpy/tools/_augur.py,sha256=EUe-aRGO-PzszTS8vMfUJtzpfC3CmUSorSJTkEEU60w,45193
         
     | 
| 
       19 
     | 
    
         
            -
            pertpy/tools/_cinemaot.py,sha256=bqbxc88AH4vo2--Y5yLH3anuu1prWDAxoRZaiNvOgtQ,33374
         
     | 
| 
       20 
     | 
    
         
            -
            pertpy/tools/_dialogue.py,sha256=OUSjPzTRi46WG5QARoj2_fpmr7IQ2ftTlXT3-OiiWJc,48116
         
     | 
| 
       21 
     | 
    
         
            -
            pertpy/tools/_differential_gene_expression.py,sha256=mR06huO71KRLcU32ktCWzL-XxA9IGz8OYiRZA26eH0E,3681
         
     | 
| 
       22 
     | 
    
         
            -
            pertpy/tools/_kernel_pca.py,sha256=3S1D_wrp4vlHUPiRbCAoRbUyY-rVs112Qh-BZHSmTxE,1578
         
     | 
| 
       23 
     | 
    
         
            -
            pertpy/tools/_milo.py,sha256=OyLztlNO4Jt1c2aN3WsBbcA0UKVXVvWAnTaKwjPwJ2I,30737
         
     | 
| 
       24 
     | 
    
         
            -
            pertpy/tools/_mixscape.py,sha256=l3YHeyaUUrtuP9P8L5Z7gH47lJpzb0glszMX84DyJBI,23559
         
     | 
| 
       25 
     | 
    
         
            -
            pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
         
     | 
| 
       26 
     | 
    
         
            -
            pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       27 
     | 
    
         
            -
            pertpy/tools/_coda/_base_coda.py,sha256=mxNe5PT1XvIlZmvjQg50kh_bSmeTGVzOC63XLw2TdiI,66859
         
     | 
| 
       28 
     | 
    
         
            -
            pertpy/tools/_coda/_sccoda.py,sha256=cxaqGsXxeLf4guTU1HApAzXN2maQPexsGXIJOlW8UTM,21616
         
     | 
| 
       29 
     | 
    
         
            -
            pertpy/tools/_coda/_tasccoda.py,sha256=q0I7zM_hGjPrpy5dF2Z9trw6u8OqdkrypGgeuAhi26k,30721
         
     | 
| 
       30 
     | 
    
         
            -
            pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       31 
     | 
    
         
            -
            pertpy/tools/_distances/_distance_tests.py,sha256=zRcOeLc18mRnUJ-_usUdVxWn3cZqZ8gLhglt77SaF9k,13604
         
     | 
| 
       32 
     | 
    
         
            -
            pertpy/tools/_distances/_distances.py,sha256=RMNtCD1zkORDE35XWcrh_6mw1c03hOQflmXNfoNtSRA,29780
         
     | 
| 
       33 
     | 
    
         
            -
            pertpy/tools/_metadata/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       34 
     | 
    
         
            -
            pertpy/tools/_metadata/_cell_line.py,sha256=4sUULdmxQ3TFUZDCwikN9TcHG5hf2hzlEO6gOglGl-A,33830
         
     | 
| 
       35 
     | 
    
         
            -
            pertpy/tools/_metadata/_look_up.py,sha256=H7kp9MgfgYMVdxyg3Qpf3_QmqNUkKFNMsswWeA_e1rQ,18200
         
     | 
| 
       36 
     | 
    
         
            -
            pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       37 
     | 
    
         
            -
            pertpy/tools/_perturbation_space/_clustering.py,sha256=ha0TfRKUIFJmL6LE-xIfENAlYyQf4nfTpgg47X_2pHA,3237
         
     | 
| 
       38 
     | 
    
         
            -
            pertpy/tools/_perturbation_space/_discriminator_classifier.py,sha256=hTEAKTnLH4ToSdEHYuJnwui3B8L-zlSR667oG3yb49M,13861
         
     | 
| 
       39 
     | 
    
         
            -
            pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
         
     | 
| 
       40 
     | 
    
         
            -
            pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=_A96OFbpjZULcQGfbsDhXiBjvD0chBl6c-4FoQNoV3w,14169
         
     | 
| 
       41 
     | 
    
         
            -
            pertpy/tools/_perturbation_space/_simple.py,sha256=AZx8GaNJV67evSi5oUkY11QcUkq3EcL0mtkCipjcx6c,10367
         
     | 
| 
       42 
     | 
    
         
            -
            pertpy/tools/_scgen/__init__.py,sha256=bMQ_2QbB4nnzQ7TzhI4DEFfuCDUNbZkL5xDClhQjhcA,49
         
     | 
| 
       43 
     | 
    
         
            -
            pertpy/tools/_scgen/_base_components.py,sha256=dIw-_7Z8iCietPF4tnpM7bFHtDksjnaHXwUjp9GoCIQ,2936
         
     | 
| 
       44 
     | 
    
         
            -
            pertpy/tools/_scgen/_jax_scgen.py,sha256=6fmen3zQm54Yprmd3r7zJK3GIWqpMd034DLGmi-krrs,15368
         
     | 
| 
       45 
     | 
    
         
            -
            pertpy/tools/_scgen/_jax_scgenvae.py,sha256=v_6tZ4wY-JjdMH1QVd_wG4_N0PoaqB-FM8zC2JsDu1o,3935
         
     | 
| 
       46 
     | 
    
         
            -
            pertpy/tools/_scgen/_utils.py,sha256=_G9cxBVcTIOs4wN0pgtOSkCsPJoohkeRDIb_anUqSfY,2871
         
     | 
| 
       47 
     | 
    
         
            -
            pertpy-0.6.0.dist-info/METADATA,sha256=bmYUVV99CMPm870ehtSiTbB6lPsYg0kSrmK1aoCvuu8,5046
         
     | 
| 
       48 
     | 
    
         
            -
            pertpy-0.6.0.dist-info/WHEEL,sha256=9QBuHhg6FNW7lppboF2vKVbCGTVzsFykgRQjjlajrhA,87
         
     | 
| 
       49 
     | 
    
         
            -
            pertpy-0.6.0.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
         
     | 
| 
       50 
     | 
    
         
            -
            pertpy-0.6.0.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |