pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,706 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ import jax.numpy as jnp
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import scanpy as sc
10
+ from adjustText import adjust_text
11
+ from anndata import AnnData
12
+ from jax import Array
13
+ from lamin_utils import logger
14
+ from scipy import stats
15
+ from scvi import REGISTRY_KEYS
16
+ from scvi.data import AnnDataManager
17
+ from scvi.data.fields import CategoricalObsField, LayerField
18
+ from scvi.model.base import BaseModelClass, JaxTrainingMixin
19
+ from scvi.utils import setup_anndata_dsp
20
+
21
+ from ._scgenvae import JaxSCGENVAE
22
+ from ._utils import balancer, extractor
23
+
24
+ if TYPE_CHECKING:
25
+ from collections.abc import Sequence
26
+
27
+ font = {"family": "Arial", "size": 14}
28
+
29
+
30
+ class Scgen(JaxTrainingMixin, BaseModelClass):
31
+ """Jax Implementation of scGen model for batch removal and perturbation prediction."""
32
+
33
+ def __init__(
34
+ self,
35
+ adata: AnnData,
36
+ n_hidden: int = 800,
37
+ n_latent: int = 100,
38
+ n_layers: int = 2,
39
+ dropout_rate: float = 0.2,
40
+ **model_kwargs,
41
+ ):
42
+ super().__init__(adata)
43
+
44
+ self.module = JaxSCGENVAE(
45
+ n_input=self.summary_stats.n_vars,
46
+ n_hidden=n_hidden,
47
+ n_latent=n_latent,
48
+ n_layers=n_layers,
49
+ dropout_rate=dropout_rate,
50
+ **model_kwargs,
51
+ )
52
+ self._model_summary_string = (
53
+ f"Scgen Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
54
+ f"{dropout_rate}"
55
+ )
56
+ self.init_params_ = self._get_init_params(locals())
57
+
58
+ def predict(
59
+ self,
60
+ ctrl_key=None,
61
+ stim_key=None,
62
+ adata_to_predict=None,
63
+ celltype_to_predict=None,
64
+ restrict_arithmetic_to="all",
65
+ ) -> tuple[AnnData, Any]:
66
+ """Predicts the cell type provided by the user in stimulated condition.
67
+
68
+ Args:
69
+ ctrl_key: Key for `control` part of the `data` found in `condition_key`.
70
+ stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
71
+ adata_to_predict: Adata for unperturbed cells you want to be predicted.
72
+ celltype_to_predict: The cell type you want to be predicted.
73
+ restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
74
+
75
+ Returns:
76
+ `np nd-array` of predicted cells in primary space.
77
+ delta: float
78
+ Difference between stimulated and control cells in latent space
79
+
80
+ Examples:
81
+ >>> import pertpy as pt
82
+ >>> data = pt.dt.kang_2018()
83
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
84
+ >>> model = pt.tl.Scgen(data)
85
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
86
+ >>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
87
+ """
88
+ # use keys registered from `setup_anndata()`
89
+ cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
90
+ condition_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
91
+
92
+ if restrict_arithmetic_to == "all":
93
+ ctrl_x = self.adata[self.adata.obs[condition_key] == ctrl_key, :]
94
+ stim_x = self.adata[self.adata.obs[condition_key] == stim_key, :]
95
+ ctrl_x = balancer(ctrl_x, cell_type_key)
96
+ stim_x = balancer(stim_x, cell_type_key)
97
+ else:
98
+ key = list(restrict_arithmetic_to.keys())[0]
99
+ values = restrict_arithmetic_to[key]
100
+ subset = self.adata[self.adata.obs[key].isin(values)]
101
+ ctrl_x = subset[subset.obs[condition_key] == ctrl_key, :]
102
+ stim_x = subset[subset.obs[condition_key] == stim_key, :]
103
+ if len(values) > 1:
104
+ ctrl_x = balancer(ctrl_x, cell_type_key)
105
+ stim_x = balancer(stim_x, cell_type_key)
106
+ if celltype_to_predict is not None and adata_to_predict is not None:
107
+ raise Exception("Please provide either a cell type or adata not both!")
108
+ if celltype_to_predict is None and adata_to_predict is None:
109
+ raise Exception("Please provide a cell type name or adata for your unperturbed cells")
110
+ if celltype_to_predict is not None:
111
+ ctrl_pred = extractor(
112
+ self.adata,
113
+ celltype_to_predict,
114
+ condition_key,
115
+ cell_type_key,
116
+ ctrl_key,
117
+ stim_key,
118
+ )[1]
119
+ else:
120
+ ctrl_pred = adata_to_predict
121
+
122
+ eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0])
123
+ rng = np.random.default_rng()
124
+ cd_ind = rng.choice(range(ctrl_x.shape[0]), size=eq, replace=False)
125
+ stim_ind = rng.choice(range(stim_x.shape[0]), size=eq, replace=False)
126
+ ctrl_adata = ctrl_x[cd_ind, :]
127
+ stim_adata = stim_x[stim_ind, :]
128
+
129
+ latent_ctrl = self._avg_vector(ctrl_adata)
130
+ latent_stim = self._avg_vector(stim_adata)
131
+
132
+ delta = latent_stim - latent_ctrl
133
+
134
+ latent_cd = self.get_latent_representation(ctrl_pred)
135
+
136
+ stim_pred = delta + latent_cd
137
+ predicted_cells = self.module.as_bound().generative(stim_pred)["px"]
138
+
139
+ predicted_adata = AnnData(
140
+ X=np.array(predicted_cells),
141
+ obs=ctrl_pred.obs.copy(),
142
+ var=ctrl_pred.var.copy(),
143
+ obsm=ctrl_pred.obsm.copy(),
144
+ )
145
+ return predicted_adata, delta
146
+
147
+ def _avg_vector(self, adata):
148
+ return np.mean(self.get_latent_representation(adata), axis=0)
149
+
150
+ def get_decoded_expression(
151
+ self,
152
+ adata: AnnData | None = None,
153
+ indices: Sequence[int] | None = None,
154
+ batch_size: int | None = None,
155
+ ) -> Array:
156
+ """Get decoded expression.
157
+
158
+ Args:
159
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
160
+ AnnData object used to initialize the model.
161
+ indices: Indices of cells in adata to use. If `None`, all cells are used.
162
+ batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
163
+
164
+ Returns:
165
+ Decoded expression for each cell
166
+
167
+ Examples:
168
+ >>> import pertpy as pt
169
+ >>> data = pt.dt.kang_2018()
170
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
171
+ >>> model = pt.tl.Scgen(data)
172
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
173
+ >>> decoded_X = model.get_decoded_expression()
174
+ """
175
+ if self.is_trained_ is False:
176
+ raise RuntimeError("Please train the model first.")
177
+
178
+ adata = self._validate_anndata(adata)
179
+ scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
180
+ decoded = []
181
+ for tensors in scdl:
182
+ _, generative_outputs = self.module.as_bound()(tensors, compute_loss=False)
183
+ px = generative_outputs["px"]
184
+ decoded.append(px)
185
+
186
+ return jnp.concatenate(decoded)
187
+
188
+ def batch_removal(self, adata: AnnData | None = None) -> AnnData:
189
+ """Removes batch effects.
190
+
191
+ Args:
192
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
193
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
194
+ corresponding to batch and cell type metadata, respectively.
195
+
196
+ Returns:
197
+ corrected: `~anndata.AnnData`
198
+ AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
199
+ A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
200
+
201
+ Examples:
202
+ >>> import pertpy as pt
203
+ >>> data = pt.dt.kang_2018()
204
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
205
+ >>> model = pt.tl.Scgen(data)
206
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
207
+ >>> corrected_adata = model.batch_removal()
208
+ """
209
+ adata = self._validate_anndata(adata)
210
+ latent_all = self.get_latent_representation(adata)
211
+ # use keys registered from `setup_anndata()`
212
+ cell_label_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key
213
+ batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
214
+
215
+ adata_latent = AnnData(latent_all)
216
+ adata_latent.obs = adata.obs.copy(deep=True)
217
+ unique_cell_types = np.unique(adata_latent.obs[cell_label_key])
218
+ shared_ct = []
219
+ not_shared_ct = []
220
+ for cell_type in unique_cell_types:
221
+ temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
222
+ if len(np.unique(temp_cell.obs[batch_key])) < 2:
223
+ cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
224
+ not_shared_ct.append(cell_type_ann)
225
+ continue
226
+ temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy()
227
+ batch_list = {}
228
+ batch_ind = {}
229
+ max_batch = 0
230
+ max_batch_ind = ""
231
+ batches = np.unique(temp_cell.obs[batch_key])
232
+ for i in batches:
233
+ temp = temp_cell[temp_cell.obs[batch_key] == i]
234
+ temp_ind = temp_cell.obs[batch_key] == i
235
+ if max_batch < len(temp):
236
+ max_batch = len(temp)
237
+ max_batch_ind = i
238
+ batch_list[i] = temp
239
+ batch_ind[i] = temp_ind
240
+ max_batch_ann = batch_list[max_batch_ind]
241
+ for study in batch_list:
242
+ delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
243
+ batch_list[study].X = delta + batch_list[study].X
244
+ temp_cell[batch_ind[study]].X = batch_list[study].X
245
+ shared_ct.append(temp_cell)
246
+
247
+ all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
248
+ if "concat_batch" in all_shared_ann.obs.columns:
249
+ del all_shared_ann.obs["concat_batch"]
250
+ if len(not_shared_ct) < 1:
251
+ corrected = AnnData(
252
+ np.array(self.module.as_bound().generative(all_shared_ann.X)["px"]),
253
+ obs=all_shared_ann.obs,
254
+ )
255
+ corrected.var_names = adata.var_names.tolist()
256
+ corrected = corrected[adata.obs_names]
257
+ if adata.raw is not None:
258
+ adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
259
+ adata_raw.obs_names = adata.obs_names
260
+ corrected.raw = adata_raw
261
+ corrected.obsm["latent"] = all_shared_ann.X
262
+ corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
263
+ return corrected
264
+ else:
265
+ all_not_shared_ann = AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
266
+ all_corrected_data = AnnData.concatenate(
267
+ all_shared_ann,
268
+ all_not_shared_ann,
269
+ batch_key="concat_batch",
270
+ index_unique=None,
271
+ )
272
+ if "concat_batch" in all_shared_ann.obs.columns:
273
+ del all_corrected_data.obs["concat_batch"]
274
+ corrected = AnnData(
275
+ np.array(self.module.as_bound().generative(all_corrected_data.X)["px"]),
276
+ obs=all_corrected_data.obs,
277
+ )
278
+ corrected.var_names = adata.var_names.tolist()
279
+ corrected = corrected[adata.obs_names]
280
+ if adata.raw is not None:
281
+ adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var)
282
+ adata_raw.obs_names = adata.obs_names
283
+ corrected.raw = adata_raw
284
+ corrected.obsm["latent"] = all_corrected_data.X
285
+ corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected)
286
+
287
+ return corrected
288
+
289
+ @classmethod
290
+ @setup_anndata_dsp.dedent
291
+ def setup_anndata(
292
+ cls,
293
+ adata: AnnData,
294
+ batch_key: str | None = None,
295
+ labels_key: str | None = None,
296
+ **kwargs,
297
+ ):
298
+ """%(summary)s.
299
+
300
+ scGen expects the expression data to come from `adata.X`
301
+
302
+ %(param_batch_key)s
303
+ %(param_labels_key)s
304
+
305
+ Examples:
306
+ >>> import pertpy as pt
307
+ >>> data = pt.dt.kang_2018()
308
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
309
+ """
310
+ setup_method_args = cls._get_setup_method_args(**locals())
311
+ anndata_fields = [
312
+ LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=False),
313
+ CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
314
+ CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
315
+ ]
316
+ adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
317
+ adata_manager.register_fields(adata, **kwargs)
318
+ cls.register_manager(adata_manager)
319
+
320
+ def to_device(self, device):
321
+ pass
322
+
323
+ @property
324
+ def device(self):
325
+ return self.module.device
326
+
327
+ def get_latent_representation(
328
+ self,
329
+ adata: AnnData | None = None,
330
+ indices: Sequence[int] | None = None,
331
+ give_mean: bool = True,
332
+ n_samples: int = 1,
333
+ batch_size: int | None = None,
334
+ ) -> np.ndarray:
335
+ """Return the latent representation for each cell.
336
+
337
+ Args:
338
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
339
+ AnnData object used to initialize the model.
340
+ indices: Indices of cells in adata to use. If `None`, all cells are used.
341
+ batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
342
+
343
+ Returns:
344
+ Low-dimensional representation for each cell
345
+
346
+ Examples:
347
+ >>> import pertpy as pt
348
+ >>> data = pt.dt.kang_2018()
349
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
350
+ >>> model = pt.tl.Scgen(data)
351
+ >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
352
+ >>> latent_X = model.get_latent_representation()
353
+ """
354
+ self._check_if_trained(warn=False)
355
+
356
+ adata = self._validate_anndata(adata)
357
+ scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True)
358
+
359
+ jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"n_samples": n_samples})
360
+
361
+ latent = []
362
+ for array_dict in scdl:
363
+ out = jit_inference_fn(self.module.rngs, array_dict)
364
+ if give_mean:
365
+ z = out["qz"].mean
366
+ else:
367
+ z = out["z"]
368
+ latent.append(z)
369
+ concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
370
+ latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore
371
+
372
+ return self.module.as_numpy_array(latent)
373
+
374
+ def plot_reg_mean_plot(
375
+ self,
376
+ adata,
377
+ condition_key: str,
378
+ axis_keys: dict[str, str],
379
+ labels: dict[str, str],
380
+ save: str | bool | None = None,
381
+ gene_list: list[str] = None,
382
+ show: bool = False,
383
+ top_100_genes: list[str] = None,
384
+ verbose: bool = False,
385
+ legend: bool = True,
386
+ title: str = None,
387
+ x_coeff: float = 0.30,
388
+ y_coeff: float = 0.8,
389
+ fontsize: float = 14,
390
+ **kwargs,
391
+ ) -> tuple[float, float] | float:
392
+ """Plots mean matching for a set of specified genes.
393
+
394
+ Args:
395
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
396
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
397
+ corresponding to batch and cell type metadata, respectively.
398
+ condition_key: The key for the condition
399
+ axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
400
+ `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
401
+ labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
402
+ path_to_save: path to save the plot.
403
+ save: Specify if the plot should be saved or not.
404
+ gene_list: list of gene names to be plotted.
405
+ show: if `True`: will show to the plot after saving it.
406
+ top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
407
+ verbose: Specify if you want information to be printed while creating the plot.,
408
+ legend: Whether to plot a legend.
409
+ title: Set if you want the plot to display a title.
410
+ x_coeff: Offset to print the R^2 value in x-direction.
411
+ y_coeff: Offset to print the R^2 value in y-direction.
412
+ fontsize: Fontsize used for text in the plot.
413
+ **kwargs:
414
+
415
+ Examples:
416
+ >>> import pertpy as pt
417
+ >>> data = pt.dt.kang_2018()
418
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
419
+ >>> scg = pt.tl.Scgen(data)
420
+ >>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
421
+ >>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
422
+ >>> pred.obs['label'] = 'pred'
423
+ >>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred)
424
+ >>> r2_value = scg.plot_reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \
425
+ labels={"x": "predicted", "y": "ground truth"}, save=False, show=True)
426
+
427
+ Preview:
428
+ .. image:: /_static/docstring_previews/scgen_reg_mean.png
429
+ """
430
+ import seaborn as sns
431
+
432
+ sns.set_theme()
433
+ sns.set_theme(color_codes=True)
434
+
435
+ diff_genes = top_100_genes
436
+ stim = adata[adata.obs[condition_key] == axis_keys["y"]]
437
+ ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
438
+ if diff_genes is not None:
439
+ if hasattr(diff_genes, "tolist"):
440
+ diff_genes = diff_genes.tolist()
441
+ adata_diff = adata[:, diff_genes]
442
+ stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
443
+ ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
444
+ x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel()
445
+ y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
446
+ m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
447
+ if verbose:
448
+ logger.info("top_100 DEGs mean: ", r_value_diff**2)
449
+ x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
450
+ y = np.asarray(np.mean(stim.X, axis=0)).ravel()
451
+ m, b, r_value, p_value, std_err = stats.linregress(x, y)
452
+ if verbose:
453
+ logger.info("All genes mean: ", r_value**2)
454
+ df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
455
+ ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
456
+ ax.tick_params(labelsize=fontsize)
457
+ if "range" in kwargs:
458
+ start, stop, step = kwargs.get("range")
459
+ ax.set_xticks(np.arange(start, stop, step))
460
+ ax.set_yticks(np.arange(start, stop, step))
461
+ ax.set_xlabel(labels["x"], fontsize=fontsize)
462
+ ax.set_ylabel(labels["y"], fontsize=fontsize)
463
+ if gene_list is not None:
464
+ texts = []
465
+ for i in gene_list:
466
+ j = adata.var_names.tolist().index(i)
467
+ x_bar = x[j]
468
+ y_bar = y[j]
469
+ texts.append(plt.text(x_bar, y_bar, i, fontsize=11, color="black"))
470
+ plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
471
+ # if "y1" in axis_keys.keys():
472
+ # y1_bar = y1[j]
473
+ # plt.text(x_bar, y1_bar, i, fontsize=11, color="black")
474
+ if gene_list is not None:
475
+ adjust_text(
476
+ texts,
477
+ x=x,
478
+ y=y,
479
+ arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5},
480
+ force_static=(0.0, 0.0),
481
+ )
482
+ if legend:
483
+ plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
484
+ if title is None:
485
+ plt.title("", fontsize=fontsize)
486
+ else:
487
+ plt.title(title, fontsize=fontsize)
488
+ ax.text(
489
+ max(x) - max(x) * x_coeff,
490
+ max(y) - y_coeff * max(y),
491
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
492
+ fontsize=kwargs.get("textsize", fontsize),
493
+ )
494
+ if diff_genes is not None:
495
+ ax.text(
496
+ max(x) - max(x) * x_coeff,
497
+ max(y) - (y_coeff + 0.15) * max(y),
498
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
499
+ fontsize=kwargs.get("textsize", fontsize),
500
+ )
501
+ if save:
502
+ plt.savefig(save, bbox_inches="tight")
503
+ if show:
504
+ plt.show()
505
+ plt.close()
506
+ if diff_genes is not None:
507
+ return r_value**2, r_value_diff**2
508
+ else:
509
+ return r_value**2
510
+
511
+ def plot_reg_var_plot(
512
+ self,
513
+ adata,
514
+ condition_key: str,
515
+ axis_keys: dict[str, str],
516
+ labels: dict[str, str],
517
+ save: str | bool | None = None,
518
+ gene_list: list[str] = None,
519
+ top_100_genes: list[str] = None,
520
+ show: bool = False,
521
+ legend: bool = True,
522
+ title: str = None,
523
+ verbose: bool = False,
524
+ x_coeff: float = 0.3,
525
+ y_coeff: float = 0.8,
526
+ fontsize: float = 14,
527
+ **kwargs,
528
+ ) -> tuple[float, float] | float:
529
+ """Plots variance matching for a set of specified genes.
530
+
531
+ Args:
532
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
533
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
534
+ corresponding to batch and cell type metadata, respectively.
535
+ condition_key: Key of the condition.
536
+ axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:
537
+ `{"x": "Key for x-axis", "y": "Key for y-axis"}`.
538
+ labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`.
539
+ path_to_save: path to save the plot.
540
+ save: Specify if the plot should be saved or not.
541
+ gene_list: list of gene names to be plotted.
542
+ show: if `True`: will show to the plot after saving it.
543
+ top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
544
+ legend: Whether to plot a elgend
545
+ title: Set if you want the plot to display a title.
546
+ verbose: Specify if you want information to be printed while creating the plot.
547
+ x_coeff: Offset to print the R^2 value in x-direction.
548
+ y_coeff: Offset to print the R^2 value in y-direction.
549
+ fontsize: Fontsize used for text in the plot.
550
+ """
551
+ import seaborn as sns
552
+
553
+ sns.set_theme()
554
+ sns.set_theme(color_codes=True)
555
+
556
+ sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon")
557
+ diff_genes = top_100_genes
558
+ stim = adata[adata.obs[condition_key] == axis_keys["y"]]
559
+ ctrl = adata[adata.obs[condition_key] == axis_keys["x"]]
560
+ if diff_genes is not None:
561
+ if hasattr(diff_genes, "tolist"):
562
+ diff_genes = diff_genes.tolist()
563
+ adata_diff = adata[:, diff_genes]
564
+ stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]]
565
+ ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]]
566
+ x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel()
567
+ y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
568
+ m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
569
+ if verbose:
570
+ logger.info("Top 100 DEGs var: ", r_value_diff**2)
571
+ if "y1" in axis_keys.keys():
572
+ real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
573
+ x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
574
+ y = np.asarray(np.var(stim.X, axis=0)).ravel()
575
+ m, b, r_value, p_value, std_err = stats.linregress(x, y)
576
+ if verbose:
577
+ logger.info("All genes var: ", r_value**2)
578
+ df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
579
+ ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
580
+ ax.tick_params(labelsize=fontsize)
581
+ if "range" in kwargs:
582
+ start, stop, step = kwargs.get("range")
583
+ ax.set_xticks(np.arange(start, stop, step))
584
+ ax.set_yticks(np.arange(start, stop, step))
585
+ # _p1 = plt.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}")
586
+ # plt.plot(x, m * x + b, "-", color="green")
587
+ ax.set_xlabel(labels["x"], fontsize=fontsize)
588
+ ax.set_ylabel(labels["y"], fontsize=fontsize)
589
+ if "y1" in axis_keys.keys():
590
+ y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
591
+ _ = plt.scatter(
592
+ x,
593
+ y1,
594
+ marker="*",
595
+ c="grey",
596
+ alpha=0.5,
597
+ label=f"{axis_keys['x']}-{axis_keys['y1']}",
598
+ )
599
+ if gene_list is not None:
600
+ for i in gene_list:
601
+ j = adata.var_names.tolist().index(i)
602
+ x_bar = x[j]
603
+ y_bar = y[j]
604
+ plt.text(x_bar, y_bar, i, fontsize=11, color="black")
605
+ plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
606
+ if "y1" in axis_keys.keys():
607
+ y1_bar = y1[j]
608
+ plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
609
+ if legend:
610
+ plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
611
+ if title is None:
612
+ plt.title("", fontsize=12)
613
+ else:
614
+ plt.title(title, fontsize=12)
615
+ ax.text(
616
+ max(x) - max(x) * x_coeff,
617
+ max(y) - y_coeff * max(y),
618
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
619
+ fontsize=kwargs.get("textsize", fontsize),
620
+ )
621
+ if diff_genes is not None:
622
+ ax.text(
623
+ max(x) - max(x) * x_coeff,
624
+ max(y) - (y_coeff + 0.15) * max(y),
625
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
626
+ fontsize=kwargs.get("textsize", fontsize),
627
+ )
628
+
629
+ if save:
630
+ plt.savefig(save, bbox_inches="tight")
631
+ if show:
632
+ plt.show()
633
+ plt.close()
634
+ if diff_genes is not None:
635
+ return r_value**2, r_value_diff**2
636
+ else:
637
+ return r_value**2
638
+
639
+ def plot_binary_classifier(
640
+ self,
641
+ scgen: Scgen,
642
+ adata: AnnData | None,
643
+ delta: np.ndarray,
644
+ ctrl_key: str,
645
+ stim_key: str,
646
+ show: bool = False,
647
+ save: str | bool | None = None,
648
+ fontsize: float = 14,
649
+ ) -> plt.Axes | None:
650
+ """Plots the dot product between delta and latent representation of a linear classifier.
651
+
652
+ Builds a linear classifier based on the dot product between
653
+ the difference vector and the latent representation of each
654
+ cell and plots the dot product results between delta and latent representation.
655
+
656
+ Args:
657
+ scgen: ScGen object that was trained.
658
+ adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
659
+ AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
660
+ corresponding to batch and cell type metadata, respectively.
661
+ delta: Difference between stimulated and control cells in latent space
662
+ ctrl_key: Key for `control` part of the `data` found in `condition_key`.
663
+ stim_key: Key for `stimulated` part of the `data` found in `condition_key`.
664
+ path_to_save: Path to save the plot.
665
+ save: Specify if the plot should be saved or not.
666
+ fontsize: Set the font size of the plot.
667
+ """
668
+ plt.close("all")
669
+ adata = scgen._validate_anndata(adata)
670
+ condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
671
+ cd = adata[adata.obs[condition_key] == ctrl_key, :]
672
+ stim = adata[adata.obs[condition_key] == stim_key, :]
673
+ all_latent_cd = scgen.get_latent_representation(cd.X)
674
+ all_latent_stim = scgen.get_latent_representation(stim.X)
675
+ dot_cd = np.zeros(len(all_latent_cd))
676
+ dot_sal = np.zeros(len(all_latent_stim))
677
+ for ind, vec in enumerate(all_latent_cd):
678
+ dot_cd[ind] = np.dot(delta, vec)
679
+ for ind, vec in enumerate(all_latent_stim):
680
+ dot_sal[ind] = np.dot(delta, vec)
681
+ plt.hist(
682
+ dot_cd,
683
+ label=ctrl_key,
684
+ bins=50,
685
+ )
686
+ plt.hist(dot_sal, label=stim_key, bins=50)
687
+ plt.axvline(0, color="k", linestyle="dashed", linewidth=1)
688
+ plt.title(" ", fontsize=fontsize)
689
+ plt.xlabel(" ", fontsize=fontsize)
690
+ plt.ylabel(" ", fontsize=fontsize)
691
+ plt.xticks(fontsize=fontsize)
692
+ plt.yticks(fontsize=fontsize)
693
+ ax = plt.gca()
694
+ ax.grid(False)
695
+
696
+ if save:
697
+ plt.savefig(save, bbox_inches="tight")
698
+ if show:
699
+ plt.show()
700
+ if not (show or save):
701
+ return ax
702
+ return None
703
+
704
+
705
+ # compatibility
706
+ SCGEN = Scgen