pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,370 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import TYPE_CHECKING, Any
|
4
|
-
|
5
|
-
import jax.numpy as jnp
|
6
|
-
import numpy as np
|
7
|
-
from anndata import AnnData
|
8
|
-
from jax import Array
|
9
|
-
from scvi import REGISTRY_KEYS
|
10
|
-
from scvi.data import AnnDataManager
|
11
|
-
from scvi.data.fields import CategoricalObsField, LayerField
|
12
|
-
from scvi.model.base import BaseModelClass, JaxTrainingMixin
|
13
|
-
from scvi.utils import setup_anndata_dsp
|
14
|
-
|
15
|
-
from ._jax_scgenvae import JaxSCGENVAE
|
16
|
-
from ._utils import balancer, extractor
|
17
|
-
|
18
|
-
if TYPE_CHECKING:
|
19
|
-
from collections.abc import Sequence
|
20
|
-
|
21
|
-
font = {"family": "Arial", "size": 14}
|
22
|
-
|
23
|
-
|
24
|
-
class SCGEN(JaxTrainingMixin, BaseModelClass):
|
25
|
-
"""Jax Implementation of scGen model for batch removal and perturbation prediction."""
|
26
|
-
|
27
|
-
def __init__(
|
28
|
-
self,
|
29
|
-
adata: AnnData,
|
30
|
-
n_hidden: int = 800,
|
31
|
-
n_latent: int = 100,
|
32
|
-
n_layers: int = 2,
|
33
|
-
dropout_rate: float = 0.2,
|
34
|
-
**model_kwargs,
|
35
|
-
):
|
36
|
-
super().__init__(adata)
|
37
|
-
|
38
|
-
self.module = JaxSCGENVAE(
|
39
|
-
n_input=self.summary_stats.n_vars,
|
40
|
-
n_hidden=n_hidden,
|
41
|
-
n_latent=n_latent,
|
42
|
-
n_layers=n_layers,
|
43
|
-
dropout_rate=dropout_rate,
|
44
|
-
**model_kwargs,
|
45
|
-
)
|
46
|
-
self._model_summary_string = (
|
47
|
-
"SCGEN Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: " "{}"
|
48
|
-
).format(
|
49
|
-
n_hidden,
|
50
|
-
n_latent,
|
51
|
-
n_layers,
|
52
|
-
dropout_rate,
|
53
|
-
)
|
54
|
-
self.init_params_ = self._get_init_params(locals())
|
55
|
-
|
56
|
-
def predict(
|
57
|
-
self,
|
58
|
-
ctrl_key=None,
|
59
|
-
stim_key=None,
|
60
|
-
adata_to_predict=None,
|
61
|
-
celltype_to_predict=None,
|
62
|
-
restrict_arithmetic_to="all",
|
63
|
-
) -> tuple[AnnData, Any]:
|
64
|
-
"""Predicts the cell type provided by the user in stimulated condition.
|
65
|
-
|
66
|
-
Args:
|
67
|
-
ctrl_key: Key for `control` part of the `data` found in `condition_key`.
|
68
|
-
stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
|
69
|
-
adata_to_predict: Adata for unperturbed cells you want to be predicted.
|
70
|
-
celltype_to_predict: The cell type you want to be predicted.
|
71
|
-
restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
|
72
|
-
|
73
|
-
Returns:
|
74
|
-
`np nd-array` of predicted cells in primary space.
|
75
|
-
delta: float
|
76
|
-
Difference between stimulated and control cells in latent space
|
77
|
-
|
78
|
-
Examples:
|
79
|
-
>>> import pertpy as pt
|
80
|
-
>>> data = pt.dt.kang_2018()
|
81
|
-
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
82
|
-
>>> model = pt.tl.SCGEN(data)
|
83
|
-
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
84
|
-
>>> pred, delta = model.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
|
85
|
-
"""
|
86
|
-
# use keys registered from `setup_anndata()`
|
87
|
-
cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
|
88
|
-
condition_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
|
89
|
-
|
90
|
-
if restrict_arithmetic_to == "all":
|
91
|
-
ctrl_x = self.adata[self.adata.obs[condition_key] == ctrl_key, :]
|
92
|
-
stim_x = self.adata[self.adata.obs[condition_key] == stim_key, :]
|
93
|
-
ctrl_x = balancer(ctrl_x, cell_type_key)
|
94
|
-
stim_x = balancer(stim_x, cell_type_key)
|
95
|
-
else:
|
96
|
-
key = list(restrict_arithmetic_to.keys())[0]
|
97
|
-
values = restrict_arithmetic_to[key]
|
98
|
-
subset = self.adata[self.adata.obs[key].isin(values)]
|
99
|
-
ctrl_x = subset[subset.obs[condition_key] == ctrl_key, :]
|
100
|
-
stim_x = subset[subset.obs[condition_key] == stim_key, :]
|
101
|
-
if len(values) > 1:
|
102
|
-
ctrl_x = balancer(ctrl_x, cell_type_key)
|
103
|
-
stim_x = balancer(stim_x, cell_type_key)
|
104
|
-
if celltype_to_predict is not None and adata_to_predict is not None:
|
105
|
-
raise Exception("Please provide either a cell type or adata not both!")
|
106
|
-
if celltype_to_predict is None and adata_to_predict is None:
|
107
|
-
raise Exception("Please provide a cell type name or adata for your unperturbed cells")
|
108
|
-
if celltype_to_predict is not None:
|
109
|
-
ctrl_pred = extractor(
|
110
|
-
self.adata,
|
111
|
-
celltype_to_predict,
|
112
|
-
condition_key,
|
113
|
-
cell_type_key,
|
114
|
-
ctrl_key,
|
115
|
-
stim_key,
|
116
|
-
)[1]
|
117
|
-
else:
|
118
|
-
ctrl_pred = adata_to_predict
|
119
|
-
|
120
|
-
eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0])
|
121
|
-
rng = np.random.default_rng()
|
122
|
-
cd_ind = rng.choice(range(ctrl_x.shape[0]), size=eq, replace=False)
|
123
|
-
stim_ind = rng.choice(range(stim_x.shape[0]), size=eq, replace=False)
|
124
|
-
ctrl_adata = ctrl_x[cd_ind, :]
|
125
|
-
stim_adata = stim_x[stim_ind, :]
|
126
|
-
|
127
|
-
latent_ctrl = self._avg_vector(ctrl_adata)
|
128
|
-
latent_stim = self._avg_vector(stim_adata)
|
129
|
-
|
130
|
-
delta = latent_stim - latent_ctrl
|
131
|
-
|
132
|
-
latent_cd = self.get_latent_representation(ctrl_pred)
|
133
|
-
|
134
|
-
stim_pred = delta + latent_cd
|
135
|
-
predicted_cells = self.module.as_bound().generative(stim_pred)["px"]
|
136
|
-
|
137
|
-
predicted_adata = AnnData(
|
138
|
-
X=np.array(predicted_cells),
|
139
|
-
obs=ctrl_pred.obs.copy(),
|
140
|
-
var=ctrl_pred.var.copy(),
|
141
|
-
obsm=ctrl_pred.obsm.copy(),
|
142
|
-
)
|
143
|
-
return predicted_adata, delta
|
144
|
-
|
145
|
-
def _avg_vector(self, adata):
|
146
|
-
return np.mean(self.get_latent_representation(adata), axis=0)
|
147
|
-
|
148
|
-
def get_decoded_expression(
|
149
|
-
self,
|
150
|
-
adata: AnnData | None = None,
|
151
|
-
indices: Sequence[int] | None = None,
|
152
|
-
batch_size: int | None = None,
|
153
|
-
) -> Array:
|
154
|
-
"""Get decoded expression.
|
155
|
-
|
156
|
-
Args:
|
157
|
-
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
158
|
-
AnnData object used to initialize the model.
|
159
|
-
indices: Indices of cells in adata to use. If `None`, all cells are used.
|
160
|
-
batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
|
161
|
-
|
162
|
-
Returns:
|
163
|
-
Decoded expression for each cell
|
164
|
-
|
165
|
-
Examples:
|
166
|
-
>>> import pertpy as pt
|
167
|
-
>>> data = pt.dt.kang_2018()
|
168
|
-
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
169
|
-
>>> model = pt.tl.SCGEN(data)
|
170
|
-
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
171
|
-
>>> decoded_X = model.get_decoded_expression()
|
172
|
-
"""
|
173
|
-
if self.is_trained_ is False:
|
174
|
-
raise RuntimeError("Please train the model first.")
|
175
|
-
|
176
|
-
adata = self._validate_anndata(adata)
|
177
|
-
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
|
178
|
-
decoded = []
|
179
|
-
for tensors in scdl:
|
180
|
-
_, generative_outputs = self.module.as_bound()(tensors, compute_loss=False)
|
181
|
-
px = generative_outputs["px"]
|
182
|
-
decoded.append(px)
|
183
|
-
|
184
|
-
return jnp.concatenate(decoded)
|
185
|
-
|
186
|
-
def batch_removal(self, adata: AnnData | None = None) -> AnnData:
|
187
|
-
"""Removes batch effects.
|
188
|
-
|
189
|
-
Args:
|
190
|
-
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
191
|
-
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
192
|
-
corresponding to batch and cell type metadata, respectively.
|
193
|
-
|
194
|
-
Returns:
|
195
|
-
corrected: `~anndata.AnnData`
|
196
|
-
AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
|
197
|
-
A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
|
198
|
-
|
199
|
-
Examples:
|
200
|
-
>>> import pertpy as pt
|
201
|
-
>>> data = pt.dt.kang_2018()
|
202
|
-
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
203
|
-
>>> model = pt.tl.SCGEN(data)
|
204
|
-
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
205
|
-
>>> corrected_adata = model.batch_removal()
|
206
|
-
"""
|
207
|
-
adata = self._validate_anndata(adata)
|
208
|
-
latent_all = self.get_latent_representation(adata)
|
209
|
-
# use keys registered from `setup_anndata()`
|
210
|
-
cell_label_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
|
211
|
-
batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
|
212
|
-
|
213
|
-
adata_latent = AnnData(latent_all)
|
214
|
-
adata_latent.obs = adata.obs.copy(deep=True)
|
215
|
-
unique_cell_types = np.unique(adata_latent.obs[cell_label_key])
|
216
|
-
shared_ct = []
|
217
|
-
not_shared_ct = []
|
218
|
-
for cell_type in unique_cell_types:
|
219
|
-
temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
|
220
|
-
if len(np.unique(temp_cell.obs[batch_key])) < 2:
|
221
|
-
cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
|
222
|
-
not_shared_ct.append(cell_type_ann)
|
223
|
-
continue
|
224
|
-
temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
|
225
|
-
batch_list = {}
|
226
|
-
batch_ind = {}
|
227
|
-
max_batch = 0
|
228
|
-
max_batch_ind = ""
|
229
|
-
batches = np.unique(temp_cell.obs[batch_key])
|
230
|
-
for i in batches:
|
231
|
-
temp = temp_cell[temp_cell.obs[batch_key] == i]
|
232
|
-
temp_ind = temp_cell.obs[batch_key] == i
|
233
|
-
if max_batch < len(temp):
|
234
|
-
max_batch = len(temp)
|
235
|
-
max_batch_ind = i
|
236
|
-
batch_list[i] = temp
|
237
|
-
batch_ind[i] = temp_ind
|
238
|
-
max_batch_ann = batch_list[max_batch_ind]
|
239
|
-
for study in batch_list:
|
240
|
-
delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
|
241
|
-
batch_list[study].X = delta + batch_list[study].X
|
242
|
-
temp_cell[batch_ind[study]].X = batch_list[study].X
|
243
|
-
shared_ct.append(temp_cell)
|
244
|
-
|
245
|
-
all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
|
246
|
-
if "concat_batch" in all_shared_ann.obs.columns:
|
247
|
-
del all_shared_ann.obs["concat_batch"]
|
248
|
-
if len(not_shared_ct) < 1:
|
249
|
-
corrected = AnnData(
|
250
|
-
np.array(self.module.as_bound().generative(all_shared_ann.X)["px"]),
|
251
|
-
obs=all_shared_ann.obs,
|
252
|
-
)
|
253
|
-
corrected.var_names = adata.var_names.tolist()
|
254
|
-
corrected = corrected[adata.obs_names]
|
255
|
-
if adata.raw is not None:
|
256
|
-
adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
|
257
|
-
adata_raw.obs_names = adata.obs_names
|
258
|
-
corrected.raw = adata_raw
|
259
|
-
corrected.obsm["latent"] = all_shared_ann.X
|
260
|
-
corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
|
261
|
-
return corrected
|
262
|
-
else:
|
263
|
-
all_not_shared_ann = AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
|
264
|
-
all_corrected_data = AnnData.concatenate(
|
265
|
-
all_shared_ann,
|
266
|
-
all_not_shared_ann,
|
267
|
-
batch_key="concat_batch",
|
268
|
-
index_unique=None,
|
269
|
-
)
|
270
|
-
if "concat_batch" in all_shared_ann.obs.columns:
|
271
|
-
del all_corrected_data.obs["concat_batch"]
|
272
|
-
corrected = AnnData(
|
273
|
-
np.array(self.module.as_bound().generative(all_corrected_data.X)["px"]),
|
274
|
-
obs=all_corrected_data.obs,
|
275
|
-
)
|
276
|
-
corrected.var_names = adata.var_names.tolist()
|
277
|
-
corrected = corrected[adata.obs_names]
|
278
|
-
if adata.raw is not None:
|
279
|
-
adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
|
280
|
-
adata_raw.obs_names = adata.obs_names
|
281
|
-
corrected.raw = adata_raw
|
282
|
-
corrected.obsm["latent"] = all_corrected_data.X
|
283
|
-
corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
|
284
|
-
|
285
|
-
return corrected
|
286
|
-
|
287
|
-
@classmethod
|
288
|
-
@setup_anndata_dsp.dedent
|
289
|
-
def setup_anndata(
|
290
|
-
cls,
|
291
|
-
adata: AnnData,
|
292
|
-
batch_key: str | None = None,
|
293
|
-
labels_key: str | None = None,
|
294
|
-
**kwargs,
|
295
|
-
):
|
296
|
-
"""%(summary)s.
|
297
|
-
|
298
|
-
scGen expects the expression data to come from `adata.X`
|
299
|
-
|
300
|
-
%(param_batch_key)s
|
301
|
-
%(param_labels_key)s
|
302
|
-
|
303
|
-
Examples:
|
304
|
-
>>> import pertpy as pt
|
305
|
-
>>> data = pt.dt.kang_2018()
|
306
|
-
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
307
|
-
"""
|
308
|
-
setup_method_args = cls._get_setup_method_args(**locals())
|
309
|
-
anndata_fields = [
|
310
|
-
LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=False),
|
311
|
-
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
|
312
|
-
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
|
313
|
-
]
|
314
|
-
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
|
315
|
-
adata_manager.register_fields(adata, **kwargs)
|
316
|
-
cls.register_manager(adata_manager)
|
317
|
-
|
318
|
-
def to_device(self, device):
|
319
|
-
pass
|
320
|
-
|
321
|
-
@property
|
322
|
-
def device(self):
|
323
|
-
return self.module.device
|
324
|
-
|
325
|
-
def get_latent_representation(
|
326
|
-
self,
|
327
|
-
adata: AnnData | None = None,
|
328
|
-
indices: Sequence[int] | None = None,
|
329
|
-
give_mean: bool = True,
|
330
|
-
n_samples: int = 1,
|
331
|
-
batch_size: int | None = None,
|
332
|
-
) -> np.ndarray:
|
333
|
-
"""Return the latent representation for each cell.
|
334
|
-
|
335
|
-
Args:
|
336
|
-
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
337
|
-
AnnData object used to initialize the model.
|
338
|
-
indices: Indices of cells in adata to use. If `None`, all cells are used.
|
339
|
-
batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
|
340
|
-
|
341
|
-
Returns:
|
342
|
-
Low-dimensional representation for each cell
|
343
|
-
|
344
|
-
Examples:
|
345
|
-
>>> import pertpy as pt
|
346
|
-
>>> data = pt.dt.kang_2018()
|
347
|
-
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
348
|
-
>>> model = pt.tl.SCGEN(data)
|
349
|
-
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
350
|
-
>>> latent_X = model.get_latent_representation()
|
351
|
-
"""
|
352
|
-
self._check_if_trained(warn=False)
|
353
|
-
|
354
|
-
adata = self._validate_anndata(adata)
|
355
|
-
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True)
|
356
|
-
|
357
|
-
jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"n_samples": n_samples})
|
358
|
-
|
359
|
-
latent = []
|
360
|
-
for array_dict in scdl:
|
361
|
-
out = jit_inference_fn(self.module.rngs, array_dict)
|
362
|
-
if give_mean:
|
363
|
-
z = out["qz"].mean
|
364
|
-
else:
|
365
|
-
z = out["z"]
|
366
|
-
latent.append(z)
|
367
|
-
concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
|
368
|
-
latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore
|
369
|
-
|
370
|
-
return self.module.as_numpy_array(latent)
|
pertpy-0.6.0.dist-info/RECORD
DELETED
@@ -1,50 +0,0 @@
|
|
1
|
-
pertpy/__init__.py,sha256=3__crpMVG7ky5lmD91Pq9qIGWgUuZQTH8xpiM5qcUJA,546
|
2
|
-
pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
pertpy/data/__init__.py,sha256=dvFUk-vAVelA65esA4EbIAVEQoE3s9K6LmE31-j2fC0,1197
|
4
|
-
pertpy/data/_dataloader.py,sha256=pNDXLSNzOLeFM_mqf9nNvN_6Y4uA4gJfG3Y7VS-03ko,2397
|
5
|
-
pertpy/data/_datasets.py,sha256=q20-f7MT2neWTplN300QOBlu-ihWa8IRTKxUgLgemIw,59496
|
6
|
-
pertpy/plot/__init__.py,sha256=HB6nEBfOPmOVRHOJsJ7IJcxx2j6-6oQ__sJRaszBKuk,455
|
7
|
-
pertpy/plot/_augur.py,sha256=pRhgc1RdRhXp6xl7-y8Z4o8beUBfltJY3XUeN9GJKbs,9064
|
8
|
-
pertpy/plot/_cinemaot.py,sha256=tPTab-5jqalGLfa1NNeevG3_ExbKRfnIE8RRnt8Eecc,3199
|
9
|
-
pertpy/plot/_coda.py,sha256=Ma24jc5KhuY3dtIJ6xO-pp0JpW7vWc-TPhSKJMXBEmQ,43650
|
10
|
-
pertpy/plot/_dialogue.py,sha256=TGv_fb5f1zPEaJA8SgCue77IJkHKsQLR8f8oIz9SEcE,3881
|
11
|
-
pertpy/plot/_guide_rna.py,sha256=Z-_vjHcOIK-DXLDTZGl5HmG6A2TnJBHv9L8VK7L3_fA,3286
|
12
|
-
pertpy/plot/_milopy.py,sha256=6K9DtmHiCh6FUb5xScUZTxXUZoRCwD0oyfAMu0SmRGA,10994
|
13
|
-
pertpy/plot/_mixscape.py,sha256=KeLCqWRcn2092VqB94PqBtP_wxD_OY4uS8GcZ2RXc7Y,27903
|
14
|
-
pertpy/plot/_scgen.py,sha256=KnPe8iOqDDZw0MpSxOU7Xr-2t1UtHKehYgBQ7_4O8d4,15125
|
15
|
-
pertpy/preprocessing/__init__.py,sha256=uja9T469LLYQAGgrTyFa4MudXci6NXnAgOn97FHXcxA,40
|
16
|
-
pertpy/preprocessing/_guide_rna.py,sha256=EYSrsMP7FpztS0NQhn1xg0oBZZ5RT5fz6YBFvmOab58,4247
|
17
|
-
pertpy/tools/__init__.py,sha256=QiFFM1IL7K47vuTbQqjgB8rVzauWmn6JVVpQG9AikvA,1108
|
18
|
-
pertpy/tools/_augur.py,sha256=EUe-aRGO-PzszTS8vMfUJtzpfC3CmUSorSJTkEEU60w,45193
|
19
|
-
pertpy/tools/_cinemaot.py,sha256=bqbxc88AH4vo2--Y5yLH3anuu1prWDAxoRZaiNvOgtQ,33374
|
20
|
-
pertpy/tools/_dialogue.py,sha256=OUSjPzTRi46WG5QARoj2_fpmr7IQ2ftTlXT3-OiiWJc,48116
|
21
|
-
pertpy/tools/_differential_gene_expression.py,sha256=mR06huO71KRLcU32ktCWzL-XxA9IGz8OYiRZA26eH0E,3681
|
22
|
-
pertpy/tools/_kernel_pca.py,sha256=3S1D_wrp4vlHUPiRbCAoRbUyY-rVs112Qh-BZHSmTxE,1578
|
23
|
-
pertpy/tools/_milo.py,sha256=OyLztlNO4Jt1c2aN3WsBbcA0UKVXVvWAnTaKwjPwJ2I,30737
|
24
|
-
pertpy/tools/_mixscape.py,sha256=l3YHeyaUUrtuP9P8L5Z7gH47lJpzb0glszMX84DyJBI,23559
|
25
|
-
pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
|
26
|
-
pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
-
pertpy/tools/_coda/_base_coda.py,sha256=mxNe5PT1XvIlZmvjQg50kh_bSmeTGVzOC63XLw2TdiI,66859
|
28
|
-
pertpy/tools/_coda/_sccoda.py,sha256=cxaqGsXxeLf4guTU1HApAzXN2maQPexsGXIJOlW8UTM,21616
|
29
|
-
pertpy/tools/_coda/_tasccoda.py,sha256=q0I7zM_hGjPrpy5dF2Z9trw6u8OqdkrypGgeuAhi26k,30721
|
30
|
-
pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
|
-
pertpy/tools/_distances/_distance_tests.py,sha256=zRcOeLc18mRnUJ-_usUdVxWn3cZqZ8gLhglt77SaF9k,13604
|
32
|
-
pertpy/tools/_distances/_distances.py,sha256=RMNtCD1zkORDE35XWcrh_6mw1c03hOQflmXNfoNtSRA,29780
|
33
|
-
pertpy/tools/_metadata/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
34
|
-
pertpy/tools/_metadata/_cell_line.py,sha256=4sUULdmxQ3TFUZDCwikN9TcHG5hf2hzlEO6gOglGl-A,33830
|
35
|
-
pertpy/tools/_metadata/_look_up.py,sha256=H7kp9MgfgYMVdxyg3Qpf3_QmqNUkKFNMsswWeA_e1rQ,18200
|
36
|
-
pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
37
|
-
pertpy/tools/_perturbation_space/_clustering.py,sha256=ha0TfRKUIFJmL6LE-xIfENAlYyQf4nfTpgg47X_2pHA,3237
|
38
|
-
pertpy/tools/_perturbation_space/_discriminator_classifier.py,sha256=hTEAKTnLH4ToSdEHYuJnwui3B8L-zlSR667oG3yb49M,13861
|
39
|
-
pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
|
40
|
-
pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=_A96OFbpjZULcQGfbsDhXiBjvD0chBl6c-4FoQNoV3w,14169
|
41
|
-
pertpy/tools/_perturbation_space/_simple.py,sha256=AZx8GaNJV67evSi5oUkY11QcUkq3EcL0mtkCipjcx6c,10367
|
42
|
-
pertpy/tools/_scgen/__init__.py,sha256=bMQ_2QbB4nnzQ7TzhI4DEFfuCDUNbZkL5xDClhQjhcA,49
|
43
|
-
pertpy/tools/_scgen/_base_components.py,sha256=dIw-_7Z8iCietPF4tnpM7bFHtDksjnaHXwUjp9GoCIQ,2936
|
44
|
-
pertpy/tools/_scgen/_jax_scgen.py,sha256=6fmen3zQm54Yprmd3r7zJK3GIWqpMd034DLGmi-krrs,15368
|
45
|
-
pertpy/tools/_scgen/_jax_scgenvae.py,sha256=v_6tZ4wY-JjdMH1QVd_wG4_N0PoaqB-FM8zC2JsDu1o,3935
|
46
|
-
pertpy/tools/_scgen/_utils.py,sha256=_G9cxBVcTIOs4wN0pgtOSkCsPJoohkeRDIb_anUqSfY,2871
|
47
|
-
pertpy-0.6.0.dist-info/METADATA,sha256=bmYUVV99CMPm870ehtSiTbB6lPsYg0kSrmK1aoCvuu8,5046
|
48
|
-
pertpy-0.6.0.dist-info/WHEEL,sha256=9QBuHhg6FNW7lppboF2vKVbCGTVzsFykgRQjjlajrhA,87
|
49
|
-
pertpy-0.6.0.dist-info/licenses/LICENSE,sha256=OZ-ZkXM5CmExJiEMM90b_7dGNNvRpj7kdE-49AnrLuI,1070
|
50
|
-
pertpy-0.6.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|