pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,701 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ import jax.numpy as jnp
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import scanpy as sc
10
+ from adjustText import adjust_text
11
+ from anndata import AnnData
12
+ from jax import Array
13
+ from scipy import stats
14
+ from scvi import REGISTRY_KEYS
15
+ from scvi.data import AnnDataManager
16
+ from scvi.data.fields import CategoricalObsField, LayerField
17
+ from scvi.model.base import BaseModelClass, JaxTrainingMixin
18
+ from scvi.utils import setup_anndata_dsp
19
+
20
+ from ._scgenvae import JaxSCGENVAE
21
+ from ._utils import balancer, extractor
22
+
23
+ if TYPE_CHECKING:
24
+ from collections.abc import Sequence
25
+
26
+ font = {"family": "Arial", "size": 14}
27
+
28
+
29
+ class SCGEN(JaxTrainingMixin, BaseModelClass):
30
+ """Jax Implementation of scGen model for batch removal and perturbation prediction."""
31
+
32
+ def __init__(
33
+ self,
34
+ adata: AnnData,
35
+ n_hidden: int = 800,
36
+ n_latent: int = 100,
37
+ n_layers: int = 2,
38
+ dropout_rate: float = 0.2,
39
+ **model_kwargs,
40
+ ):
41
+ super().__init__(adata)
42
+
43
+ self.module = JaxSCGENVAE(
44
+ n_input=self.summary_stats.n_vars,
45
+ n_hidden=n_hidden,
46
+ n_latent=n_latent,
47
+ n_layers=n_layers,
48
+ dropout_rate=dropout_rate,
49
+ **model_kwargs,
50
+ )
51
+ self._model_summary_string = (
52
+ f"SCGEN Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
53
+ f"{dropout_rate}"
54
+ )
55
+ self.init_params_ = self._get_init_params(locals())
56
+
57
+ def predict(
58
+ self,
59
+ ctrl_key=None,
60
+ stim_key=None,
61
+ adata_to_predict=None,
62
+ celltype_to_predict=None,
63
+ restrict_arithmetic_to="all",
64
+ ) -> tuple[AnnData, Any]:
65
+ """Predicts the cell type provided by the user in stimulated condition.
66
+
67
+ Args:
68
+ ctrl_key: Key for `control` part of the `data` found in `condition_key`.
69
+ stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
70
+ adata_to_predict: Adata for unperturbed cells you want to be predicted.
71
+ celltype_to_predict: The cell type you want to be predicted.
72
+ restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
73
+
74
+ Returns:
75
+ `np nd-array` of predicted cells in primary space.
76
+ delta: float
77
+ Difference between stimulated and control cells in latent space
78
+
79
+ Examples:
80
+ >>> import pertpy as pt
81
+ >>> data = pt.dt.kang_2018()
82
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
83
+ >>> model = pt.tl.SCGEN(data)
84
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
85
+ >>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
86
+ """
87
+ # use keys registered from `setup_anndata()`
88
+ cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
89
+ condition_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
90
+
91
+ if restrict_arithmetic_to == "all":
92
+ ctrl_x = self.adata[self.adata.obs[condition_key] == ctrl_key, :]
93
+ stim_x = self.adata[self.adata.obs[condition_key] == stim_key, :]
94
+ ctrl_x = balancer(ctrl_x, cell_type_key)
95
+ stim_x = balancer(stim_x, cell_type_key)
96
+ else:
97
+ key = list(restrict_arithmetic_to.keys())[0]
98
+ values = restrict_arithmetic_to[key]
99
+ subset = self.adata[self.adata.obs[key].isin(values)]
100
+ ctrl_x = subset[subset.obs[condition_key] == ctrl_key, :]
101
+ stim_x = subset[subset.obs[condition_key] == stim_key, :]
102
+ if len(values) > 1:
103
+ ctrl_x = balancer(ctrl_x, cell_type_key)
104
+ stim_x = balancer(stim_x, cell_type_key)
105
+ if celltype_to_predict is not None and adata_to_predict is not None:
106
+ raise Exception("Please provide either a cell type or adata not both!")
107
+ if celltype_to_predict is None and adata_to_predict is None:
108
+ raise Exception("Please provide a cell type name or adata for your unperturbed cells")
109
+ if celltype_to_predict is not None:
110
+ ctrl_pred = extractor(
111
+ self.adata,
112
+ celltype_to_predict,
113
+ condition_key,
114
+ cell_type_key,
115
+ ctrl_key,
116
+ stim_key,
117
+ )[1]
118
+ else:
119
+ ctrl_pred = adata_to_predict
120
+
121
+ eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0])
122
+ rng = np.random.default_rng()
123
+ cd_ind = rng.choice(range(ctrl_x.shape[0]), size=eq, replace=False)
124
+ stim_ind = rng.choice(range(stim_x.shape[0]), size=eq, replace=False)
125
+ ctrl_adata = ctrl_x[cd_ind, :]
126
+ stim_adata = stim_x[stim_ind, :]
127
+
128
+ latent_ctrl = self._avg_vector(ctrl_adata)
129
+ latent_stim = self._avg_vector(stim_adata)
130
+
131
+ delta = latent_stim - latent_ctrl
132
+
133
+ latent_cd = self.get_latent_representation(ctrl_pred)
134
+
135
+ stim_pred = delta + latent_cd
136
+ predicted_cells = self.module.as_bound().generative(stim_pred)["px"]
137
+
138
+ predicted_adata = AnnData(
139
+ X=np.array(predicted_cells),
140
+ obs=ctrl_pred.obs.copy(),
141
+ var=ctrl_pred.var.copy(),
142
+ obsm=ctrl_pred.obsm.copy(),
143
+ )
144
+ return predicted_adata, delta
145
+
146
+ def _avg_vector(self, adata):
147
+ return np.mean(self.get_latent_representation(adata), axis=0)
148
+
149
+ def get_decoded_expression(
150
+ self,
151
+ adata: AnnData | None = None,
152
+ indices: Sequence[int] | None = None,
153
+ batch_size: int | None = None,
154
+ ) -> Array:
155
+ """Get decoded expression.
156
+
157
+ Args:
158
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
159
+ AnnData object used to initialize the model.
160
+ indices: Indices of cells in adata to use. If `None`, all cells are used.
161
+ batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
162
+
163
+ Returns:
164
+ Decoded expression for each cell
165
+
166
+ Examples:
167
+ >>> import pertpy as pt
168
+ >>> data = pt.dt.kang_2018()
169
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
170
+ >>> model = pt.tl.SCGEN(data)
171
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
172
+ >>> decoded_X = model.get_decoded_expression()
173
+ """
174
+ if self.is_trained_ is False:
175
+ raise RuntimeError("Please train the model first.")
176
+
177
+ adata = self._validate_anndata(adata)
178
+ scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
179
+ decoded = []
180
+ for tensors in scdl:
181
+ _, generative_outputs = self.module.as_bound()(tensors, compute_loss=False)
182
+ px = generative_outputs["px"]
183
+ decoded.append(px)
184
+
185
+ return jnp.concatenate(decoded)
186
+
187
+ def batch_removal(self, adata: AnnData | None = None) -> AnnData:
188
+ """Removes batch effects.
189
+
190
+ Args:
191
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
192
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
193
+ corresponding to batch and cell type metadata, respectively.
194
+
195
+ Returns:
196
+ corrected: `~anndata.AnnData`
197
+ AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
198
+ A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
199
+
200
+ Examples:
201
+ >>> import pertpy as pt
202
+ >>> data = pt.dt.kang_2018()
203
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
204
+ >>> model = pt.tl.SCGEN(data)
205
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
206
+ >>> corrected_adata = model.batch_removal()
207
+ """
208
+ adata = self._validate_anndata(adata)
209
+ latent_all = self.get_latent_representation(adata)
210
+ # use keys registered from `setup_anndata()`
211
+ cell_label_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
212
+ batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
213
+
214
+ adata_latent = AnnData(latent_all)
215
+ adata_latent.obs = adata.obs.copy(deep=True)
216
+ unique_cell_types = np.unique(adata_latent.obs[cell_label_key])
217
+ shared_ct = []
218
+ not_shared_ct = []
219
+ for cell_type in unique_cell_types:
220
+ temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
221
+ if len(np.unique(temp_cell.obs[batch_key])) < 2:
222
+ cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
223
+ not_shared_ct.append(cell_type_ann)
224
+ continue
225
+ temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
226
+ batch_list = {}
227
+ batch_ind = {}
228
+ max_batch = 0
229
+ max_batch_ind = ""
230
+ batches = np.unique(temp_cell.obs[batch_key])
231
+ for i in batches:
232
+ temp = temp_cell[temp_cell.obs[batch_key] == i]
233
+ temp_ind = temp_cell.obs[batch_key] == i
234
+ if max_batch < len(temp):
235
+ max_batch = len(temp)
236
+ max_batch_ind = i
237
+ batch_list[i] = temp
238
+ batch_ind[i] = temp_ind
239
+ max_batch_ann = batch_list[max_batch_ind]
240
+ for study in batch_list:
241
+ delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
242
+ batch_list[study].X = delta + batch_list[study].X
243
+ temp_cell[batch_ind[study]].X = batch_list[study].X
244
+ shared_ct.append(temp_cell)
245
+
246
+ all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
247
+ if "concat_batch" in all_shared_ann.obs.columns:
248
+ del all_shared_ann.obs["concat_batch"]
249
+ if len(not_shared_ct) < 1:
250
+ corrected = AnnData(
251
+ np.array(self.module.as_bound().generative(all_shared_ann.X)["px"]),
252
+ obs=all_shared_ann.obs,
253
+ )
254
+ corrected.var_names = adata.var_names.tolist()
255
+ corrected = corrected[adata.obs_names]
256
+ if adata.raw is not None:
257
+ adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
258
+ adata_raw.obs_names = adata.obs_names
259
+ corrected.raw = adata_raw
260
+ corrected.obsm["latent"] = all_shared_ann.X
261
+ corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
262
+ return corrected
263
+ else:
264
+ all_not_shared_ann = AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
265
+ all_corrected_data = AnnData.concatenate(
266
+ all_shared_ann,
267
+ all_not_shared_ann,
268
+ batch_key="concat_batch",
269
+ index_unique=None,
270
+ )
271
+ if "concat_batch" in all_shared_ann.obs.columns:
272
+ del all_corrected_data.obs["concat_batch"]
273
+ corrected = AnnData(
274
+ np.array(self.module.as_bound().generative(all_corrected_data.X)["px"]),
275
+ obs=all_corrected_data.obs,
276
+ )
277
+ corrected.var_names = adata.var_names.tolist()
278
+ corrected = corrected[adata.obs_names]
279
+ if adata.raw is not None:
280
+ adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
281
+ adata_raw.obs_names = adata.obs_names
282
+ corrected.raw = adata_raw
283
+ corrected.obsm["latent"] = all_corrected_data.X
284
+ corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
285
+
286
+ return corrected
287
+
288
+ @classmethod
289
+ @setup_anndata_dsp.dedent
290
+ def setup_anndata(
291
+ cls,
292
+ adata: AnnData,
293
+ batch_key: str | None = None,
294
+ labels_key: str | None = None,
295
+ **kwargs,
296
+ ):
297
+ """%(summary)s.
298
+
299
+ scGen expects the expression data to come from `adata.X`
300
+
301
+ %(param_batch_key)s
302
+ %(param_labels_key)s
303
+
304
+ Examples:
305
+ >>> import pertpy as pt
306
+ >>> data = pt.dt.kang_2018()
307
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
308
+ """
309
+ setup_method_args = cls._get_setup_method_args(**locals())
310
+ anndata_fields = [
311
+ LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=False),
312
+ CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
313
+ CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
314
+ ]
315
+ adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
316
+ adata_manager.register_fields(adata, **kwargs)
317
+ cls.register_manager(adata_manager)
318
+
319
+ def to_device(self, device):
320
+ pass
321
+
322
+ @property
323
+ def device(self):
324
+ return self.module.device
325
+
326
+ def get_latent_representation(
327
+ self,
328
+ adata: AnnData | None = None,
329
+ indices: Sequence[int] | None = None,
330
+ give_mean: bool = True,
331
+ n_samples: int = 1,
332
+ batch_size: int | None = None,
333
+ ) -> np.ndarray:
334
+ """Return the latent representation for each cell.
335
+
336
+ Args:
337
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
338
+ AnnData object used to initialize the model.
339
+ indices: Indices of cells in adata to use. If `None`, all cells are used.
340
+ batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
341
+
342
+ Returns:
343
+ Low-dimensional representation for each cell
344
+
345
+ Examples:
346
+ >>> import pertpy as pt
347
+ >>> data = pt.dt.kang_2018()
348
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
349
+ >>> model = pt.tl.SCGEN(data)
350
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
351
+ >>> latent_X = model.get_latent_representation()
352
+ """
353
+ self._check_if_trained(warn=False)
354
+
355
+ adata = self._validate_anndata(adata)
356
+ scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True)
357
+
358
+ jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"n_samples": n_samples})
359
+
360
+ latent = []
361
+ for array_dict in scdl:
362
+ out = jit_inference_fn(self.module.rngs, array_dict)
363
+ if give_mean:
364
+ z = out["qz"].mean
365
+ else:
366
+ z = out["z"]
367
+ latent.append(z)
368
+ concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
369
+ latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore
370
+
371
+ return self.module.as_numpy_array(latent)
372
+
373
+ def plot_reg_mean_plot(
374
+ self,
375
+ adata,
376
+ condition_key: str,
377
+ axis_keys: dict[str, str],
378
+ labels: dict[str, str],
379
+ save: str | bool | None = None,
380
+ gene_list: list[str] = None,
381
+ show: bool = False,
382
+ top_100_genes: list[str] = None,
383
+ verbose: bool = False,
384
+ legend: bool = True,
385
+ title: str = None,
386
+ x_coeff: float = 0.30,
387
+ y_coeff: float = 0.8,
388
+ fontsize: float = 14,
389
+ **kwargs,
390
+ ) -> tuple[float, float] | float:
391
+ """Plots mean matching for a set of specified genes.
392
+
393
+ Args:
394
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
395
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
396
+ corresponding to batch and cell type metadata, respectively.
397
+ condition_key: The key for the condition
398
+ axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
399
+ `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
400
+ labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
401
+ path_to_save: path to save the plot.
402
+ save: Specify if the plot should be saved or not.
403
+ gene_list: list of gene names to be plotted.
404
+ show: if `True`: will show to the plot after saving it.
405
+ top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
406
+ verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
407
+ legend: if `True`: plots a legend, defaults to `True`.
408
+ title: Set if you want the plot to display a title.
409
+ x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
410
+ y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
411
+ fontsize: Fontsize used for text in the plot, defaults to 14.
412
+ **kwargs:
413
+
414
+ Examples:
415
+ >>> import pertpy as pt
416
+ >>> data = pt.dt.kang_2018()
417
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
418
+ >>> scg = pt.tl.SCGEN(data)
419
+ >>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
420
+ >>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
421
+ >>> pred.obs['label'] = 'pred'
422
+ >>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred)
423
+ >>> r2_value = scg.plot_reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \
424
+ labels={"x": "predicted", "y": "ground truth"}, save=False, show=True)
425
+
426
+ Preview:
427
+ .. image:: /_static/docstring_previews/scgen_reg_mean.png
428
+ """
429
+ import seaborn as sns
430
+
431
+ sns.set_theme()
432
+ sns.set_theme(color_codes=True)
433
+
434
+ diff_genes = top_100_genes
435
+ stim = adata[adata.obs[condition_key] == axis_keys["y"]]
436
+ ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
437
+ if diff_genes is not None:
438
+ if hasattr(diff_genes, "tolist"):
439
+ diff_genes = diff_genes.tolist()
440
+ adata_diff = adata[:, diff_genes]
441
+ stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
442
+ ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
443
+ x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel()
444
+ y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
445
+ m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
446
+ if verbose:
447
+ print("top_100 DEGs mean: ", r_value_diff**2)
448
+ x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
449
+ y = np.asarray(np.mean(stim.X, axis=0)).ravel()
450
+ m, b, r_value, p_value, std_err = stats.linregress(x, y)
451
+ if verbose:
452
+ print("All genes mean: ", r_value**2)
453
+ df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
454
+ ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
455
+ ax.tick_params(labelsize=fontsize)
456
+ if "range" in kwargs:
457
+ start, stop, step = kwargs.get("range")
458
+ ax.set_xticks(np.arange(start, stop, step))
459
+ ax.set_yticks(np.arange(start, stop, step))
460
+ ax.set_xlabel(labels["x"], fontsize=fontsize)
461
+ ax.set_ylabel(labels["y"], fontsize=fontsize)
462
+ if gene_list is not None:
463
+ texts = []
464
+ for i in gene_list:
465
+ j = adata.var_names.tolist().index(i)
466
+ x_bar = x[j]
467
+ y_bar = y[j]
468
+ texts.append(plt.text(x_bar, y_bar, i, fontsize=11, color="black"))
469
+ plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
470
+ # if "y1" in axis_keys.keys():
471
+ # y1_bar = y1[j]
472
+ # plt.text(x_bar, y1_bar, i, fontsize=11, color="black")
473
+ if gene_list is not None:
474
+ adjust_text(
475
+ texts,
476
+ x=x,
477
+ y=y,
478
+ arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5},
479
+ force_static=(0.0, 0.0),
480
+ )
481
+ if legend:
482
+ plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
483
+ if title is None:
484
+ plt.title("", fontsize=fontsize)
485
+ else:
486
+ plt.title(title, fontsize=fontsize)
487
+ ax.text(
488
+ max(x) - max(x) * x_coeff,
489
+ max(y) - y_coeff * max(y),
490
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
491
+ fontsize=kwargs.get("textsize", fontsize),
492
+ )
493
+ if diff_genes is not None:
494
+ ax.text(
495
+ max(x) - max(x) * x_coeff,
496
+ max(y) - (y_coeff + 0.15) * max(y),
497
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
498
+ fontsize=kwargs.get("textsize", fontsize),
499
+ )
500
+ if save:
501
+ plt.savefig(save, bbox_inches="tight")
502
+ if show:
503
+ plt.show()
504
+ plt.close()
505
+ if diff_genes is not None:
506
+ return r_value**2, r_value_diff**2
507
+ else:
508
+ return r_value**2
509
+
510
+ def plot_reg_var_plot(
511
+ self,
512
+ adata,
513
+ condition_key: str,
514
+ axis_keys: dict[str, str],
515
+ labels: dict[str, str],
516
+ save: str | bool | None = None,
517
+ gene_list: list[str] = None,
518
+ top_100_genes: list[str] = None,
519
+ show: bool = False,
520
+ legend: bool = True,
521
+ title: str = None,
522
+ verbose: bool = False,
523
+ x_coeff: float = 0.3,
524
+ y_coeff: float = 0.8,
525
+ fontsize: float = 14,
526
+ **kwargs,
527
+ ) -> tuple[float, float] | float:
528
+ """Plots variance matching for a set of specified genes.
529
+
530
+ Args:
531
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
532
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
533
+ corresponding to batch and cell type metadata, respectively.
534
+ condition_key: Key of the condition.
535
+ axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
536
+ `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
537
+ labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
538
+ path_to_save: path to save the plot.
539
+ save: Specify if the plot should be saved or not.
540
+ gene_list: list of gene names to be plotted.
541
+ show: if `True`: will show to the plot after saving it.
542
+ top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
543
+ legend: if `True`: plots a legend, defaults to `True`.
544
+ title: Set if you want the plot to display a title.
545
+ verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
546
+ x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
547
+ y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
548
+ fontsize: Fontsize used for text in the plot, defaults to 14.
549
+ """
550
+ import seaborn as sns
551
+
552
+ sns.set_theme()
553
+ sns.set_theme(color_codes=True)
554
+
555
+ sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon")
556
+ diff_genes = top_100_genes
557
+ stim = adata[adata.obs[condition_key] == axis_keys["y"]]
558
+ ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
559
+ if diff_genes is not None:
560
+ if hasattr(diff_genes, "tolist"):
561
+ diff_genes = diff_genes.tolist()
562
+ adata_diff = adata[:, diff_genes]
563
+ stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
564
+ ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
565
+ x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel()
566
+ y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
567
+ m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
568
+ if verbose:
569
+ print("Top 100 DEGs var: ", r_value_diff**2)
570
+ if "y1" in axis_keys.keys():
571
+ real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
572
+ x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
573
+ y = np.asarray(np.var(stim.X, axis=0)).ravel()
574
+ m, b, r_value, p_value, std_err = stats.linregress(x, y)
575
+ if verbose:
576
+ print("All genes var: ", r_value**2)
577
+ df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
578
+ ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
579
+ ax.tick_params(labelsize=fontsize)
580
+ if "range" in kwargs:
581
+ start, stop, step = kwargs.get("range")
582
+ ax.set_xticks(np.arange(start, stop, step))
583
+ ax.set_yticks(np.arange(start, stop, step))
584
+ # _p1 = plt.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}")
585
+ # plt.plot(x, m * x + b, "-", color="green")
586
+ ax.set_xlabel(labels["x"], fontsize=fontsize)
587
+ ax.set_ylabel(labels["y"], fontsize=fontsize)
588
+ if "y1" in axis_keys.keys():
589
+ y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
590
+ _ = plt.scatter(
591
+ x,
592
+ y1,
593
+ marker="*",
594
+ c="grey",
595
+ alpha=0.5,
596
+ label=f"{axis_keys['x']}-{axis_keys['y1']}",
597
+ )
598
+ if gene_list is not None:
599
+ for i in gene_list:
600
+ j = adata.var_names.tolist().index(i)
601
+ x_bar = x[j]
602
+ y_bar = y[j]
603
+ plt.text(x_bar, y_bar, i, fontsize=11, color="black")
604
+ plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
605
+ if "y1" in axis_keys.keys():
606
+ y1_bar = y1[j]
607
+ plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
608
+ if legend:
609
+ plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
610
+ if title is None:
611
+ plt.title("", fontsize=12)
612
+ else:
613
+ plt.title(title, fontsize=12)
614
+ ax.text(
615
+ max(x) - max(x) * x_coeff,
616
+ max(y) - y_coeff * max(y),
617
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
618
+ fontsize=kwargs.get("textsize", fontsize),
619
+ )
620
+ if diff_genes is not None:
621
+ ax.text(
622
+ max(x) - max(x) * x_coeff,
623
+ max(y) - (y_coeff + 0.15) * max(y),
624
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
625
+ fontsize=kwargs.get("textsize", fontsize),
626
+ )
627
+
628
+ if save:
629
+ plt.savefig(save, bbox_inches="tight")
630
+ if show:
631
+ plt.show()
632
+ plt.close()
633
+ if diff_genes is not None:
634
+ return r_value**2, r_value_diff**2
635
+ else:
636
+ return r_value**2
637
+
638
+ def plot_binary_classifier(
639
+ self,
640
+ scgen: SCGEN,
641
+ adata: AnnData | None,
642
+ delta: np.ndarray,
643
+ ctrl_key: str,
644
+ stim_key: str,
645
+ show: bool = False,
646
+ save: str | bool | None = None,
647
+ fontsize: float = 14,
648
+ ) -> plt.Axes | None:
649
+ """Plots the dot product between delta and latent representation of a linear classifier.
650
+
651
+ Builds a linear classifier based on the dot product between
652
+ the difference vector and the latent representation of each
653
+ cell and plots the dot product results between delta and latent representation.
654
+
655
+ Args:
656
+ scgen: ScGen object that was trained.
657
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
658
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
659
+ corresponding to batch and cell type metadata, respectively.
660
+ delta: Difference between stimulated and control cells in latent space
661
+ ctrl_key: Key for `control` part of the `data` found in `condition_key`.
662
+ stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
663
+ path_to_save: Path to save the plot.
664
+ save: Specify if the plot should be saved or not.
665
+ fontsize: Set the font size of the plot.
666
+ """
667
+ plt.close("all")
668
+ adata = scgen._validate_anndata(adata)
669
+ condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
670
+ cd = adata[adata.obs[condition_key] == ctrl_key, :]
671
+ stim = adata[adata.obs[condition_key] == stim_key, :]
672
+ all_latent_cd = scgen.get_latent_representation(cd.X)
673
+ all_latent_stim = scgen.get_latent_representation(stim.X)
674
+ dot_cd = np.zeros(len(all_latent_cd))
675
+ dot_sal = np.zeros(len(all_latent_stim))
676
+ for ind, vec in enumerate(all_latent_cd):
677
+ dot_cd[ind] = np.dot(delta, vec)
678
+ for ind, vec in enumerate(all_latent_stim):
679
+ dot_sal[ind] = np.dot(delta, vec)
680
+ plt.hist(
681
+ dot_cd,
682
+ label=ctrl_key,
683
+ bins=50,
684
+ )
685
+ plt.hist(dot_sal, label=stim_key, bins=50)
686
+ plt.axvline(0, color="k", linestyle="dashed", linewidth=1)
687
+ plt.title(" ", fontsize=fontsize)
688
+ plt.xlabel(" ", fontsize=fontsize)
689
+ plt.ylabel(" ", fontsize=fontsize)
690
+ plt.xticks(fontsize=fontsize)
691
+ plt.yticks(fontsize=fontsize)
692
+ ax = plt.gca()
693
+ ax.grid(False)
694
+
695
+ if save:
696
+ plt.savefig(save, bbox_inches="tight")
697
+ if show:
698
+ plt.show()
699
+ if not (show or save):
700
+ return ax
701
+ return None