pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,701 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any
|
4
|
+
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import matplotlib.pyplot as plt
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import scanpy as sc
|
10
|
+
from adjustText import adjust_text
|
11
|
+
from anndata import AnnData
|
12
|
+
from jax import Array
|
13
|
+
from scipy import stats
|
14
|
+
from scvi import REGISTRY_KEYS
|
15
|
+
from scvi.data import AnnDataManager
|
16
|
+
from scvi.data.fields import CategoricalObsField, LayerField
|
17
|
+
from scvi.model.base import BaseModelClass, JaxTrainingMixin
|
18
|
+
from scvi.utils import setup_anndata_dsp
|
19
|
+
|
20
|
+
from ._scgenvae import JaxSCGENVAE
|
21
|
+
from ._utils import balancer, extractor
|
22
|
+
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
from collections.abc import Sequence
|
25
|
+
|
26
|
+
font = {"family": "Arial", "size": 14}
|
27
|
+
|
28
|
+
|
29
|
+
class SCGEN(JaxTrainingMixin, BaseModelClass):
|
30
|
+
"""Jax Implementation of scGen model for batch removal and perturbation prediction."""
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
adata: AnnData,
|
35
|
+
n_hidden: int = 800,
|
36
|
+
n_latent: int = 100,
|
37
|
+
n_layers: int = 2,
|
38
|
+
dropout_rate: float = 0.2,
|
39
|
+
**model_kwargs,
|
40
|
+
):
|
41
|
+
super().__init__(adata)
|
42
|
+
|
43
|
+
self.module = JaxSCGENVAE(
|
44
|
+
n_input=self.summary_stats.n_vars,
|
45
|
+
n_hidden=n_hidden,
|
46
|
+
n_latent=n_latent,
|
47
|
+
n_layers=n_layers,
|
48
|
+
dropout_rate=dropout_rate,
|
49
|
+
**model_kwargs,
|
50
|
+
)
|
51
|
+
self._model_summary_string = (
|
52
|
+
f"SCGEN Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
|
53
|
+
f"{dropout_rate}"
|
54
|
+
)
|
55
|
+
self.init_params_ = self._get_init_params(locals())
|
56
|
+
|
57
|
+
def predict(
|
58
|
+
self,
|
59
|
+
ctrl_key=None,
|
60
|
+
stim_key=None,
|
61
|
+
adata_to_predict=None,
|
62
|
+
celltype_to_predict=None,
|
63
|
+
restrict_arithmetic_to="all",
|
64
|
+
) -> tuple[AnnData, Any]:
|
65
|
+
"""Predicts the cell type provided by the user in stimulated condition.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
ctrl_key: Key for `control` part of the `data` found in `condition_key`.
|
69
|
+
stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
|
70
|
+
adata_to_predict: Adata for unperturbed cells you want to be predicted.
|
71
|
+
celltype_to_predict: The cell type you want to be predicted.
|
72
|
+
restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
`np nd-array` of predicted cells in primary space.
|
76
|
+
delta: float
|
77
|
+
Difference between stimulated and control cells in latent space
|
78
|
+
|
79
|
+
Examples:
|
80
|
+
>>> import pertpy as pt
|
81
|
+
>>> data = pt.dt.kang_2018()
|
82
|
+
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
83
|
+
>>> model = pt.tl.SCGEN(data)
|
84
|
+
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
85
|
+
>>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
|
86
|
+
"""
|
87
|
+
# use keys registered from `setup_anndata()`
|
88
|
+
cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
|
89
|
+
condition_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
|
90
|
+
|
91
|
+
if restrict_arithmetic_to == "all":
|
92
|
+
ctrl_x = self.adata[self.adata.obs[condition_key] == ctrl_key, :]
|
93
|
+
stim_x = self.adata[self.adata.obs[condition_key] == stim_key, :]
|
94
|
+
ctrl_x = balancer(ctrl_x, cell_type_key)
|
95
|
+
stim_x = balancer(stim_x, cell_type_key)
|
96
|
+
else:
|
97
|
+
key = list(restrict_arithmetic_to.keys())[0]
|
98
|
+
values = restrict_arithmetic_to[key]
|
99
|
+
subset = self.adata[self.adata.obs[key].isin(values)]
|
100
|
+
ctrl_x = subset[subset.obs[condition_key] == ctrl_key, :]
|
101
|
+
stim_x = subset[subset.obs[condition_key] == stim_key, :]
|
102
|
+
if len(values) > 1:
|
103
|
+
ctrl_x = balancer(ctrl_x, cell_type_key)
|
104
|
+
stim_x = balancer(stim_x, cell_type_key)
|
105
|
+
if celltype_to_predict is not None and adata_to_predict is not None:
|
106
|
+
raise Exception("Please provide either a cell type or adata not both!")
|
107
|
+
if celltype_to_predict is None and adata_to_predict is None:
|
108
|
+
raise Exception("Please provide a cell type name or adata for your unperturbed cells")
|
109
|
+
if celltype_to_predict is not None:
|
110
|
+
ctrl_pred = extractor(
|
111
|
+
self.adata,
|
112
|
+
celltype_to_predict,
|
113
|
+
condition_key,
|
114
|
+
cell_type_key,
|
115
|
+
ctrl_key,
|
116
|
+
stim_key,
|
117
|
+
)[1]
|
118
|
+
else:
|
119
|
+
ctrl_pred = adata_to_predict
|
120
|
+
|
121
|
+
eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0])
|
122
|
+
rng = np.random.default_rng()
|
123
|
+
cd_ind = rng.choice(range(ctrl_x.shape[0]), size=eq, replace=False)
|
124
|
+
stim_ind = rng.choice(range(stim_x.shape[0]), size=eq, replace=False)
|
125
|
+
ctrl_adata = ctrl_x[cd_ind, :]
|
126
|
+
stim_adata = stim_x[stim_ind, :]
|
127
|
+
|
128
|
+
latent_ctrl = self._avg_vector(ctrl_adata)
|
129
|
+
latent_stim = self._avg_vector(stim_adata)
|
130
|
+
|
131
|
+
delta = latent_stim - latent_ctrl
|
132
|
+
|
133
|
+
latent_cd = self.get_latent_representation(ctrl_pred)
|
134
|
+
|
135
|
+
stim_pred = delta + latent_cd
|
136
|
+
predicted_cells = self.module.as_bound().generative(stim_pred)["px"]
|
137
|
+
|
138
|
+
predicted_adata = AnnData(
|
139
|
+
X=np.array(predicted_cells),
|
140
|
+
obs=ctrl_pred.obs.copy(),
|
141
|
+
var=ctrl_pred.var.copy(),
|
142
|
+
obsm=ctrl_pred.obsm.copy(),
|
143
|
+
)
|
144
|
+
return predicted_adata, delta
|
145
|
+
|
146
|
+
def _avg_vector(self, adata):
|
147
|
+
return np.mean(self.get_latent_representation(adata), axis=0)
|
148
|
+
|
149
|
+
def get_decoded_expression(
|
150
|
+
self,
|
151
|
+
adata: AnnData | None = None,
|
152
|
+
indices: Sequence[int] | None = None,
|
153
|
+
batch_size: int | None = None,
|
154
|
+
) -> Array:
|
155
|
+
"""Get decoded expression.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
159
|
+
AnnData object used to initialize the model.
|
160
|
+
indices: Indices of cells in adata to use. If `None`, all cells are used.
|
161
|
+
batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
Decoded expression for each cell
|
165
|
+
|
166
|
+
Examples:
|
167
|
+
>>> import pertpy as pt
|
168
|
+
>>> data = pt.dt.kang_2018()
|
169
|
+
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
170
|
+
>>> model = pt.tl.SCGEN(data)
|
171
|
+
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
172
|
+
>>> decoded_X = model.get_decoded_expression()
|
173
|
+
"""
|
174
|
+
if self.is_trained_ is False:
|
175
|
+
raise RuntimeError("Please train the model first.")
|
176
|
+
|
177
|
+
adata = self._validate_anndata(adata)
|
178
|
+
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
|
179
|
+
decoded = []
|
180
|
+
for tensors in scdl:
|
181
|
+
_, generative_outputs = self.module.as_bound()(tensors, compute_loss=False)
|
182
|
+
px = generative_outputs["px"]
|
183
|
+
decoded.append(px)
|
184
|
+
|
185
|
+
return jnp.concatenate(decoded)
|
186
|
+
|
187
|
+
def batch_removal(self, adata: AnnData | None = None) -> AnnData:
|
188
|
+
"""Removes batch effects.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
192
|
+
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
193
|
+
corresponding to batch and cell type metadata, respectively.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
corrected: `~anndata.AnnData`
|
197
|
+
AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
|
198
|
+
A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
|
199
|
+
|
200
|
+
Examples:
|
201
|
+
>>> import pertpy as pt
|
202
|
+
>>> data = pt.dt.kang_2018()
|
203
|
+
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
204
|
+
>>> model = pt.tl.SCGEN(data)
|
205
|
+
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
206
|
+
>>> corrected_adata = model.batch_removal()
|
207
|
+
"""
|
208
|
+
adata = self._validate_anndata(adata)
|
209
|
+
latent_all = self.get_latent_representation(adata)
|
210
|
+
# use keys registered from `setup_anndata()`
|
211
|
+
cell_label_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
|
212
|
+
batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
|
213
|
+
|
214
|
+
adata_latent = AnnData(latent_all)
|
215
|
+
adata_latent.obs = adata.obs.copy(deep=True)
|
216
|
+
unique_cell_types = np.unique(adata_latent.obs[cell_label_key])
|
217
|
+
shared_ct = []
|
218
|
+
not_shared_ct = []
|
219
|
+
for cell_type in unique_cell_types:
|
220
|
+
temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
|
221
|
+
if len(np.unique(temp_cell.obs[batch_key])) < 2:
|
222
|
+
cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
|
223
|
+
not_shared_ct.append(cell_type_ann)
|
224
|
+
continue
|
225
|
+
temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
|
226
|
+
batch_list = {}
|
227
|
+
batch_ind = {}
|
228
|
+
max_batch = 0
|
229
|
+
max_batch_ind = ""
|
230
|
+
batches = np.unique(temp_cell.obs[batch_key])
|
231
|
+
for i in batches:
|
232
|
+
temp = temp_cell[temp_cell.obs[batch_key] == i]
|
233
|
+
temp_ind = temp_cell.obs[batch_key] == i
|
234
|
+
if max_batch < len(temp):
|
235
|
+
max_batch = len(temp)
|
236
|
+
max_batch_ind = i
|
237
|
+
batch_list[i] = temp
|
238
|
+
batch_ind[i] = temp_ind
|
239
|
+
max_batch_ann = batch_list[max_batch_ind]
|
240
|
+
for study in batch_list:
|
241
|
+
delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
|
242
|
+
batch_list[study].X = delta + batch_list[study].X
|
243
|
+
temp_cell[batch_ind[study]].X = batch_list[study].X
|
244
|
+
shared_ct.append(temp_cell)
|
245
|
+
|
246
|
+
all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
|
247
|
+
if "concat_batch" in all_shared_ann.obs.columns:
|
248
|
+
del all_shared_ann.obs["concat_batch"]
|
249
|
+
if len(not_shared_ct) < 1:
|
250
|
+
corrected = AnnData(
|
251
|
+
np.array(self.module.as_bound().generative(all_shared_ann.X)["px"]),
|
252
|
+
obs=all_shared_ann.obs,
|
253
|
+
)
|
254
|
+
corrected.var_names = adata.var_names.tolist()
|
255
|
+
corrected = corrected[adata.obs_names]
|
256
|
+
if adata.raw is not None:
|
257
|
+
adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
|
258
|
+
adata_raw.obs_names = adata.obs_names
|
259
|
+
corrected.raw = adata_raw
|
260
|
+
corrected.obsm["latent"] = all_shared_ann.X
|
261
|
+
corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
|
262
|
+
return corrected
|
263
|
+
else:
|
264
|
+
all_not_shared_ann = AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
|
265
|
+
all_corrected_data = AnnData.concatenate(
|
266
|
+
all_shared_ann,
|
267
|
+
all_not_shared_ann,
|
268
|
+
batch_key="concat_batch",
|
269
|
+
index_unique=None,
|
270
|
+
)
|
271
|
+
if "concat_batch" in all_shared_ann.obs.columns:
|
272
|
+
del all_corrected_data.obs["concat_batch"]
|
273
|
+
corrected = AnnData(
|
274
|
+
np.array(self.module.as_bound().generative(all_corrected_data.X)["px"]),
|
275
|
+
obs=all_corrected_data.obs,
|
276
|
+
)
|
277
|
+
corrected.var_names = adata.var_names.tolist()
|
278
|
+
corrected = corrected[adata.obs_names]
|
279
|
+
if adata.raw is not None:
|
280
|
+
adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
|
281
|
+
adata_raw.obs_names = adata.obs_names
|
282
|
+
corrected.raw = adata_raw
|
283
|
+
corrected.obsm["latent"] = all_corrected_data.X
|
284
|
+
corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
|
285
|
+
|
286
|
+
return corrected
|
287
|
+
|
288
|
+
@classmethod
|
289
|
+
@setup_anndata_dsp.dedent
|
290
|
+
def setup_anndata(
|
291
|
+
cls,
|
292
|
+
adata: AnnData,
|
293
|
+
batch_key: str | None = None,
|
294
|
+
labels_key: str | None = None,
|
295
|
+
**kwargs,
|
296
|
+
):
|
297
|
+
"""%(summary)s.
|
298
|
+
|
299
|
+
scGen expects the expression data to come from `adata.X`
|
300
|
+
|
301
|
+
%(param_batch_key)s
|
302
|
+
%(param_labels_key)s
|
303
|
+
|
304
|
+
Examples:
|
305
|
+
>>> import pertpy as pt
|
306
|
+
>>> data = pt.dt.kang_2018()
|
307
|
+
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
308
|
+
"""
|
309
|
+
setup_method_args = cls._get_setup_method_args(**locals())
|
310
|
+
anndata_fields = [
|
311
|
+
LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=False),
|
312
|
+
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
|
313
|
+
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
|
314
|
+
]
|
315
|
+
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
|
316
|
+
adata_manager.register_fields(adata, **kwargs)
|
317
|
+
cls.register_manager(adata_manager)
|
318
|
+
|
319
|
+
def to_device(self, device):
|
320
|
+
pass
|
321
|
+
|
322
|
+
@property
|
323
|
+
def device(self):
|
324
|
+
return self.module.device
|
325
|
+
|
326
|
+
def get_latent_representation(
|
327
|
+
self,
|
328
|
+
adata: AnnData | None = None,
|
329
|
+
indices: Sequence[int] | None = None,
|
330
|
+
give_mean: bool = True,
|
331
|
+
n_samples: int = 1,
|
332
|
+
batch_size: int | None = None,
|
333
|
+
) -> np.ndarray:
|
334
|
+
"""Return the latent representation for each cell.
|
335
|
+
|
336
|
+
Args:
|
337
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
338
|
+
AnnData object used to initialize the model.
|
339
|
+
indices: Indices of cells in adata to use. If `None`, all cells are used.
|
340
|
+
batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
|
341
|
+
|
342
|
+
Returns:
|
343
|
+
Low-dimensional representation for each cell
|
344
|
+
|
345
|
+
Examples:
|
346
|
+
>>> import pertpy as pt
|
347
|
+
>>> data = pt.dt.kang_2018()
|
348
|
+
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
349
|
+
>>> model = pt.tl.SCGEN(data)
|
350
|
+
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
351
|
+
>>> latent_X = model.get_latent_representation()
|
352
|
+
"""
|
353
|
+
self._check_if_trained(warn=False)
|
354
|
+
|
355
|
+
adata = self._validate_anndata(adata)
|
356
|
+
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True)
|
357
|
+
|
358
|
+
jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"n_samples": n_samples})
|
359
|
+
|
360
|
+
latent = []
|
361
|
+
for array_dict in scdl:
|
362
|
+
out = jit_inference_fn(self.module.rngs, array_dict)
|
363
|
+
if give_mean:
|
364
|
+
z = out["qz"].mean
|
365
|
+
else:
|
366
|
+
z = out["z"]
|
367
|
+
latent.append(z)
|
368
|
+
concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
|
369
|
+
latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore
|
370
|
+
|
371
|
+
return self.module.as_numpy_array(latent)
|
372
|
+
|
373
|
+
def plot_reg_mean_plot(
|
374
|
+
self,
|
375
|
+
adata,
|
376
|
+
condition_key: str,
|
377
|
+
axis_keys: dict[str, str],
|
378
|
+
labels: dict[str, str],
|
379
|
+
save: str | bool | None = None,
|
380
|
+
gene_list: list[str] = None,
|
381
|
+
show: bool = False,
|
382
|
+
top_100_genes: list[str] = None,
|
383
|
+
verbose: bool = False,
|
384
|
+
legend: bool = True,
|
385
|
+
title: str = None,
|
386
|
+
x_coeff: float = 0.30,
|
387
|
+
y_coeff: float = 0.8,
|
388
|
+
fontsize: float = 14,
|
389
|
+
**kwargs,
|
390
|
+
) -> tuple[float, float] | float:
|
391
|
+
"""Plots mean matching for a set of specified genes.
|
392
|
+
|
393
|
+
Args:
|
394
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
395
|
+
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
396
|
+
corresponding to batch and cell type metadata, respectively.
|
397
|
+
condition_key: The key for the condition
|
398
|
+
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
399
|
+
`{"x": "Key for x-axis", "y": "Key for y-axis"}`.
|
400
|
+
labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
|
401
|
+
path_to_save: path to save the plot.
|
402
|
+
save: Specify if the plot should be saved or not.
|
403
|
+
gene_list: list of gene names to be plotted.
|
404
|
+
show: if `True`: will show to the plot after saving it.
|
405
|
+
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
406
|
+
verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
|
407
|
+
legend: if `True`: plots a legend, defaults to `True`.
|
408
|
+
title: Set if you want the plot to display a title.
|
409
|
+
x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
|
410
|
+
y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
|
411
|
+
fontsize: Fontsize used for text in the plot, defaults to 14.
|
412
|
+
**kwargs:
|
413
|
+
|
414
|
+
Examples:
|
415
|
+
>>> import pertpy as pt
|
416
|
+
>>> data = pt.dt.kang_2018()
|
417
|
+
>>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
418
|
+
>>> scg = pt.tl.SCGEN(data)
|
419
|
+
>>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
420
|
+
>>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
|
421
|
+
>>> pred.obs['label'] = 'pred'
|
422
|
+
>>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred)
|
423
|
+
>>> r2_value = scg.plot_reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \
|
424
|
+
labels={"x": "predicted", "y": "ground truth"}, save=False, show=True)
|
425
|
+
|
426
|
+
Preview:
|
427
|
+
.. image:: /_static/docstring_previews/scgen_reg_mean.png
|
428
|
+
"""
|
429
|
+
import seaborn as sns
|
430
|
+
|
431
|
+
sns.set_theme()
|
432
|
+
sns.set_theme(color_codes=True)
|
433
|
+
|
434
|
+
diff_genes = top_100_genes
|
435
|
+
stim = adata[adata.obs[condition_key] == axis_keys["y"]]
|
436
|
+
ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
|
437
|
+
if diff_genes is not None:
|
438
|
+
if hasattr(diff_genes, "tolist"):
|
439
|
+
diff_genes = diff_genes.tolist()
|
440
|
+
adata_diff = adata[:, diff_genes]
|
441
|
+
stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
|
442
|
+
ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
|
443
|
+
x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel()
|
444
|
+
y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
|
445
|
+
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
446
|
+
if verbose:
|
447
|
+
print("top_100 DEGs mean: ", r_value_diff**2)
|
448
|
+
x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
|
449
|
+
y = np.asarray(np.mean(stim.X, axis=0)).ravel()
|
450
|
+
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
451
|
+
if verbose:
|
452
|
+
print("All genes mean: ", r_value**2)
|
453
|
+
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
454
|
+
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
455
|
+
ax.tick_params(labelsize=fontsize)
|
456
|
+
if "range" in kwargs:
|
457
|
+
start, stop, step = kwargs.get("range")
|
458
|
+
ax.set_xticks(np.arange(start, stop, step))
|
459
|
+
ax.set_yticks(np.arange(start, stop, step))
|
460
|
+
ax.set_xlabel(labels["x"], fontsize=fontsize)
|
461
|
+
ax.set_ylabel(labels["y"], fontsize=fontsize)
|
462
|
+
if gene_list is not None:
|
463
|
+
texts = []
|
464
|
+
for i in gene_list:
|
465
|
+
j = adata.var_names.tolist().index(i)
|
466
|
+
x_bar = x[j]
|
467
|
+
y_bar = y[j]
|
468
|
+
texts.append(plt.text(x_bar, y_bar, i, fontsize=11, color="black"))
|
469
|
+
plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
|
470
|
+
# if "y1" in axis_keys.keys():
|
471
|
+
# y1_bar = y1[j]
|
472
|
+
# plt.text(x_bar, y1_bar, i, fontsize=11, color="black")
|
473
|
+
if gene_list is not None:
|
474
|
+
adjust_text(
|
475
|
+
texts,
|
476
|
+
x=x,
|
477
|
+
y=y,
|
478
|
+
arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5},
|
479
|
+
force_static=(0.0, 0.0),
|
480
|
+
)
|
481
|
+
if legend:
|
482
|
+
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
483
|
+
if title is None:
|
484
|
+
plt.title("", fontsize=fontsize)
|
485
|
+
else:
|
486
|
+
plt.title(title, fontsize=fontsize)
|
487
|
+
ax.text(
|
488
|
+
max(x) - max(x) * x_coeff,
|
489
|
+
max(y) - y_coeff * max(y),
|
490
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
|
491
|
+
fontsize=kwargs.get("textsize", fontsize),
|
492
|
+
)
|
493
|
+
if diff_genes is not None:
|
494
|
+
ax.text(
|
495
|
+
max(x) - max(x) * x_coeff,
|
496
|
+
max(y) - (y_coeff + 0.15) * max(y),
|
497
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
|
498
|
+
fontsize=kwargs.get("textsize", fontsize),
|
499
|
+
)
|
500
|
+
if save:
|
501
|
+
plt.savefig(save, bbox_inches="tight")
|
502
|
+
if show:
|
503
|
+
plt.show()
|
504
|
+
plt.close()
|
505
|
+
if diff_genes is not None:
|
506
|
+
return r_value**2, r_value_diff**2
|
507
|
+
else:
|
508
|
+
return r_value**2
|
509
|
+
|
510
|
+
def plot_reg_var_plot(
|
511
|
+
self,
|
512
|
+
adata,
|
513
|
+
condition_key: str,
|
514
|
+
axis_keys: dict[str, str],
|
515
|
+
labels: dict[str, str],
|
516
|
+
save: str | bool | None = None,
|
517
|
+
gene_list: list[str] = None,
|
518
|
+
top_100_genes: list[str] = None,
|
519
|
+
show: bool = False,
|
520
|
+
legend: bool = True,
|
521
|
+
title: str = None,
|
522
|
+
verbose: bool = False,
|
523
|
+
x_coeff: float = 0.3,
|
524
|
+
y_coeff: float = 0.8,
|
525
|
+
fontsize: float = 14,
|
526
|
+
**kwargs,
|
527
|
+
) -> tuple[float, float] | float:
|
528
|
+
"""Plots variance matching for a set of specified genes.
|
529
|
+
|
530
|
+
Args:
|
531
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
532
|
+
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
533
|
+
corresponding to batch and cell type metadata, respectively.
|
534
|
+
condition_key: Key of the condition.
|
535
|
+
axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
|
536
|
+
`{"x": "Key for x-axis", "y": "Key for y-axis"}`.
|
537
|
+
labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
|
538
|
+
path_to_save: path to save the plot.
|
539
|
+
save: Specify if the plot should be saved or not.
|
540
|
+
gene_list: list of gene names to be plotted.
|
541
|
+
show: if `True`: will show to the plot after saving it.
|
542
|
+
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
543
|
+
legend: if `True`: plots a legend, defaults to `True`.
|
544
|
+
title: Set if you want the plot to display a title.
|
545
|
+
verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
|
546
|
+
x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
|
547
|
+
y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
|
548
|
+
fontsize: Fontsize used for text in the plot, defaults to 14.
|
549
|
+
"""
|
550
|
+
import seaborn as sns
|
551
|
+
|
552
|
+
sns.set_theme()
|
553
|
+
sns.set_theme(color_codes=True)
|
554
|
+
|
555
|
+
sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon")
|
556
|
+
diff_genes = top_100_genes
|
557
|
+
stim = adata[adata.obs[condition_key] == axis_keys["y"]]
|
558
|
+
ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
|
559
|
+
if diff_genes is not None:
|
560
|
+
if hasattr(diff_genes, "tolist"):
|
561
|
+
diff_genes = diff_genes.tolist()
|
562
|
+
adata_diff = adata[:, diff_genes]
|
563
|
+
stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
|
564
|
+
ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
|
565
|
+
x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel()
|
566
|
+
y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
|
567
|
+
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
568
|
+
if verbose:
|
569
|
+
print("Top 100 DEGs var: ", r_value_diff**2)
|
570
|
+
if "y1" in axis_keys.keys():
|
571
|
+
real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
|
572
|
+
x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
|
573
|
+
y = np.asarray(np.var(stim.X, axis=0)).ravel()
|
574
|
+
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
575
|
+
if verbose:
|
576
|
+
print("All genes var: ", r_value**2)
|
577
|
+
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
578
|
+
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
579
|
+
ax.tick_params(labelsize=fontsize)
|
580
|
+
if "range" in kwargs:
|
581
|
+
start, stop, step = kwargs.get("range")
|
582
|
+
ax.set_xticks(np.arange(start, stop, step))
|
583
|
+
ax.set_yticks(np.arange(start, stop, step))
|
584
|
+
# _p1 = plt.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}")
|
585
|
+
# plt.plot(x, m * x + b, "-", color="green")
|
586
|
+
ax.set_xlabel(labels["x"], fontsize=fontsize)
|
587
|
+
ax.set_ylabel(labels["y"], fontsize=fontsize)
|
588
|
+
if "y1" in axis_keys.keys():
|
589
|
+
y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
|
590
|
+
_ = plt.scatter(
|
591
|
+
x,
|
592
|
+
y1,
|
593
|
+
marker="*",
|
594
|
+
c="grey",
|
595
|
+
alpha=0.5,
|
596
|
+
label=f"{axis_keys['x']}-{axis_keys['y1']}",
|
597
|
+
)
|
598
|
+
if gene_list is not None:
|
599
|
+
for i in gene_list:
|
600
|
+
j = adata.var_names.tolist().index(i)
|
601
|
+
x_bar = x[j]
|
602
|
+
y_bar = y[j]
|
603
|
+
plt.text(x_bar, y_bar, i, fontsize=11, color="black")
|
604
|
+
plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
|
605
|
+
if "y1" in axis_keys.keys():
|
606
|
+
y1_bar = y1[j]
|
607
|
+
plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
|
608
|
+
if legend:
|
609
|
+
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
610
|
+
if title is None:
|
611
|
+
plt.title("", fontsize=12)
|
612
|
+
else:
|
613
|
+
plt.title(title, fontsize=12)
|
614
|
+
ax.text(
|
615
|
+
max(x) - max(x) * x_coeff,
|
616
|
+
max(y) - y_coeff * max(y),
|
617
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
|
618
|
+
fontsize=kwargs.get("textsize", fontsize),
|
619
|
+
)
|
620
|
+
if diff_genes is not None:
|
621
|
+
ax.text(
|
622
|
+
max(x) - max(x) * x_coeff,
|
623
|
+
max(y) - (y_coeff + 0.15) * max(y),
|
624
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
|
625
|
+
fontsize=kwargs.get("textsize", fontsize),
|
626
|
+
)
|
627
|
+
|
628
|
+
if save:
|
629
|
+
plt.savefig(save, bbox_inches="tight")
|
630
|
+
if show:
|
631
|
+
plt.show()
|
632
|
+
plt.close()
|
633
|
+
if diff_genes is not None:
|
634
|
+
return r_value**2, r_value_diff**2
|
635
|
+
else:
|
636
|
+
return r_value**2
|
637
|
+
|
638
|
+
def plot_binary_classifier(
|
639
|
+
self,
|
640
|
+
scgen: SCGEN,
|
641
|
+
adata: AnnData | None,
|
642
|
+
delta: np.ndarray,
|
643
|
+
ctrl_key: str,
|
644
|
+
stim_key: str,
|
645
|
+
show: bool = False,
|
646
|
+
save: str | bool | None = None,
|
647
|
+
fontsize: float = 14,
|
648
|
+
) -> plt.Axes | None:
|
649
|
+
"""Plots the dot product between delta and latent representation of a linear classifier.
|
650
|
+
|
651
|
+
Builds a linear classifier based on the dot product between
|
652
|
+
the difference vector and the latent representation of each
|
653
|
+
cell and plots the dot product results between delta and latent representation.
|
654
|
+
|
655
|
+
Args:
|
656
|
+
scgen: ScGen object that was trained.
|
657
|
+
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
658
|
+
AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
|
659
|
+
corresponding to batch and cell type metadata, respectively.
|
660
|
+
delta: Difference between stimulated and control cells in latent space
|
661
|
+
ctrl_key: Key for `control` part of the `data` found in `condition_key`.
|
662
|
+
stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
|
663
|
+
path_to_save: Path to save the plot.
|
664
|
+
save: Specify if the plot should be saved or not.
|
665
|
+
fontsize: Set the font size of the plot.
|
666
|
+
"""
|
667
|
+
plt.close("all")
|
668
|
+
adata = scgen._validate_anndata(adata)
|
669
|
+
condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
|
670
|
+
cd = adata[adata.obs[condition_key] == ctrl_key, :]
|
671
|
+
stim = adata[adata.obs[condition_key] == stim_key, :]
|
672
|
+
all_latent_cd = scgen.get_latent_representation(cd.X)
|
673
|
+
all_latent_stim = scgen.get_latent_representation(stim.X)
|
674
|
+
dot_cd = np.zeros(len(all_latent_cd))
|
675
|
+
dot_sal = np.zeros(len(all_latent_stim))
|
676
|
+
for ind, vec in enumerate(all_latent_cd):
|
677
|
+
dot_cd[ind] = np.dot(delta, vec)
|
678
|
+
for ind, vec in enumerate(all_latent_stim):
|
679
|
+
dot_sal[ind] = np.dot(delta, vec)
|
680
|
+
plt.hist(
|
681
|
+
dot_cd,
|
682
|
+
label=ctrl_key,
|
683
|
+
bins=50,
|
684
|
+
)
|
685
|
+
plt.hist(dot_sal, label=stim_key, bins=50)
|
686
|
+
plt.axvline(0, color="k", linestyle="dashed", linewidth=1)
|
687
|
+
plt.title(" ", fontsize=fontsize)
|
688
|
+
plt.xlabel(" ", fontsize=fontsize)
|
689
|
+
plt.ylabel(" ", fontsize=fontsize)
|
690
|
+
plt.xticks(fontsize=fontsize)
|
691
|
+
plt.yticks(fontsize=fontsize)
|
692
|
+
ax = plt.gca()
|
693
|
+
ax.grid(False)
|
694
|
+
|
695
|
+
if save:
|
696
|
+
plt.savefig(save, bbox_inches="tight")
|
697
|
+
if show:
|
698
|
+
plt.show()
|
699
|
+
if not (show or save):
|
700
|
+
return ax
|
701
|
+
return None
|