pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,706 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any
|
4
|
+
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import matplotlib.pyplot as plt
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import scanpy as sc
|
10
|
+
from adjustText import adjust_text
|
11
|
+
from anndata import AnnData
|
12
|
+
from jax import Array
|
13
|
+
from lamin_utils import logger
|
14
|
+
from scipy import stats
|
15
|
+
from scvi import REGISTRY_KEYS
|
16
|
+
from scvi.data import AnnDataManager
|
17
|
+
from scvi.data.fields import CategoricalObsField, LayerField
|
18
|
+
from scvi.model.base import BaseModelClass, JaxTrainingMixin
|
19
|
+
from scvi.utils import setup_anndata_dsp
|
20
|
+
|
21
|
+
from ._scgenvae import JaxSCGENVAE
|
22
|
+
from ._utils import balancer, extractor
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from collections.abc import Sequence
|
26
|
+
|
27
|
+
font = {"family": "Arial", "size": 14}
|
28
|
+
|
29
|
+
|
30
|
+
class Scgen(JaxTrainingMixin, BaseModelClass):
|
31
|
+
"""Jax Implementation of scGen model for batch removal and perturbation prediction."""
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
adata: AnnData,
|
36
|
+
n_hidden: int = 800,
|
37
|
+
n_latent: int = 100,
|
38
|
+
n_layers: int = 2,
|
39
|
+
dropout_rate: float = 0.2,
|
40
|
+
**model_kwargs,
|
41
|
+
):
|
42
|
+
super().__init__(adata)
|
43
|
+
|
44
|
+
self.module = JaxSCGENVAE(
|
45
|
+
n_input=self.summary_stats.n_vars,
|
46
|
+
n_hidden=n_hidden,
|
47
|
+
n_latent=n_latent,
|
48
|
+
n_layers=n_layers,
|
49
|
+
dropout_rate=dropout_rate,
|
50
|
+
**model_kwargs,
|
51
|
+
)
|
52
|
+
self._model_summary_string = (
|
53
|
+
f"Scgen Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
|
54
|
+
f"{dropout_rate}"
|
55
|
+
)
|
56
|
+
self.init_params_ = self._get_init_params(locals())
|
57
|
+
|
58
|
+
def predict(
|
59
|
+
self,
|
60
|
+
ctrl_key=None,
|
61
|
+
stim_key=None,
|
62
|
+
adata_to_predict=None,
|
63
|
+
celltype_to_predict=None,
|
64
|
+
restrict_arithmetic_to="all",
|
65
|
+
) -> tuple[AnnData, Any]:
|
66
|
+
"""Predicts the cell type provided by the user in stimulated condition.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
ctrl_key: Key for `control` part of the `data` found in `condition_key`.
|
70
|
+
stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
|
71
|
+
adata_to_predict: Adata for unperturbed cells you want to be predicted.
|
72
|
+
celltype_to_predict: The cell type you want to be predicted.
|
73
|
+
restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
`np nd-array` of predicted cells in primary space.
|
77
|
+
delta: float
|
78
|
+
Difference between stimulated and control cells in latent space
|
79
|
+
|
80
|
+
Examples:
|
81
|
+
>>> import pertpy as pt
|
82
|
+
>>> data = pt.dt.kang_2018()
|
83
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
84
|
+
>>> model = pt.tl.Scgen(data)
|
85
|
+
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
86
|
+
>>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
|
87
|
+
"""
|
88
|
+
# use keys registered from `setup_anndata()`
|
89
|
+
cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
|
90
|
+
condition_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
|
91
|
+
|
92
|
+
if restrict_arithmetic_to == "all":
|
93
|
+
ctrl_x = self.adata[self.adata.obs[condition_key] == ctrl_key, :]
|
94
|
+
stim_x = self.adata[self.adata.obs[condition_key] == stim_key, :]
|
95
|
+
ctrl_x = balancer(ctrl_x, cell_type_key)
|
96
|
+
stim_x = balancer(stim_x, cell_type_key)
|
97
|
+
else:
|
98
|
+
key = list(restrict_arithmetic_to.keys())[0]
|
99
|
+
values = restrict_arithmetic_to[key]
|
100
|
+
subset = self.adata[self.adata.obs[key].isin(values)]
|
101
|
+
ctrl_x = subset[subset.obs[condition_key] == ctrl_key, :]
|
102
|
+
stim_x = subset[subset.obs[condition_key] == stim_key, :]
|
103
|
+
if len(values) > 1:
|
104
|
+
ctrl_x = balancer(ctrl_x, cell_type_key)
|
105
|
+
stim_x = balancer(stim_x, cell_type_key)
|
106
|
+
if celltype_to_predict is not None and adata_to_predict is not None:
|
107
|
+
raise Exception("Please provide either a cell type or adata not both!")
|
108
|
+
if celltype_to_predict is None and adata_to_predict is None:
|
109
|
+
raise Exception("Please provide a cell type name or adata for your unperturbed cells")
|
110
|
+
if celltype_to_predict is not None:
|
111
|
+
ctrl_pred = extractor(
|
112
|
+
self.adata,
|
113
|
+
celltype_to_predict,
|
114
|
+
condition_key,
|
115
|
+
cell_type_key,
|
116
|
+
ctrl_key,
|
117
|
+
stim_key,
|
118
|
+
)[1]
|
119
|
+
else:
|
120
|
+
ctrl_pred = adata_to_predict
|
121
|
+
|
122
|
+
eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0])
|
123
|
+
rng = np.random.default_rng()
|
124
|
+
cd_ind = rng.choice(range(ctrl_x.shape[0]), size=eq, replace=False)
|
125
|
+
stim_ind = rng.choice(range(stim_x.shape[0]), size=eq, replace=False)
|
126
|
+
ctrl_adata = ctrl_x[cd_ind, :]
|
127
|
+
stim_adata = stim_x[stim_ind, :]
|
128
|
+
|
129
|
+
latent_ctrl = self._avg_vector(ctrl_adata)
|
130
|
+
latent_stim = self._avg_vector(stim_adata)
|
131
|
+
|
132
|
+
delta = latent_stim - latent_ctrl
|
133
|
+
|
134
|
+
latent_cd = self.get_latent_representation(ctrl_pred)
|
135
|
+
|
136
|
+
stim_pred = delta + latent_cd
|
137
|
+
predicted_cells = self.module.as_bound().generative(stim_pred)["px"]
|
138
|
+
|
139
|
+
predicted_adata = AnnData(
|
140
|
+
X=np.array(predicted_cells),
|
141
|
+
obs=ctrl_pred.obs.copy(),
|
142
|
+
var=ctrl_pred.var.copy(),
|
143
|
+
obsm=ctrl_pred.obsm.copy(),
|
144
|
+
)
|
145
|
+
return predicted_adata, delta
|
146
|
+
|
147
|
+
def _avg_vector(self, adata):
|
148
|
+
return np.mean(self.get_latent_representation(adata), axis=0)
|
149
|
+
|
150
|
+
def get_decoded_expression(
|
151
|
+
self,
|
152
|
+
adata: AnnData | None = None,
|
153
|
+
indices: Sequence[int] | None = None,
|
154
|
+
batch_size: int | None = None,
|
155
|
+
) -> Array:
|
156
|
+
"""Get decoded expression.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
160
|
+
AnnData object used to initialize the model.
|
161
|
+
indices: Indices of cells in adata to use. If `None`, all cells are used.
|
162
|
+
batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
Decoded expression for each cell
|
166
|
+
|
167
|
+
Examples:
|
168
|
+
>>> import pertpy as pt
|
169
|
+
>>> data = pt.dt.kang_2018()
|
170
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
171
|
+
>>> model = pt.tl.Scgen(data)
|
172
|
+
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
173
|
+
>>> decoded_X = model.get_decoded_expression()
|
174
|
+
"""
|
175
|
+
if self.is_trained_ is False:
|
176
|
+
raise RuntimeError("Please train the model first.")
|
177
|
+
|
178
|
+
adata = self._validate_anndata(adata)
|
179
|
+
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
|
180
|
+
decoded = []
|
181
|
+
for tensors in scdl:
|
182
|
+
_, generative_outputs = self.module.as_bound()(tensors, compute_loss=False)
|
183
|
+
px = generative_outputs["px"]
|
184
|
+
decoded.append(px)
|
185
|
+
|
186
|
+
return jnp.concatenate(decoded)
|
187
|
+
|
188
|
+
def batch_removal(self, adata: AnnData | None = None) -> AnnData:
|
189
|
+
"""Removes batch effects.
|
190
|
+
|
191
|
+
Args:
|
192
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
193
|
+
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
194
|
+
corresponding to batch and cell type metadata, respectively.
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
corrected: `~anndata.AnnData`
|
198
|
+
AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
|
199
|
+
A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
|
200
|
+
|
201
|
+
Examples:
|
202
|
+
>>> import pertpy as pt
|
203
|
+
>>> data = pt.dt.kang_2018()
|
204
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
205
|
+
>>> model = pt.tl.Scgen(data)
|
206
|
+
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
207
|
+
>>> corrected_adata = model.batch_removal()
|
208
|
+
"""
|
209
|
+
adata = self._validate_anndata(adata)
|
210
|
+
latent_all = self.get_latent_representation(adata)
|
211
|
+
# use keys registered from `setup_anndata()`
|
212
|
+
cell_label_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
|
213
|
+
batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
|
214
|
+
|
215
|
+
adata_latent = AnnData(latent_all)
|
216
|
+
adata_latent.obs = adata.obs.copy(deep=True)
|
217
|
+
unique_cell_types = np.unique(adata_latent.obs[cell_label_key])
|
218
|
+
shared_ct = []
|
219
|
+
not_shared_ct = []
|
220
|
+
for cell_type in unique_cell_types:
|
221
|
+
temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
|
222
|
+
if len(np.unique(temp_cell.obs[batch_key])) < 2:
|
223
|
+
cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
|
224
|
+
not_shared_ct.append(cell_type_ann)
|
225
|
+
continue
|
226
|
+
temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
|
227
|
+
batch_list = {}
|
228
|
+
batch_ind = {}
|
229
|
+
max_batch = 0
|
230
|
+
max_batch_ind = ""
|
231
|
+
batches = np.unique(temp_cell.obs[batch_key])
|
232
|
+
for i in batches:
|
233
|
+
temp = temp_cell[temp_cell.obs[batch_key] == i]
|
234
|
+
temp_ind = temp_cell.obs[batch_key] == i
|
235
|
+
if max_batch < len(temp):
|
236
|
+
max_batch = len(temp)
|
237
|
+
max_batch_ind = i
|
238
|
+
batch_list[i] = temp
|
239
|
+
batch_ind[i] = temp_ind
|
240
|
+
max_batch_ann = batch_list[max_batch_ind]
|
241
|
+
for study in batch_list:
|
242
|
+
delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
|
243
|
+
batch_list[study].X = delta + batch_list[study].X
|
244
|
+
temp_cell[batch_ind[study]].X = batch_list[study].X
|
245
|
+
shared_ct.append(temp_cell)
|
246
|
+
|
247
|
+
all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
|
248
|
+
if "concat_batch" in all_shared_ann.obs.columns:
|
249
|
+
del all_shared_ann.obs["concat_batch"]
|
250
|
+
if len(not_shared_ct) < 1:
|
251
|
+
corrected = AnnData(
|
252
|
+
np.array(self.module.as_bound().generative(all_shared_ann.X)["px"]),
|
253
|
+
obs=all_shared_ann.obs,
|
254
|
+
)
|
255
|
+
corrected.var_names = adata.var_names.tolist()
|
256
|
+
corrected = corrected[adata.obs_names]
|
257
|
+
if adata.raw is not None:
|
258
|
+
adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
|
259
|
+
adata_raw.obs_names = adata.obs_names
|
260
|
+
corrected.raw = adata_raw
|
261
|
+
corrected.obsm["latent"] = all_shared_ann.X
|
262
|
+
corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
|
263
|
+
return corrected
|
264
|
+
else:
|
265
|
+
all_not_shared_ann = AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
|
266
|
+
all_corrected_data = AnnData.concatenate(
|
267
|
+
all_shared_ann,
|
268
|
+
all_not_shared_ann,
|
269
|
+
batch_key="concat_batch",
|
270
|
+
index_unique=None,
|
271
|
+
)
|
272
|
+
if "concat_batch" in all_shared_ann.obs.columns:
|
273
|
+
del all_corrected_data.obs["concat_batch"]
|
274
|
+
corrected = AnnData(
|
275
|
+
np.array(self.module.as_bound().generative(all_corrected_data.X)["px"]),
|
276
|
+
obs=all_corrected_data.obs,
|
277
|
+
)
|
278
|
+
corrected.var_names = adata.var_names.tolist()
|
279
|
+
corrected = corrected[adata.obs_names]
|
280
|
+
if adata.raw is not None:
|
281
|
+
adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
|
282
|
+
adata_raw.obs_names = adata.obs_names
|
283
|
+
corrected.raw = adata_raw
|
284
|
+
corrected.obsm["latent"] = all_corrected_data.X
|
285
|
+
corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
|
286
|
+
|
287
|
+
return corrected
|
288
|
+
|
289
|
+
@classmethod
|
290
|
+
@setup_anndata_dsp.dedent
|
291
|
+
def setup_anndata(
|
292
|
+
cls,
|
293
|
+
adata: AnnData,
|
294
|
+
batch_key: str | None = None,
|
295
|
+
labels_key: str | None = None,
|
296
|
+
**kwargs,
|
297
|
+
):
|
298
|
+
"""%(summary)s.
|
299
|
+
|
300
|
+
scGen expects the expression data to come from `adata.X`
|
301
|
+
|
302
|
+
%(param_batch_key)s
|
303
|
+
%(param_labels_key)s
|
304
|
+
|
305
|
+
Examples:
|
306
|
+
>>> import pertpy as pt
|
307
|
+
>>> data = pt.dt.kang_2018()
|
308
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
309
|
+
"""
|
310
|
+
setup_method_args = cls._get_setup_method_args(**locals())
|
311
|
+
anndata_fields = [
|
312
|
+
LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=False),
|
313
|
+
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
|
314
|
+
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
|
315
|
+
]
|
316
|
+
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
|
317
|
+
adata_manager.register_fields(adata, **kwargs)
|
318
|
+
cls.register_manager(adata_manager)
|
319
|
+
|
320
|
+
def to_device(self, device):
|
321
|
+
pass
|
322
|
+
|
323
|
+
@property
|
324
|
+
def device(self):
|
325
|
+
return self.module.device
|
326
|
+
|
327
|
+
def get_latent_representation(
|
328
|
+
self,
|
329
|
+
adata: AnnData | None = None,
|
330
|
+
indices: Sequence[int] | None = None,
|
331
|
+
give_mean: bool = True,
|
332
|
+
n_samples: int = 1,
|
333
|
+
batch_size: int | None = None,
|
334
|
+
) -> np.ndarray:
|
335
|
+
"""Return the latent representation for each cell.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
339
|
+
AnnData object used to initialize the model.
|
340
|
+
indices: Indices of cells in adata to use. If `None`, all cells are used.
|
341
|
+
batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
|
342
|
+
|
343
|
+
Returns:
|
344
|
+
Low-dimensional representation for each cell
|
345
|
+
|
346
|
+
Examples:
|
347
|
+
>>> import pertpy as pt
|
348
|
+
>>> data = pt.dt.kang_2018()
|
349
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
350
|
+
>>> model = pt.tl.Scgen(data)
|
351
|
+
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
352
|
+
>>> latent_X = model.get_latent_representation()
|
353
|
+
"""
|
354
|
+
self._check_if_trained(warn=False)
|
355
|
+
|
356
|
+
adata = self._validate_anndata(adata)
|
357
|
+
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True)
|
358
|
+
|
359
|
+
jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"n_samples": n_samples})
|
360
|
+
|
361
|
+
latent = []
|
362
|
+
for array_dict in scdl:
|
363
|
+
out = jit_inference_fn(self.module.rngs, array_dict)
|
364
|
+
if give_mean:
|
365
|
+
z = out["qz"].mean
|
366
|
+
else:
|
367
|
+
z = out["z"]
|
368
|
+
latent.append(z)
|
369
|
+
concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
|
370
|
+
latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore
|
371
|
+
|
372
|
+
return self.module.as_numpy_array(latent)
|
373
|
+
|
374
|
+
def plot_reg_mean_plot(
|
375
|
+
self,
|
376
|
+
adata,
|
377
|
+
condition_key: str,
|
378
|
+
axis_keys: dict[str, str],
|
379
|
+
labels: dict[str, str],
|
380
|
+
save: str | bool | None = None,
|
381
|
+
gene_list: list[str] = None,
|
382
|
+
show: bool = False,
|
383
|
+
top_100_genes: list[str] = None,
|
384
|
+
verbose: bool = False,
|
385
|
+
legend: bool = True,
|
386
|
+
title: str = None,
|
387
|
+
x_coeff: float = 0.30,
|
388
|
+
y_coeff: float = 0.8,
|
389
|
+
fontsize: float = 14,
|
390
|
+
**kwargs,
|
391
|
+
) -> tuple[float, float] | float:
|
392
|
+
"""Plots mean matching for a set of specified genes.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
396
|
+
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
397
|
+
corresponding to batch and cell type metadata, respectively.
|
398
|
+
condition_key: The key for the condition
|
399
|
+
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
400
|
+
`{"x": "Key for x-axis", "y": "Key for y-axis"}`.
|
401
|
+
labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
|
402
|
+
path_to_save: path to save the plot.
|
403
|
+
save: Specify if the plot should be saved or not.
|
404
|
+
gene_list: list of gene names to be plotted.
|
405
|
+
show: if `True`: will show to the plot after saving it.
|
406
|
+
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
407
|
+
verbose: Specify if you want information to be printed while creating the plot.,
|
408
|
+
legend: Whether to plot a legend.
|
409
|
+
title: Set if you want the plot to display a title.
|
410
|
+
x_coeff: Offset to print the R^2 value in x-direction.
|
411
|
+
y_coeff: Offset to print the R^2 value in y-direction.
|
412
|
+
fontsize: Fontsize used for text in the plot.
|
413
|
+
**kwargs:
|
414
|
+
|
415
|
+
Examples:
|
416
|
+
>>> import pertpy as pt
|
417
|
+
>>> data = pt.dt.kang_2018()
|
418
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
419
|
+
>>> scg = pt.tl.Scgen(data)
|
420
|
+
>>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
421
|
+
>>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
|
422
|
+
>>> pred.obs['label'] = 'pred'
|
423
|
+
>>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred)
|
424
|
+
>>> r2_value = scg.plot_reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \
|
425
|
+
labels={"x": "predicted", "y": "ground truth"}, save=False, show=True)
|
426
|
+
|
427
|
+
Preview:
|
428
|
+
.. image:: /_static/docstring_previews/scgen_reg_mean.png
|
429
|
+
"""
|
430
|
+
import seaborn as sns
|
431
|
+
|
432
|
+
sns.set_theme()
|
433
|
+
sns.set_theme(color_codes=True)
|
434
|
+
|
435
|
+
diff_genes = top_100_genes
|
436
|
+
stim = adata[adata.obs[condition_key] == axis_keys["y"]]
|
437
|
+
ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
|
438
|
+
if diff_genes is not None:
|
439
|
+
if hasattr(diff_genes, "tolist"):
|
440
|
+
diff_genes = diff_genes.tolist()
|
441
|
+
adata_diff = adata[:, diff_genes]
|
442
|
+
stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
|
443
|
+
ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
|
444
|
+
x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel()
|
445
|
+
y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
|
446
|
+
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
447
|
+
if verbose:
|
448
|
+
logger.info("top_100 DEGs mean: ", r_value_diff**2)
|
449
|
+
x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
|
450
|
+
y = np.asarray(np.mean(stim.X, axis=0)).ravel()
|
451
|
+
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
452
|
+
if verbose:
|
453
|
+
logger.info("All genes mean: ", r_value**2)
|
454
|
+
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
455
|
+
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
456
|
+
ax.tick_params(labelsize=fontsize)
|
457
|
+
if "range" in kwargs:
|
458
|
+
start, stop, step = kwargs.get("range")
|
459
|
+
ax.set_xticks(np.arange(start, stop, step))
|
460
|
+
ax.set_yticks(np.arange(start, stop, step))
|
461
|
+
ax.set_xlabel(labels["x"], fontsize=fontsize)
|
462
|
+
ax.set_ylabel(labels["y"], fontsize=fontsize)
|
463
|
+
if gene_list is not None:
|
464
|
+
texts = []
|
465
|
+
for i in gene_list:
|
466
|
+
j = adata.var_names.tolist().index(i)
|
467
|
+
x_bar = x[j]
|
468
|
+
y_bar = y[j]
|
469
|
+
texts.append(plt.text(x_bar, y_bar, i, fontsize=11, color="black"))
|
470
|
+
plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
|
471
|
+
# if "y1" in axis_keys.keys():
|
472
|
+
# y1_bar = y1[j]
|
473
|
+
# plt.text(x_bar, y1_bar, i, fontsize=11, color="black")
|
474
|
+
if gene_list is not None:
|
475
|
+
adjust_text(
|
476
|
+
texts,
|
477
|
+
x=x,
|
478
|
+
y=y,
|
479
|
+
arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5},
|
480
|
+
force_static=(0.0, 0.0),
|
481
|
+
)
|
482
|
+
if legend:
|
483
|
+
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
484
|
+
if title is None:
|
485
|
+
plt.title("", fontsize=fontsize)
|
486
|
+
else:
|
487
|
+
plt.title(title, fontsize=fontsize)
|
488
|
+
ax.text(
|
489
|
+
max(x) - max(x) * x_coeff,
|
490
|
+
max(y) - y_coeff * max(y),
|
491
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
|
492
|
+
fontsize=kwargs.get("textsize", fontsize),
|
493
|
+
)
|
494
|
+
if diff_genes is not None:
|
495
|
+
ax.text(
|
496
|
+
max(x) - max(x) * x_coeff,
|
497
|
+
max(y) - (y_coeff + 0.15) * max(y),
|
498
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
|
499
|
+
fontsize=kwargs.get("textsize", fontsize),
|
500
|
+
)
|
501
|
+
if save:
|
502
|
+
plt.savefig(save, bbox_inches="tight")
|
503
|
+
if show:
|
504
|
+
plt.show()
|
505
|
+
plt.close()
|
506
|
+
if diff_genes is not None:
|
507
|
+
return r_value**2, r_value_diff**2
|
508
|
+
else:
|
509
|
+
return r_value**2
|
510
|
+
|
511
|
+
def plot_reg_var_plot(
|
512
|
+
self,
|
513
|
+
adata,
|
514
|
+
condition_key: str,
|
515
|
+
axis_keys: dict[str, str],
|
516
|
+
labels: dict[str, str],
|
517
|
+
save: str | bool | None = None,
|
518
|
+
gene_list: list[str] = None,
|
519
|
+
top_100_genes: list[str] = None,
|
520
|
+
show: bool = False,
|
521
|
+
legend: bool = True,
|
522
|
+
title: str = None,
|
523
|
+
verbose: bool = False,
|
524
|
+
x_coeff: float = 0.3,
|
525
|
+
y_coeff: float = 0.8,
|
526
|
+
fontsize: float = 14,
|
527
|
+
**kwargs,
|
528
|
+
) -> tuple[float, float] | float:
|
529
|
+
"""Plots variance matching for a set of specified genes.
|
530
|
+
|
531
|
+
Args:
|
532
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
533
|
+
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
534
|
+
corresponding to batch and cell type metadata, respectively.
|
535
|
+
condition_key: Key of the condition.
|
536
|
+
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
537
|
+
`{"x": "Key for x-axis", "y": "Key for y-axis"}`.
|
538
|
+
labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
|
539
|
+
path_to_save: path to save the plot.
|
540
|
+
save: Specify if the plot should be saved or not.
|
541
|
+
gene_list: list of gene names to be plotted.
|
542
|
+
show: if `True`: will show to the plot after saving it.
|
543
|
+
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
544
|
+
legend: Whether to plot a elgend
|
545
|
+
title: Set if you want the plot to display a title.
|
546
|
+
verbose: Specify if you want information to be printed while creating the plot.
|
547
|
+
x_coeff: Offset to print the R^2 value in x-direction.
|
548
|
+
y_coeff: Offset to print the R^2 value in y-direction.
|
549
|
+
fontsize: Fontsize used for text in the plot.
|
550
|
+
"""
|
551
|
+
import seaborn as sns
|
552
|
+
|
553
|
+
sns.set_theme()
|
554
|
+
sns.set_theme(color_codes=True)
|
555
|
+
|
556
|
+
sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon")
|
557
|
+
diff_genes = top_100_genes
|
558
|
+
stim = adata[adata.obs[condition_key] == axis_keys["y"]]
|
559
|
+
ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
|
560
|
+
if diff_genes is not None:
|
561
|
+
if hasattr(diff_genes, "tolist"):
|
562
|
+
diff_genes = diff_genes.tolist()
|
563
|
+
adata_diff = adata[:, diff_genes]
|
564
|
+
stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
|
565
|
+
ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
|
566
|
+
x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel()
|
567
|
+
y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
|
568
|
+
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
569
|
+
if verbose:
|
570
|
+
logger.info("Top 100 DEGs var: ", r_value_diff**2)
|
571
|
+
if "y1" in axis_keys.keys():
|
572
|
+
real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
|
573
|
+
x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
|
574
|
+
y = np.asarray(np.var(stim.X, axis=0)).ravel()
|
575
|
+
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
576
|
+
if verbose:
|
577
|
+
logger.info("All genes var: ", r_value**2)
|
578
|
+
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
579
|
+
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
580
|
+
ax.tick_params(labelsize=fontsize)
|
581
|
+
if "range" in kwargs:
|
582
|
+
start, stop, step = kwargs.get("range")
|
583
|
+
ax.set_xticks(np.arange(start, stop, step))
|
584
|
+
ax.set_yticks(np.arange(start, stop, step))
|
585
|
+
# _p1 = plt.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}")
|
586
|
+
# plt.plot(x, m * x + b, "-", color="green")
|
587
|
+
ax.set_xlabel(labels["x"], fontsize=fontsize)
|
588
|
+
ax.set_ylabel(labels["y"], fontsize=fontsize)
|
589
|
+
if "y1" in axis_keys.keys():
|
590
|
+
y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
|
591
|
+
_ = plt.scatter(
|
592
|
+
x,
|
593
|
+
y1,
|
594
|
+
marker="*",
|
595
|
+
c="grey",
|
596
|
+
alpha=0.5,
|
597
|
+
label=f"{axis_keys['x']}-{axis_keys['y1']}",
|
598
|
+
)
|
599
|
+
if gene_list is not None:
|
600
|
+
for i in gene_list:
|
601
|
+
j = adata.var_names.tolist().index(i)
|
602
|
+
x_bar = x[j]
|
603
|
+
y_bar = y[j]
|
604
|
+
plt.text(x_bar, y_bar, i, fontsize=11, color="black")
|
605
|
+
plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
|
606
|
+
if "y1" in axis_keys.keys():
|
607
|
+
y1_bar = y1[j]
|
608
|
+
plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
|
609
|
+
if legend:
|
610
|
+
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
611
|
+
if title is None:
|
612
|
+
plt.title("", fontsize=12)
|
613
|
+
else:
|
614
|
+
plt.title(title, fontsize=12)
|
615
|
+
ax.text(
|
616
|
+
max(x) - max(x) * x_coeff,
|
617
|
+
max(y) - y_coeff * max(y),
|
618
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
|
619
|
+
fontsize=kwargs.get("textsize", fontsize),
|
620
|
+
)
|
621
|
+
if diff_genes is not None:
|
622
|
+
ax.text(
|
623
|
+
max(x) - max(x) * x_coeff,
|
624
|
+
max(y) - (y_coeff + 0.15) * max(y),
|
625
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
|
626
|
+
fontsize=kwargs.get("textsize", fontsize),
|
627
|
+
)
|
628
|
+
|
629
|
+
if save:
|
630
|
+
plt.savefig(save, bbox_inches="tight")
|
631
|
+
if show:
|
632
|
+
plt.show()
|
633
|
+
plt.close()
|
634
|
+
if diff_genes is not None:
|
635
|
+
return r_value**2, r_value_diff**2
|
636
|
+
else:
|
637
|
+
return r_value**2
|
638
|
+
|
639
|
+
def plot_binary_classifier(
|
640
|
+
self,
|
641
|
+
scgen: Scgen,
|
642
|
+
adata: AnnData | None,
|
643
|
+
delta: np.ndarray,
|
644
|
+
ctrl_key: str,
|
645
|
+
stim_key: str,
|
646
|
+
show: bool = False,
|
647
|
+
save: str | bool | None = None,
|
648
|
+
fontsize: float = 14,
|
649
|
+
) -> plt.Axes | None:
|
650
|
+
"""Plots the dot product between delta and latent representation of a linear classifier.
|
651
|
+
|
652
|
+
Builds a linear classifier based on the dot product between
|
653
|
+
the difference vector and the latent representation of each
|
654
|
+
cell and plots the dot product results between delta and latent representation.
|
655
|
+
|
656
|
+
Args:
|
657
|
+
scgen: ScGen object that was trained.
|
658
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
659
|
+
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
660
|
+
corresponding to batch and cell type metadata, respectively.
|
661
|
+
delta: Difference between stimulated and control cells in latent space
|
662
|
+
ctrl_key: Key for `control` part of the `data` found in `condition_key`.
|
663
|
+
stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
|
664
|
+
path_to_save: Path to save the plot.
|
665
|
+
save: Specify if the plot should be saved or not.
|
666
|
+
fontsize: Set the font size of the plot.
|
667
|
+
"""
|
668
|
+
plt.close("all")
|
669
|
+
adata = scgen._validate_anndata(adata)
|
670
|
+
condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
|
671
|
+
cd = adata[adata.obs[condition_key] == ctrl_key, :]
|
672
|
+
stim = adata[adata.obs[condition_key] == stim_key, :]
|
673
|
+
all_latent_cd = scgen.get_latent_representation(cd.X)
|
674
|
+
all_latent_stim = scgen.get_latent_representation(stim.X)
|
675
|
+
dot_cd = np.zeros(len(all_latent_cd))
|
676
|
+
dot_sal = np.zeros(len(all_latent_stim))
|
677
|
+
for ind, vec in enumerate(all_latent_cd):
|
678
|
+
dot_cd[ind] = np.dot(delta, vec)
|
679
|
+
for ind, vec in enumerate(all_latent_stim):
|
680
|
+
dot_sal[ind] = np.dot(delta, vec)
|
681
|
+
plt.hist(
|
682
|
+
dot_cd,
|
683
|
+
label=ctrl_key,
|
684
|
+
bins=50,
|
685
|
+
)
|
686
|
+
plt.hist(dot_sal, label=stim_key, bins=50)
|
687
|
+
plt.axvline(0, color="k", linestyle="dashed", linewidth=1)
|
688
|
+
plt.title(" ", fontsize=fontsize)
|
689
|
+
plt.xlabel(" ", fontsize=fontsize)
|
690
|
+
plt.ylabel(" ", fontsize=fontsize)
|
691
|
+
plt.xticks(fontsize=fontsize)
|
692
|
+
plt.yticks(fontsize=fontsize)
|
693
|
+
ax = plt.gca()
|
694
|
+
ax.grid(False)
|
695
|
+
|
696
|
+
if save:
|
697
|
+
plt.savefig(save, bbox_inches="tight")
|
698
|
+
if show:
|
699
|
+
plt.show()
|
700
|
+
if not (show or save):
|
701
|
+
return ax
|
702
|
+
return None
|
703
|
+
|
704
|
+
|
705
|
+
# compatibility
|
706
|
+
SCGEN = Scgen
|