pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,701 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ import jax.numpy as jnp
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import scanpy as sc
10
+ from adjustText import adjust_text
11
+ from anndata import AnnData
12
+ from jax import Array
13
+ from scipy import stats
14
+ from scvi import REGISTRY_KEYS
15
+ from scvi.data import AnnDataManager
16
+ from scvi.data.fields import CategoricalObsField, LayerField
17
+ from scvi.model.base import BaseModelClass, JaxTrainingMixin
18
+ from scvi.utils import setup_anndata_dsp
19
+
20
+ from ._scgenvae import JaxSCGENVAE
21
+ from ._utils import balancer, extractor
22
+
23
+ if TYPE_CHECKING:
24
+ from collections.abc import Sequence
25
+
26
+ font = {"family": "Arial", "size": 14}
27
+
28
+
29
+ class SCGEN(JaxTrainingMixin, BaseModelClass):
30
+ """Jax Implementation of scGen model for batch removal and perturbation prediction."""
31
+
32
+ def __init__(
33
+ self,
34
+ adata: AnnData,
35
+ n_hidden: int = 800,
36
+ n_latent: int = 100,
37
+ n_layers: int = 2,
38
+ dropout_rate: float = 0.2,
39
+ **model_kwargs,
40
+ ):
41
+ super().__init__(adata)
42
+
43
+ self.module = JaxSCGENVAE(
44
+ n_input=self.summary_stats.n_vars,
45
+ n_hidden=n_hidden,
46
+ n_latent=n_latent,
47
+ n_layers=n_layers,
48
+ dropout_rate=dropout_rate,
49
+ **model_kwargs,
50
+ )
51
+ self._model_summary_string = (
52
+ f"SCGEN Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
53
+ f"{dropout_rate}"
54
+ )
55
+ self.init_params_ = self._get_init_params(locals())
56
+
57
+ def predict(
58
+ self,
59
+ ctrl_key=None,
60
+ stim_key=None,
61
+ adata_to_predict=None,
62
+ celltype_to_predict=None,
63
+ restrict_arithmetic_to="all",
64
+ ) -> tuple[AnnData, Any]:
65
+ """Predicts the cell type provided by the user in stimulated condition.
66
+
67
+ Args:
68
+ ctrl_key: Key for `control` part of the `data` found in `condition_key`.
69
+ stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
70
+ adata_to_predict: Adata for unperturbed cells you want to be predicted.
71
+ celltype_to_predict: The cell type you want to be predicted.
72
+ restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
73
+
74
+ Returns:
75
+ `np nd-array` of predicted cells in primary space.
76
+ delta: float
77
+ Difference between stimulated and control cells in latent space
78
+
79
+ Examples:
80
+ >>> import pertpy as pt
81
+ >>> data = pt.dt.kang_2018()
82
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
83
+ >>> model = pt.tl.SCGEN(data)
84
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
85
+ >>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
86
+ """
87
+ # use keys registered from `setup_anndata()`
88
+ cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
89
+ condition_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
90
+
91
+ if restrict_arithmetic_to == "all":
92
+ ctrl_x = self.adata[self.adata.obs[condition_key] == ctrl_key, :]
93
+ stim_x = self.adata[self.adata.obs[condition_key] == stim_key, :]
94
+ ctrl_x = balancer(ctrl_x, cell_type_key)
95
+ stim_x = balancer(stim_x, cell_type_key)
96
+ else:
97
+ key = list(restrict_arithmetic_to.keys())[0]
98
+ values = restrict_arithmetic_to[key]
99
+ subset = self.adata[self.adata.obs[key].isin(values)]
100
+ ctrl_x = subset[subset.obs[condition_key] == ctrl_key, :]
101
+ stim_x = subset[subset.obs[condition_key] == stim_key, :]
102
+ if len(values) > 1:
103
+ ctrl_x = balancer(ctrl_x, cell_type_key)
104
+ stim_x = balancer(stim_x, cell_type_key)
105
+ if celltype_to_predict is not None and adata_to_predict is not None:
106
+ raise Exception("Please provide either a cell type or adata not both!")
107
+ if celltype_to_predict is None and adata_to_predict is None:
108
+ raise Exception("Please provide a cell type name or adata for your unperturbed cells")
109
+ if celltype_to_predict is not None:
110
+ ctrl_pred = extractor(
111
+ self.adata,
112
+ celltype_to_predict,
113
+ condition_key,
114
+ cell_type_key,
115
+ ctrl_key,
116
+ stim_key,
117
+ )[1]
118
+ else:
119
+ ctrl_pred = adata_to_predict
120
+
121
+ eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0])
122
+ rng = np.random.default_rng()
123
+ cd_ind = rng.choice(range(ctrl_x.shape[0]), size=eq, replace=False)
124
+ stim_ind = rng.choice(range(stim_x.shape[0]), size=eq, replace=False)
125
+ ctrl_adata = ctrl_x[cd_ind, :]
126
+ stim_adata = stim_x[stim_ind, :]
127
+
128
+ latent_ctrl = self._avg_vector(ctrl_adata)
129
+ latent_stim = self._avg_vector(stim_adata)
130
+
131
+ delta = latent_stim - latent_ctrl
132
+
133
+ latent_cd = self.get_latent_representation(ctrl_pred)
134
+
135
+ stim_pred = delta + latent_cd
136
+ predicted_cells = self.module.as_bound().generative(stim_pred)["px"]
137
+
138
+ predicted_adata = AnnData(
139
+ X=np.array(predicted_cells),
140
+ obs=ctrl_pred.obs.copy(),
141
+ var=ctrl_pred.var.copy(),
142
+ obsm=ctrl_pred.obsm.copy(),
143
+ )
144
+ return predicted_adata, delta
145
+
146
+ def _avg_vector(self, adata):
147
+ return np.mean(self.get_latent_representation(adata), axis=0)
148
+
149
+ def get_decoded_expression(
150
+ self,
151
+ adata: AnnData | None = None,
152
+ indices: Sequence[int] | None = None,
153
+ batch_size: int | None = None,
154
+ ) -> Array:
155
+ """Get decoded expression.
156
+
157
+ Args:
158
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
159
+ AnnData object used to initialize the model.
160
+ indices: Indices of cells in adata to use. If `None`, all cells are used.
161
+ batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
162
+
163
+ Returns:
164
+ Decoded expression for each cell
165
+
166
+ Examples:
167
+ >>> import pertpy as pt
168
+ >>> data = pt.dt.kang_2018()
169
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
170
+ >>> model = pt.tl.SCGEN(data)
171
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
172
+ >>> decoded_X = model.get_decoded_expression()
173
+ """
174
+ if self.is_trained_ is False:
175
+ raise RuntimeError("Please train the model first.")
176
+
177
+ adata = self._validate_anndata(adata)
178
+ scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
179
+ decoded = []
180
+ for tensors in scdl:
181
+ _, generative_outputs = self.module.as_bound()(tensors, compute_loss=False)
182
+ px = generative_outputs["px"]
183
+ decoded.append(px)
184
+
185
+ return jnp.concatenate(decoded)
186
+
187
+ def batch_removal(self, adata: AnnData | None = None) -> AnnData:
188
+ """Removes batch effects.
189
+
190
+ Args:
191
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
192
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
193
+ corresponding to batch and cell type metadata, respectively.
194
+
195
+ Returns:
196
+ corrected: `~anndata.AnnData`
197
+ AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
198
+ A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
199
+
200
+ Examples:
201
+ >>> import pertpy as pt
202
+ >>> data = pt.dt.kang_2018()
203
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
204
+ >>> model = pt.tl.SCGEN(data)
205
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
206
+ >>> corrected_adata = model.batch_removal()
207
+ """
208
+ adata = self._validate_anndata(adata)
209
+ latent_all = self.get_latent_representation(adata)
210
+ # use keys registered from `setup_anndata()`
211
+ cell_label_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
212
+ batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
213
+
214
+ adata_latent = AnnData(latent_all)
215
+ adata_latent.obs = adata.obs.copy(deep=True)
216
+ unique_cell_types = np.unique(adata_latent.obs[cell_label_key])
217
+ shared_ct = []
218
+ not_shared_ct = []
219
+ for cell_type in unique_cell_types:
220
+ temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
221
+ if len(np.unique(temp_cell.obs[batch_key])) < 2:
222
+ cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
223
+ not_shared_ct.append(cell_type_ann)
224
+ continue
225
+ temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
226
+ batch_list = {}
227
+ batch_ind = {}
228
+ max_batch = 0
229
+ max_batch_ind = ""
230
+ batches = np.unique(temp_cell.obs[batch_key])
231
+ for i in batches:
232
+ temp = temp_cell[temp_cell.obs[batch_key] == i]
233
+ temp_ind = temp_cell.obs[batch_key] == i
234
+ if max_batch < len(temp):
235
+ max_batch = len(temp)
236
+ max_batch_ind = i
237
+ batch_list[i] = temp
238
+ batch_ind[i] = temp_ind
239
+ max_batch_ann = batch_list[max_batch_ind]
240
+ for study in batch_list:
241
+ delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
242
+ batch_list[study].X = delta + batch_list[study].X
243
+ temp_cell[batch_ind[study]].X = batch_list[study].X
244
+ shared_ct.append(temp_cell)
245
+
246
+ all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
247
+ if "concat_batch" in all_shared_ann.obs.columns:
248
+ del all_shared_ann.obs["concat_batch"]
249
+ if len(not_shared_ct) < 1:
250
+ corrected = AnnData(
251
+ np.array(self.module.as_bound().generative(all_shared_ann.X)["px"]),
252
+ obs=all_shared_ann.obs,
253
+ )
254
+ corrected.var_names = adata.var_names.tolist()
255
+ corrected = corrected[adata.obs_names]
256
+ if adata.raw is not None:
257
+ adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
258
+ adata_raw.obs_names = adata.obs_names
259
+ corrected.raw = adata_raw
260
+ corrected.obsm["latent"] = all_shared_ann.X
261
+ corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
262
+ return corrected
263
+ else:
264
+ all_not_shared_ann = AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
265
+ all_corrected_data = AnnData.concatenate(
266
+ all_shared_ann,
267
+ all_not_shared_ann,
268
+ batch_key="concat_batch",
269
+ index_unique=None,
270
+ )
271
+ if "concat_batch" in all_shared_ann.obs.columns:
272
+ del all_corrected_data.obs["concat_batch"]
273
+ corrected = AnnData(
274
+ np.array(self.module.as_bound().generative(all_corrected_data.X)["px"]),
275
+ obs=all_corrected_data.obs,
276
+ )
277
+ corrected.var_names = adata.var_names.tolist()
278
+ corrected = corrected[adata.obs_names]
279
+ if adata.raw is not None:
280
+ adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
281
+ adata_raw.obs_names = adata.obs_names
282
+ corrected.raw = adata_raw
283
+ corrected.obsm["latent"] = all_corrected_data.X
284
+ corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
285
+
286
+ return corrected
287
+
288
+ @classmethod
289
+ @setup_anndata_dsp.dedent
290
+ def setup_anndata(
291
+ cls,
292
+ adata: AnnData,
293
+ batch_key: str | None = None,
294
+ labels_key: str | None = None,
295
+ **kwargs,
296
+ ):
297
+ """%(summary)s.
298
+
299
+ scGen expects the expression data to come from `adata.X`
300
+
301
+ %(param_batch_key)s
302
+ %(param_labels_key)s
303
+
304
+ Examples:
305
+ >>> import pertpy as pt
306
+ >>> data = pt.dt.kang_2018()
307
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
308
+ """
309
+ setup_method_args = cls._get_setup_method_args(**locals())
310
+ anndata_fields = [
311
+ LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=False),
312
+ CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
313
+ CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
314
+ ]
315
+ adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
316
+ adata_manager.register_fields(adata, **kwargs)
317
+ cls.register_manager(adata_manager)
318
+
319
+ def to_device(self, device):
320
+ pass
321
+
322
+ @property
323
+ def device(self):
324
+ return self.module.device
325
+
326
+ def get_latent_representation(
327
+ self,
328
+ adata: AnnData | None = None,
329
+ indices: Sequence[int] | None = None,
330
+ give_mean: bool = True,
331
+ n_samples: int = 1,
332
+ batch_size: int | None = None,
333
+ ) -> np.ndarray:
334
+ """Return the latent representation for each cell.
335
+
336
+ Args:
337
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
338
+ AnnData object used to initialize the model.
339
+ indices: Indices of cells in adata to use. If `None`, all cells are used.
340
+ batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
341
+
342
+ Returns:
343
+ Low-dimensional representation for each cell
344
+
345
+ Examples:
346
+ >>> import pertpy as pt
347
+ >>> data = pt.dt.kang_2018()
348
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
349
+ >>> model = pt.tl.SCGEN(data)
350
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
351
+ >>> latent_X = model.get_latent_representation()
352
+ """
353
+ self._check_if_trained(warn=False)
354
+
355
+ adata = self._validate_anndata(adata)
356
+ scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True)
357
+
358
+ jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"n_samples": n_samples})
359
+
360
+ latent = []
361
+ for array_dict in scdl:
362
+ out = jit_inference_fn(self.module.rngs, array_dict)
363
+ if give_mean:
364
+ z = out["qz"].mean
365
+ else:
366
+ z = out["z"]
367
+ latent.append(z)
368
+ concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
369
+ latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore
370
+
371
+ return self.module.as_numpy_array(latent)
372
+
373
+ def plot_reg_mean_plot(
374
+ self,
375
+ adata,
376
+ condition_key: str,
377
+ axis_keys: dict[str, str],
378
+ labels: dict[str, str],
379
+ save: str | bool | None = None,
380
+ gene_list: list[str] = None,
381
+ show: bool = False,
382
+ top_100_genes: list[str] = None,
383
+ verbose: bool = False,
384
+ legend: bool = True,
385
+ title: str = None,
386
+ x_coeff: float = 0.30,
387
+ y_coeff: float = 0.8,
388
+ fontsize: float = 14,
389
+ **kwargs,
390
+ ) -> tuple[float, float] | float:
391
+ """Plots mean matching for a set of specified genes.
392
+
393
+ Args:
394
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
395
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
396
+ corresponding to batch and cell type metadata, respectively.
397
+ condition_key: The key for the condition
398
+ axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
399
+ `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
400
+ labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
401
+ path_to_save: path to save the plot.
402
+ save: Specify if the plot should be saved or not.
403
+ gene_list: list of gene names to be plotted.
404
+ show: if `True`: will show to the plot after saving it.
405
+ top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
406
+ verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
407
+ legend: if `True`: plots a legend, defaults to `True`.
408
+ title: Set if you want the plot to display a title.
409
+ x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
410
+ y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
411
+ fontsize: Fontsize used for text in the plot, defaults to 14.
412
+ **kwargs:
413
+
414
+ Examples:
415
+ >>> import pertpy as pt
416
+ >>> data = pt.dt.kang_2018()
417
+ >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
418
+ >>> scg = pt.tl.SCGEN(data)
419
+ >>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
420
+ >>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
421
+ >>> pred.obs['label'] = 'pred'
422
+ >>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred)
423
+ >>> r2_value = scg.plot_reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \
424
+ labels={"x": "predicted", "y": "ground truth"}, save=False, show=True)
425
+
426
+ Preview:
427
+ .. image:: /_static/docstring_previews/scgen_reg_mean.png
428
+ """
429
+ import seaborn as sns
430
+
431
+ sns.set_theme()
432
+ sns.set_theme(color_codes=True)
433
+
434
+ diff_genes = top_100_genes
435
+ stim = adata[adata.obs[condition_key] == axis_keys["y"]]
436
+ ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
437
+ if diff_genes is not None:
438
+ if hasattr(diff_genes, "tolist"):
439
+ diff_genes = diff_genes.tolist()
440
+ adata_diff = adata[:, diff_genes]
441
+ stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
442
+ ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
443
+ x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel()
444
+ y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
445
+ m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
446
+ if verbose:
447
+ print("top_100 DEGs mean: ", r_value_diff**2)
448
+ x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
449
+ y = np.asarray(np.mean(stim.X, axis=0)).ravel()
450
+ m, b, r_value, p_value, std_err = stats.linregress(x, y)
451
+ if verbose:
452
+ print("All genes mean: ", r_value**2)
453
+ df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
454
+ ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
455
+ ax.tick_params(labelsize=fontsize)
456
+ if "range" in kwargs:
457
+ start, stop, step = kwargs.get("range")
458
+ ax.set_xticks(np.arange(start, stop, step))
459
+ ax.set_yticks(np.arange(start, stop, step))
460
+ ax.set_xlabel(labels["x"], fontsize=fontsize)
461
+ ax.set_ylabel(labels["y"], fontsize=fontsize)
462
+ if gene_list is not None:
463
+ texts = []
464
+ for i in gene_list:
465
+ j = adata.var_names.tolist().index(i)
466
+ x_bar = x[j]
467
+ y_bar = y[j]
468
+ texts.append(plt.text(x_bar, y_bar, i, fontsize=11, color="black"))
469
+ plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
470
+ # if "y1" in axis_keys.keys():
471
+ # y1_bar = y1[j]
472
+ # plt.text(x_bar, y1_bar, i, fontsize=11, color="black")
473
+ if gene_list is not None:
474
+ adjust_text(
475
+ texts,
476
+ x=x,
477
+ y=y,
478
+ arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5},
479
+ force_static=(0.0, 0.0),
480
+ )
481
+ if legend:
482
+ plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
483
+ if title is None:
484
+ plt.title("", fontsize=fontsize)
485
+ else:
486
+ plt.title(title, fontsize=fontsize)
487
+ ax.text(
488
+ max(x) - max(x) * x_coeff,
489
+ max(y) - y_coeff * max(y),
490
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
491
+ fontsize=kwargs.get("textsize", fontsize),
492
+ )
493
+ if diff_genes is not None:
494
+ ax.text(
495
+ max(x) - max(x) * x_coeff,
496
+ max(y) - (y_coeff + 0.15) * max(y),
497
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
498
+ fontsize=kwargs.get("textsize", fontsize),
499
+ )
500
+ if save:
501
+ plt.savefig(save, bbox_inches="tight")
502
+ if show:
503
+ plt.show()
504
+ plt.close()
505
+ if diff_genes is not None:
506
+ return r_value**2, r_value_diff**2
507
+ else:
508
+ return r_value**2
509
+
510
+ def plot_reg_var_plot(
511
+ self,
512
+ adata,
513
+ condition_key: str,
514
+ axis_keys: dict[str, str],
515
+ labels: dict[str, str],
516
+ save: str | bool | None = None,
517
+ gene_list: list[str] = None,
518
+ top_100_genes: list[str] = None,
519
+ show: bool = False,
520
+ legend: bool = True,
521
+ title: str = None,
522
+ verbose: bool = False,
523
+ x_coeff: float = 0.3,
524
+ y_coeff: float = 0.8,
525
+ fontsize: float = 14,
526
+ **kwargs,
527
+ ) -> tuple[float, float] | float:
528
+ """Plots variance matching for a set of specified genes.
529
+
530
+ Args:
531
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
532
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
533
+ corresponding to batch and cell type metadata, respectively.
534
+ condition_key: Key of the condition.
535
+ axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
536
+ `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
537
+ labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
538
+ path_to_save: path to save the plot.
539
+ save: Specify if the plot should be saved or not.
540
+ gene_list: list of gene names to be plotted.
541
+ show: if `True`: will show to the plot after saving it.
542
+ top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
543
+ legend: if `True`: plots a legend, defaults to `True`.
544
+ title: Set if you want the plot to display a title.
545
+ verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
546
+ x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
547
+ y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
548
+ fontsize: Fontsize used for text in the plot, defaults to 14.
549
+ """
550
+ import seaborn as sns
551
+
552
+ sns.set_theme()
553
+ sns.set_theme(color_codes=True)
554
+
555
+ sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon")
556
+ diff_genes = top_100_genes
557
+ stim = adata[adata.obs[condition_key] == axis_keys["y"]]
558
+ ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
559
+ if diff_genes is not None:
560
+ if hasattr(diff_genes, "tolist"):
561
+ diff_genes = diff_genes.tolist()
562
+ adata_diff = adata[:, diff_genes]
563
+ stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
564
+ ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
565
+ x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel()
566
+ y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
567
+ m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
568
+ if verbose:
569
+ print("Top 100 DEGs var: ", r_value_diff**2)
570
+ if "y1" in axis_keys.keys():
571
+ real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
572
+ x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
573
+ y = np.asarray(np.var(stim.X, axis=0)).ravel()
574
+ m, b, r_value, p_value, std_err = stats.linregress(x, y)
575
+ if verbose:
576
+ print("All genes var: ", r_value**2)
577
+ df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
578
+ ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
579
+ ax.tick_params(labelsize=fontsize)
580
+ if "range" in kwargs:
581
+ start, stop, step = kwargs.get("range")
582
+ ax.set_xticks(np.arange(start, stop, step))
583
+ ax.set_yticks(np.arange(start, stop, step))
584
+ # _p1 = plt.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}")
585
+ # plt.plot(x, m * x + b, "-", color="green")
586
+ ax.set_xlabel(labels["x"], fontsize=fontsize)
587
+ ax.set_ylabel(labels["y"], fontsize=fontsize)
588
+ if "y1" in axis_keys.keys():
589
+ y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
590
+ _ = plt.scatter(
591
+ x,
592
+ y1,
593
+ marker="*",
594
+ c="grey",
595
+ alpha=0.5,
596
+ label=f"{axis_keys['x']}-{axis_keys['y1']}",
597
+ )
598
+ if gene_list is not None:
599
+ for i in gene_list:
600
+ j = adata.var_names.tolist().index(i)
601
+ x_bar = x[j]
602
+ y_bar = y[j]
603
+ plt.text(x_bar, y_bar, i, fontsize=11, color="black")
604
+ plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
605
+ if "y1" in axis_keys.keys():
606
+ y1_bar = y1[j]
607
+ plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
608
+ if legend:
609
+ plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
610
+ if title is None:
611
+ plt.title("", fontsize=12)
612
+ else:
613
+ plt.title(title, fontsize=12)
614
+ ax.text(
615
+ max(x) - max(x) * x_coeff,
616
+ max(y) - y_coeff * max(y),
617
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
618
+ fontsize=kwargs.get("textsize", fontsize),
619
+ )
620
+ if diff_genes is not None:
621
+ ax.text(
622
+ max(x) - max(x) * x_coeff,
623
+ max(y) - (y_coeff + 0.15) * max(y),
624
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
625
+ fontsize=kwargs.get("textsize", fontsize),
626
+ )
627
+
628
+ if save:
629
+ plt.savefig(save, bbox_inches="tight")
630
+ if show:
631
+ plt.show()
632
+ plt.close()
633
+ if diff_genes is not None:
634
+ return r_value**2, r_value_diff**2
635
+ else:
636
+ return r_value**2
637
+
638
+ def plot_binary_classifier(
639
+ self,
640
+ scgen: SCGEN,
641
+ adata: AnnData | None,
642
+ delta: np.ndarray,
643
+ ctrl_key: str,
644
+ stim_key: str,
645
+ show: bool = False,
646
+ save: str | bool | None = None,
647
+ fontsize: float = 14,
648
+ ) -> plt.Axes | None:
649
+ """Plots the dot product between delta and latent representation of a linear classifier.
650
+
651
+ Builds a linear classifier based on the dot product between
652
+ the difference vector and the latent representation of each
653
+ cell and plots the dot product results between delta and latent representation.
654
+
655
+ Args:
656
+ scgen: ScGen object that was trained.
657
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
658
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
659
+ corresponding to batch and cell type metadata, respectively.
660
+ delta: Difference between stimulated and control cells in latent space
661
+ ctrl_key: Key for `control` part of the `data` found in `condition_key`.
662
+ stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
663
+ path_to_save: Path to save the plot.
664
+ save: Specify if the plot should be saved or not.
665
+ fontsize: Set the font size of the plot.
666
+ """
667
+ plt.close("all")
668
+ adata = scgen._validate_anndata(adata)
669
+ condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
670
+ cd = adata[adata.obs[condition_key] == ctrl_key, :]
671
+ stim = adata[adata.obs[condition_key] == stim_key, :]
672
+ all_latent_cd = scgen.get_latent_representation(cd.X)
673
+ all_latent_stim = scgen.get_latent_representation(stim.X)
674
+ dot_cd = np.zeros(len(all_latent_cd))
675
+ dot_sal = np.zeros(len(all_latent_stim))
676
+ for ind, vec in enumerate(all_latent_cd):
677
+ dot_cd[ind] = np.dot(delta, vec)
678
+ for ind, vec in enumerate(all_latent_stim):
679
+ dot_sal[ind] = np.dot(delta, vec)
680
+ plt.hist(
681
+ dot_cd,
682
+ label=ctrl_key,
683
+ bins=50,
684
+ )
685
+ plt.hist(dot_sal, label=stim_key, bins=50)
686
+ plt.axvline(0, color="k", linestyle="dashed", linewidth=1)
687
+ plt.title(" ", fontsize=fontsize)
688
+ plt.xlabel(" ", fontsize=fontsize)
689
+ plt.ylabel(" ", fontsize=fontsize)
690
+ plt.xticks(fontsize=fontsize)
691
+ plt.yticks(fontsize=fontsize)
692
+ ax = plt.gca()
693
+ ax.grid(False)
694
+
695
+ if save:
696
+ plt.savefig(save, bbox_inches="tight")
697
+ if show:
698
+ plt.show()
699
+ if not (show or save):
700
+ return ax
701
+ return None