pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,706 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ import jax.numpy as jnp
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import scanpy as sc
10
+ from adjustText import adjust_text
11
+ from anndata import AnnData
12
+ from jax import Array
13
+ from lamin_utils import logger
14
+ from scipy import stats
15
+ from scvi import REGISTRY_KEYS
16
+ from scvi.data import AnnDataManager
17
+ from scvi.data.fields import CategoricalObsField, LayerField
18
+ from scvi.model.base import BaseModelClass, JaxTrainingMixin
19
+ from scvi.utils import setup_anndata_dsp
20
+
21
+ from ._scgenvae import JaxSCGENVAE
22
+ from ._utils import balancer, extractor
23
+
24
+ if TYPE_CHECKING:
25
+ from collections.abc import Sequence
26
+
27
+ font = {"family": "Arial", "size": 14}
28
+
29
+
30
+ class Scgen(JaxTrainingMixin, BaseModelClass):
31
+ """Jax Implementation of scGen model for batch removal and perturbation prediction."""
32
+
33
+ def __init__(
34
+ self,
35
+ adata: AnnData,
36
+ n_hidden: int = 800,
37
+ n_latent: int = 100,
38
+ n_layers: int = 2,
39
+ dropout_rate: float = 0.2,
40
+ **model_kwargs,
41
+ ):
42
+ super().__init__(adata)
43
+
44
+ self.module = JaxSCGENVAE(
45
+ n_input=self.summary_stats.n_vars,
46
+ n_hidden=n_hidden,
47
+ n_latent=n_latent,
48
+ n_layers=n_layers,
49
+ dropout_rate=dropout_rate,
50
+ **model_kwargs,
51
+ )
52
+ self._model_summary_string = (
53
+ f"Scgen Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
54
+ f"{dropout_rate}"
55
+ )
56
+ self.init_params_ = self._get_init_params(locals())
57
+
58
+ def predict(
59
+ self,
60
+ ctrl_key=None,
61
+ stim_key=None,
62
+ adata_to_predict=None,
63
+ celltype_to_predict=None,
64
+ restrict_arithmetic_to="all",
65
+ ) -> tuple[AnnData, Any]:
66
+ """Predicts the cell type provided by the user in stimulated condition.
67
+
68
+ Args:
69
+ ctrl_key: Key for `control` part of the `data` found in `condition_key`.
70
+ stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
71
+ adata_to_predict: Adata for unperturbed cells you want to be predicted.
72
+ celltype_to_predict: The cell type you want to be predicted.
73
+ restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
74
+
75
+ Returns:
76
+ `np nd-array` of predicted cells in primary space.
77
+ delta: float
78
+ Difference between stimulated and control cells in latent space
79
+
80
+ Examples:
81
+ >>> import pertpy as pt
82
+ >>> data = pt.dt.kang_2018()
83
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
84
+ >>> model = pt.tl.Scgen(data)
85
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
86
+ >>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
87
+ """
88
+ # use keys registered from `setup_anndata()`
89
+ cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
90
+ condition_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
91
+
92
+ if restrict_arithmetic_to == "all":
93
+ ctrl_x = self.adata[self.adata.obs[condition_key] == ctrl_key, :]
94
+ stim_x = self.adata[self.adata.obs[condition_key] == stim_key, :]
95
+ ctrl_x = balancer(ctrl_x, cell_type_key)
96
+ stim_x = balancer(stim_x, cell_type_key)
97
+ else:
98
+ key = list(restrict_arithmetic_to.keys())[0]
99
+ values = restrict_arithmetic_to[key]
100
+ subset = self.adata[self.adata.obs[key].isin(values)]
101
+ ctrl_x = subset[subset.obs[condition_key] == ctrl_key, :]
102
+ stim_x = subset[subset.obs[condition_key] == stim_key, :]
103
+ if len(values) > 1:
104
+ ctrl_x = balancer(ctrl_x, cell_type_key)
105
+ stim_x = balancer(stim_x, cell_type_key)
106
+ if celltype_to_predict is not None and adata_to_predict is not None:
107
+ raise Exception("Please provide either a cell type or adata not both!")
108
+ if celltype_to_predict is None and adata_to_predict is None:
109
+ raise Exception("Please provide a cell type name or adata for your unperturbed cells")
110
+ if celltype_to_predict is not None:
111
+ ctrl_pred = extractor(
112
+ self.adata,
113
+ celltype_to_predict,
114
+ condition_key,
115
+ cell_type_key,
116
+ ctrl_key,
117
+ stim_key,
118
+ )[1]
119
+ else:
120
+ ctrl_pred = adata_to_predict
121
+
122
+ eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0])
123
+ rng = np.random.default_rng()
124
+ cd_ind = rng.choice(range(ctrl_x.shape[0]), size=eq, replace=False)
125
+ stim_ind = rng.choice(range(stim_x.shape[0]), size=eq, replace=False)
126
+ ctrl_adata = ctrl_x[cd_ind, :]
127
+ stim_adata = stim_x[stim_ind, :]
128
+
129
+ latent_ctrl = self._avg_vector(ctrl_adata)
130
+ latent_stim = self._avg_vector(stim_adata)
131
+
132
+ delta = latent_stim - latent_ctrl
133
+
134
+ latent_cd = self.get_latent_representation(ctrl_pred)
135
+
136
+ stim_pred = delta + latent_cd
137
+ predicted_cells = self.module.as_bound().generative(stim_pred)["px"]
138
+
139
+ predicted_adata = AnnData(
140
+ X=np.array(predicted_cells),
141
+ obs=ctrl_pred.obs.copy(),
142
+ var=ctrl_pred.var.copy(),
143
+ obsm=ctrl_pred.obsm.copy(),
144
+ )
145
+ return predicted_adata, delta
146
+
147
+ def _avg_vector(self, adata):
148
+ return np.mean(self.get_latent_representation(adata), axis=0)
149
+
150
+ def get_decoded_expression(
151
+ self,
152
+ adata: AnnData | None = None,
153
+ indices: Sequence[int] | None = None,
154
+ batch_size: int | None = None,
155
+ ) -> Array:
156
+ """Get decoded expression.
157
+
158
+ Args:
159
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
160
+ AnnData object used to initialize the model.
161
+ indices: Indices of cells in adata to use. If `None`, all cells are used.
162
+ batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
163
+
164
+ Returns:
165
+ Decoded expression for each cell
166
+
167
+ Examples:
168
+ >>> import pertpy as pt
169
+ >>> data = pt.dt.kang_2018()
170
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
171
+ >>> model = pt.tl.Scgen(data)
172
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
173
+ >>> decoded_X = model.get_decoded_expression()
174
+ """
175
+ if self.is_trained_ is False:
176
+ raise RuntimeError("Please train the model first.")
177
+
178
+ adata = self._validate_anndata(adata)
179
+ scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
180
+ decoded = []
181
+ for tensors in scdl:
182
+ _, generative_outputs = self.module.as_bound()(tensors, compute_loss=False)
183
+ px = generative_outputs["px"]
184
+ decoded.append(px)
185
+
186
+ return jnp.concatenate(decoded)
187
+
188
+ def batch_removal(self, adata: AnnData | None = None) -> AnnData:
189
+ """Removes batch effects.
190
+
191
+ Args:
192
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
193
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
194
+ corresponding to batch and cell type metadata, respectively.
195
+
196
+ Returns:
197
+ corrected: `~anndata.AnnData`
198
+ AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
199
+ A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
200
+
201
+ Examples:
202
+ >>> import pertpy as pt
203
+ >>> data = pt.dt.kang_2018()
204
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
205
+ >>> model = pt.tl.Scgen(data)
206
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
207
+ >>> corrected_adata = model.batch_removal()
208
+ """
209
+ adata = self._validate_anndata(adata)
210
+ latent_all = self.get_latent_representation(adata)
211
+ # use keys registered from `setup_anndata()`
212
+ cell_label_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
213
+ batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
214
+
215
+ adata_latent = AnnData(latent_all)
216
+ adata_latent.obs = adata.obs.copy(deep=True)
217
+ unique_cell_types = np.unique(adata_latent.obs[cell_label_key])
218
+ shared_ct = []
219
+ not_shared_ct = []
220
+ for cell_type in unique_cell_types:
221
+ temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
222
+ if len(np.unique(temp_cell.obs[batch_key])) < 2:
223
+ cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
224
+ not_shared_ct.append(cell_type_ann)
225
+ continue
226
+ temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
227
+ batch_list = {}
228
+ batch_ind = {}
229
+ max_batch = 0
230
+ max_batch_ind = ""
231
+ batches = np.unique(temp_cell.obs[batch_key])
232
+ for i in batches:
233
+ temp = temp_cell[temp_cell.obs[batch_key] == i]
234
+ temp_ind = temp_cell.obs[batch_key] == i
235
+ if max_batch < len(temp):
236
+ max_batch = len(temp)
237
+ max_batch_ind = i
238
+ batch_list[i] = temp
239
+ batch_ind[i] = temp_ind
240
+ max_batch_ann = batch_list[max_batch_ind]
241
+ for study in batch_list:
242
+ delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
243
+ batch_list[study].X = delta + batch_list[study].X
244
+ temp_cell[batch_ind[study]].X = batch_list[study].X
245
+ shared_ct.append(temp_cell)
246
+
247
+ all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
248
+ if "concat_batch" in all_shared_ann.obs.columns:
249
+ del all_shared_ann.obs["concat_batch"]
250
+ if len(not_shared_ct) < 1:
251
+ corrected = AnnData(
252
+ np.array(self.module.as_bound().generative(all_shared_ann.X)["px"]),
253
+ obs=all_shared_ann.obs,
254
+ )
255
+ corrected.var_names = adata.var_names.tolist()
256
+ corrected = corrected[adata.obs_names]
257
+ if adata.raw is not None:
258
+ adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
259
+ adata_raw.obs_names = adata.obs_names
260
+ corrected.raw = adata_raw
261
+ corrected.obsm["latent"] = all_shared_ann.X
262
+ corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
263
+ return corrected
264
+ else:
265
+ all_not_shared_ann = AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
266
+ all_corrected_data = AnnData.concatenate(
267
+ all_shared_ann,
268
+ all_not_shared_ann,
269
+ batch_key="concat_batch",
270
+ index_unique=None,
271
+ )
272
+ if "concat_batch" in all_shared_ann.obs.columns:
273
+ del all_corrected_data.obs["concat_batch"]
274
+ corrected = AnnData(
275
+ np.array(self.module.as_bound().generative(all_corrected_data.X)["px"]),
276
+ obs=all_corrected_data.obs,
277
+ )
278
+ corrected.var_names = adata.var_names.tolist()
279
+ corrected = corrected[adata.obs_names]
280
+ if adata.raw is not None:
281
+ adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
282
+ adata_raw.obs_names = adata.obs_names
283
+ corrected.raw = adata_raw
284
+ corrected.obsm["latent"] = all_corrected_data.X
285
+ corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
286
+
287
+ return corrected
288
+
289
+ @classmethod
290
+ @setup_anndata_dsp.dedent
291
+ def setup_anndata(
292
+ cls,
293
+ adata: AnnData,
294
+ batch_key: str | None = None,
295
+ labels_key: str | None = None,
296
+ **kwargs,
297
+ ):
298
+ """%(summary)s.
299
+
300
+ scGen expects the expression data to come from `adata.X`
301
+
302
+ %(param_batch_key)s
303
+ %(param_labels_key)s
304
+
305
+ Examples:
306
+ >>> import pertpy as pt
307
+ >>> data = pt.dt.kang_2018()
308
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
309
+ """
310
+ setup_method_args = cls._get_setup_method_args(**locals())
311
+ anndata_fields = [
312
+ LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=False),
313
+ CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
314
+ CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
315
+ ]
316
+ adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
317
+ adata_manager.register_fields(adata, **kwargs)
318
+ cls.register_manager(adata_manager)
319
+
320
+ def to_device(self, device):
321
+ pass
322
+
323
+ @property
324
+ def device(self):
325
+ return self.module.device
326
+
327
+ def get_latent_representation(
328
+ self,
329
+ adata: AnnData | None = None,
330
+ indices: Sequence[int] | None = None,
331
+ give_mean: bool = True,
332
+ n_samples: int = 1,
333
+ batch_size: int | None = None,
334
+ ) -> np.ndarray:
335
+ """Return the latent representation for each cell.
336
+
337
+ Args:
338
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
339
+ AnnData object used to initialize the model.
340
+ indices: Indices of cells in adata to use. If `None`, all cells are used.
341
+ batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
342
+
343
+ Returns:
344
+ Low-dimensional representation for each cell
345
+
346
+ Examples:
347
+ >>> import pertpy as pt
348
+ >>> data = pt.dt.kang_2018()
349
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
350
+ >>> model = pt.tl.Scgen(data)
351
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
352
+ >>> latent_X = model.get_latent_representation()
353
+ """
354
+ self._check_if_trained(warn=False)
355
+
356
+ adata = self._validate_anndata(adata)
357
+ scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True)
358
+
359
+ jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"n_samples": n_samples})
360
+
361
+ latent = []
362
+ for array_dict in scdl:
363
+ out = jit_inference_fn(self.module.rngs, array_dict)
364
+ if give_mean:
365
+ z = out["qz"].mean
366
+ else:
367
+ z = out["z"]
368
+ latent.append(z)
369
+ concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
370
+ latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore
371
+
372
+ return self.module.as_numpy_array(latent)
373
+
374
+ def plot_reg_mean_plot(
375
+ self,
376
+ adata,
377
+ condition_key: str,
378
+ axis_keys: dict[str, str],
379
+ labels: dict[str, str],
380
+ save: str | bool | None = None,
381
+ gene_list: list[str] = None,
382
+ show: bool = False,
383
+ top_100_genes: list[str] = None,
384
+ verbose: bool = False,
385
+ legend: bool = True,
386
+ title: str = None,
387
+ x_coeff: float = 0.30,
388
+ y_coeff: float = 0.8,
389
+ fontsize: float = 14,
390
+ **kwargs,
391
+ ) -> tuple[float, float] | float:
392
+ """Plots mean matching for a set of specified genes.
393
+
394
+ Args:
395
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
396
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
397
+ corresponding to batch and cell type metadata, respectively.
398
+ condition_key: The key for the condition
399
+ axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
400
+ `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
401
+ labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
402
+ path_to_save: path to save the plot.
403
+ save: Specify if the plot should be saved or not.
404
+ gene_list: list of gene names to be plotted.
405
+ show: if `True`: will show to the plot after saving it.
406
+ top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
407
+ verbose: Specify if you want information to be printed while creating the plot.,
408
+ legend: Whether to plot a legend.
409
+ title: Set if you want the plot to display a title.
410
+ x_coeff: Offset to print the R^2 value in x-direction.
411
+ y_coeff: Offset to print the R^2 value in y-direction.
412
+ fontsize: Fontsize used for text in the plot.
413
+ **kwargs:
414
+
415
+ Examples:
416
+ >>> import pertpy as pt
417
+ >>> data = pt.dt.kang_2018()
418
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
419
+ >>> scg = pt.tl.Scgen(data)
420
+ >>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
421
+ >>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
422
+ >>> pred.obs['label'] = 'pred'
423
+ >>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred)
424
+ >>> r2_value = scg.plot_reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \
425
+ labels={"x": "predicted", "y": "ground truth"}, save=False, show=True)
426
+
427
+ Preview:
428
+ .. image:: /_static/docstring_previews/scgen_reg_mean.png
429
+ """
430
+ import seaborn as sns
431
+
432
+ sns.set_theme()
433
+ sns.set_theme(color_codes=True)
434
+
435
+ diff_genes = top_100_genes
436
+ stim = adata[adata.obs[condition_key] == axis_keys["y"]]
437
+ ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
438
+ if diff_genes is not None:
439
+ if hasattr(diff_genes, "tolist"):
440
+ diff_genes = diff_genes.tolist()
441
+ adata_diff = adata[:, diff_genes]
442
+ stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
443
+ ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
444
+ x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel()
445
+ y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
446
+ m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
447
+ if verbose:
448
+ logger.info("top_100 DEGs mean: ", r_value_diff**2)
449
+ x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
450
+ y = np.asarray(np.mean(stim.X, axis=0)).ravel()
451
+ m, b, r_value, p_value, std_err = stats.linregress(x, y)
452
+ if verbose:
453
+ logger.info("All genes mean: ", r_value**2)
454
+ df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
455
+ ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
456
+ ax.tick_params(labelsize=fontsize)
457
+ if "range" in kwargs:
458
+ start, stop, step = kwargs.get("range")
459
+ ax.set_xticks(np.arange(start, stop, step))
460
+ ax.set_yticks(np.arange(start, stop, step))
461
+ ax.set_xlabel(labels["x"], fontsize=fontsize)
462
+ ax.set_ylabel(labels["y"], fontsize=fontsize)
463
+ if gene_list is not None:
464
+ texts = []
465
+ for i in gene_list:
466
+ j = adata.var_names.tolist().index(i)
467
+ x_bar = x[j]
468
+ y_bar = y[j]
469
+ texts.append(plt.text(x_bar, y_bar, i, fontsize=11, color="black"))
470
+ plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
471
+ # if "y1" in axis_keys.keys():
472
+ # y1_bar = y1[j]
473
+ # plt.text(x_bar, y1_bar, i, fontsize=11, color="black")
474
+ if gene_list is not None:
475
+ adjust_text(
476
+ texts,
477
+ x=x,
478
+ y=y,
479
+ arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5},
480
+ force_static=(0.0, 0.0),
481
+ )
482
+ if legend:
483
+ plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
484
+ if title is None:
485
+ plt.title("", fontsize=fontsize)
486
+ else:
487
+ plt.title(title, fontsize=fontsize)
488
+ ax.text(
489
+ max(x) - max(x) * x_coeff,
490
+ max(y) - y_coeff * max(y),
491
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
492
+ fontsize=kwargs.get("textsize", fontsize),
493
+ )
494
+ if diff_genes is not None:
495
+ ax.text(
496
+ max(x) - max(x) * x_coeff,
497
+ max(y) - (y_coeff + 0.15) * max(y),
498
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
499
+ fontsize=kwargs.get("textsize", fontsize),
500
+ )
501
+ if save:
502
+ plt.savefig(save, bbox_inches="tight")
503
+ if show:
504
+ plt.show()
505
+ plt.close()
506
+ if diff_genes is not None:
507
+ return r_value**2, r_value_diff**2
508
+ else:
509
+ return r_value**2
510
+
511
+ def plot_reg_var_plot(
512
+ self,
513
+ adata,
514
+ condition_key: str,
515
+ axis_keys: dict[str, str],
516
+ labels: dict[str, str],
517
+ save: str | bool | None = None,
518
+ gene_list: list[str] = None,
519
+ top_100_genes: list[str] = None,
520
+ show: bool = False,
521
+ legend: bool = True,
522
+ title: str = None,
523
+ verbose: bool = False,
524
+ x_coeff: float = 0.3,
525
+ y_coeff: float = 0.8,
526
+ fontsize: float = 14,
527
+ **kwargs,
528
+ ) -> tuple[float, float] | float:
529
+ """Plots variance matching for a set of specified genes.
530
+
531
+ Args:
532
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
533
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
534
+ corresponding to batch and cell type metadata, respectively.
535
+ condition_key: Key of the condition.
536
+ axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
537
+ `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
538
+ labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
539
+ path_to_save: path to save the plot.
540
+ save: Specify if the plot should be saved or not.
541
+ gene_list: list of gene names to be plotted.
542
+ show: if `True`: will show to the plot after saving it.
543
+ top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
544
+ legend: Whether to plot a elgend
545
+ title: Set if you want the plot to display a title.
546
+ verbose: Specify if you want information to be printed while creating the plot.
547
+ x_coeff: Offset to print the R^2 value in x-direction.
548
+ y_coeff: Offset to print the R^2 value in y-direction.
549
+ fontsize: Fontsize used for text in the plot.
550
+ """
551
+ import seaborn as sns
552
+
553
+ sns.set_theme()
554
+ sns.set_theme(color_codes=True)
555
+
556
+ sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon")
557
+ diff_genes = top_100_genes
558
+ stim = adata[adata.obs[condition_key] == axis_keys["y"]]
559
+ ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
560
+ if diff_genes is not None:
561
+ if hasattr(diff_genes, "tolist"):
562
+ diff_genes = diff_genes.tolist()
563
+ adata_diff = adata[:, diff_genes]
564
+ stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
565
+ ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
566
+ x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel()
567
+ y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
568
+ m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
569
+ if verbose:
570
+ logger.info("Top 100 DEGs var: ", r_value_diff**2)
571
+ if "y1" in axis_keys.keys():
572
+ real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
573
+ x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
574
+ y = np.asarray(np.var(stim.X, axis=0)).ravel()
575
+ m, b, r_value, p_value, std_err = stats.linregress(x, y)
576
+ if verbose:
577
+ logger.info("All genes var: ", r_value**2)
578
+ df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
579
+ ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
580
+ ax.tick_params(labelsize=fontsize)
581
+ if "range" in kwargs:
582
+ start, stop, step = kwargs.get("range")
583
+ ax.set_xticks(np.arange(start, stop, step))
584
+ ax.set_yticks(np.arange(start, stop, step))
585
+ # _p1 = plt.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}")
586
+ # plt.plot(x, m * x + b, "-", color="green")
587
+ ax.set_xlabel(labels["x"], fontsize=fontsize)
588
+ ax.set_ylabel(labels["y"], fontsize=fontsize)
589
+ if "y1" in axis_keys.keys():
590
+ y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
591
+ _ = plt.scatter(
592
+ x,
593
+ y1,
594
+ marker="*",
595
+ c="grey",
596
+ alpha=0.5,
597
+ label=f"{axis_keys['x']}-{axis_keys['y1']}",
598
+ )
599
+ if gene_list is not None:
600
+ for i in gene_list:
601
+ j = adata.var_names.tolist().index(i)
602
+ x_bar = x[j]
603
+ y_bar = y[j]
604
+ plt.text(x_bar, y_bar, i, fontsize=11, color="black")
605
+ plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
606
+ if "y1" in axis_keys.keys():
607
+ y1_bar = y1[j]
608
+ plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
609
+ if legend:
610
+ plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
611
+ if title is None:
612
+ plt.title("", fontsize=12)
613
+ else:
614
+ plt.title(title, fontsize=12)
615
+ ax.text(
616
+ max(x) - max(x) * x_coeff,
617
+ max(y) - y_coeff * max(y),
618
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
619
+ fontsize=kwargs.get("textsize", fontsize),
620
+ )
621
+ if diff_genes is not None:
622
+ ax.text(
623
+ max(x) - max(x) * x_coeff,
624
+ max(y) - (y_coeff + 0.15) * max(y),
625
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
626
+ fontsize=kwargs.get("textsize", fontsize),
627
+ )
628
+
629
+ if save:
630
+ plt.savefig(save, bbox_inches="tight")
631
+ if show:
632
+ plt.show()
633
+ plt.close()
634
+ if diff_genes is not None:
635
+ return r_value**2, r_value_diff**2
636
+ else:
637
+ return r_value**2
638
+
639
+ def plot_binary_classifier(
640
+ self,
641
+ scgen: Scgen,
642
+ adata: AnnData | None,
643
+ delta: np.ndarray,
644
+ ctrl_key: str,
645
+ stim_key: str,
646
+ show: bool = False,
647
+ save: str | bool | None = None,
648
+ fontsize: float = 14,
649
+ ) -> plt.Axes | None:
650
+ """Plots the dot product between delta and latent representation of a linear classifier.
651
+
652
+ Builds a linear classifier based on the dot product between
653
+ the difference vector and the latent representation of each
654
+ cell and plots the dot product results between delta and latent representation.
655
+
656
+ Args:
657
+ scgen: ScGen object that was trained.
658
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
659
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
660
+ corresponding to batch and cell type metadata, respectively.
661
+ delta: Difference between stimulated and control cells in latent space
662
+ ctrl_key: Key for `control` part of the `data` found in `condition_key`.
663
+ stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
664
+ path_to_save: Path to save the plot.
665
+ save: Specify if the plot should be saved or not.
666
+ fontsize: Set the font size of the plot.
667
+ """
668
+ plt.close("all")
669
+ adata = scgen._validate_anndata(adata)
670
+ condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
671
+ cd = adata[adata.obs[condition_key] == ctrl_key, :]
672
+ stim = adata[adata.obs[condition_key] == stim_key, :]
673
+ all_latent_cd = scgen.get_latent_representation(cd.X)
674
+ all_latent_stim = scgen.get_latent_representation(stim.X)
675
+ dot_cd = np.zeros(len(all_latent_cd))
676
+ dot_sal = np.zeros(len(all_latent_stim))
677
+ for ind, vec in enumerate(all_latent_cd):
678
+ dot_cd[ind] = np.dot(delta, vec)
679
+ for ind, vec in enumerate(all_latent_stim):
680
+ dot_sal[ind] = np.dot(delta, vec)
681
+ plt.hist(
682
+ dot_cd,
683
+ label=ctrl_key,
684
+ bins=50,
685
+ )
686
+ plt.hist(dot_sal, label=stim_key, bins=50)
687
+ plt.axvline(0, color="k", linestyle="dashed", linewidth=1)
688
+ plt.title(" ", fontsize=fontsize)
689
+ plt.xlabel(" ", fontsize=fontsize)
690
+ plt.ylabel(" ", fontsize=fontsize)
691
+ plt.xticks(fontsize=fontsize)
692
+ plt.yticks(fontsize=fontsize)
693
+ ax = plt.gca()
694
+ ax.grid(False)
695
+
696
+ if save:
697
+ plt.savefig(save, bbox_inches="tight")
698
+ if show:
699
+ plt.show()
700
+ if not (show or save):
701
+ return ax
702
+ return None
703
+
704
+
705
+ # compatibility
706
+ SCGEN = Scgen