pertpy 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +1 -3
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +133 -25
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +138 -44
  11. pertpy/preprocessing/_guide_rna_mixture.py +17 -19
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +106 -98
  14. pertpy/tools/_cinemaot.py +74 -114
  15. pertpy/tools/_coda/_base_coda.py +129 -145
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +48 -40
  19. pertpy/tools/_differential_gene_expression/_base.py +21 -31
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +31 -45
  28. pertpy/tools/_enrichment.py +7 -22
  29. pertpy/tools/_milo.py +19 -15
  30. pertpy/tools/_mixscape.py +73 -75
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +12 -14
  36. pertpy/tools/_scgen/_scgen.py +16 -17
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/METADATA +36 -20
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.10.0.dist-info/RECORD +0 -58
  44. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
@@ -5,10 +5,10 @@ import warnings
5
5
  import anndata
6
6
  import numpy as np
7
7
  import pandas as pd
8
- import pytorch_lightning as pl
9
8
  import scipy
10
9
  import torch
11
10
  from anndata import AnnData
11
+ from pytorch_lightning import LightningModule, Trainer
12
12
  from pytorch_lightning.callbacks import EarlyStopping
13
13
  from sklearn.linear_model import LogisticRegression
14
14
  from sklearn.model_selection import train_test_split
@@ -35,9 +35,7 @@ class LRClassifierSpace(PerturbationSpace):
35
35
  test_split_size: float = 0.2,
36
36
  max_iter: int = 1000,
37
37
  ):
38
- """
39
- Fits a logistic regression model to the data and takes the coefficients of the logistic regression
40
- model as perturbation embedding.
38
+ """Fits a logistic regression model to the data and takes the coefficients of the logistic regression model as perturbation embedding.
41
39
 
42
40
  Args:
43
41
  adata: AnnData object of size cells x genes
@@ -60,7 +58,7 @@ class LRClassifierSpace(PerturbationSpace):
60
58
  if layer_key is not None and layer_key not in adata.obs.columns:
61
59
  raise ValueError(f"Layer key {layer_key} not found in adata.")
62
60
 
63
- if embedding_key is not None and embedding_key not in adata.obsm.keys():
61
+ if embedding_key is not None and embedding_key not in adata.obsm:
64
62
  raise ValueError(f"Embedding key {embedding_key} not found in adata.obsm.")
65
63
 
66
64
  if layer_key is not None and embedding_key is not None:
@@ -207,7 +205,7 @@ class MLPClassifierSpace(PerturbationSpace):
207
205
  adata.obs["encoded_perturbations"] = [np.float32(label) for label in encoded_labels]
208
206
 
209
207
  # Split the data in train, test and validation
210
- X = list(range(0, adata.n_obs))
208
+ X = list(range(adata.n_obs))
211
209
  y = adata.obs[target_col]
212
210
 
213
211
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
@@ -248,7 +246,7 @@ class MLPClassifierSpace(PerturbationSpace):
248
246
  # Save adata observations for embedding annotations in get_embeddings
249
247
  self.adata_obs = adata.obs.reset_index(drop=True)
250
248
 
251
- self.trainer = pl.Trainer(
249
+ self.trainer = Trainer(
252
250
  min_epochs=1,
253
251
  max_epochs=max_epochs,
254
252
  check_val_every_n_epoch=val_epochs_check,
@@ -273,7 +271,7 @@ class MLPClassifierSpace(PerturbationSpace):
273
271
  if dataset_count == 0:
274
272
  pert_adata = batch_adata
275
273
  else:
276
- pert_adata = anndata.concat([pert_adata, batch_adata])
274
+ pert_adata = batch_adata if dataset_count == 0 else anndata.concat([pert_adata, batch_adata])
277
275
 
278
276
  # Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
279
277
  # and we can just add the annotations from the original AnnData object
@@ -308,9 +306,7 @@ class MLPClassifierSpace(PerturbationSpace):
308
306
 
309
307
 
310
308
  class MLP(torch.nn.Module):
311
- """
312
- A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
313
- """
309
+ """A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm."""
314
310
 
315
311
  def __init__(
316
312
  self,
@@ -320,7 +316,8 @@ class MLP(torch.nn.Module):
320
316
  layer_norm: bool = False,
321
317
  last_layer_act: str = "linear",
322
318
  ) -> None:
323
- """
319
+ """Multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
320
+
324
321
  Args:
325
322
  sizes: size of layers.
326
323
  dropout: Dropout probability.
@@ -375,8 +372,8 @@ def init_weights(m):
375
372
 
376
373
 
377
374
  class PLDataset(Dataset):
378
- """
379
- Dataset for perturbation classification.
375
+ """Dataset for perturbation classification.
376
+
380
377
  Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
381
378
  """
382
379
 
@@ -387,14 +384,14 @@ class PLDataset(Dataset):
387
384
  label_col: str = "perturbations",
388
385
  layer_key: str = None,
389
386
  ):
390
- """
387
+ """PyTorch lightning Dataset for perturbation classification.
388
+
391
389
  Args:
392
390
  adata: AnnData object with observations and labels.
393
391
  target_col: key with the perturbation labels numerically encoded.
394
392
  label_col: key with the perturbation labels.
395
- layer_key: key of the layer to be used as data, otherwise .X
393
+ layer_key: key of the layer to be used as data, otherwise .X.
396
394
  """
397
-
398
395
  if layer_key:
399
396
  self.data = adata.layers[layer_key]
400
397
  else:
@@ -407,7 +404,7 @@ class PLDataset(Dataset):
407
404
  return self.data.shape[0]
408
405
 
409
406
  def __getitem__(self, idx):
410
- """Returns a sample and corresponding perturbations applied (labels)"""
407
+ """Returns a sample and corresponding perturbations applied (labels)."""
411
408
  sample = self.data[idx].toarray().squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
412
409
  num_label = self.labels.iloc[idx]
413
410
  str_label = self.pert_labels.iloc[idx]
@@ -415,7 +412,7 @@ class PLDataset(Dataset):
415
412
  return sample, num_label, str_label
416
413
 
417
414
 
418
- class PerturbationClassifier(pl.LightningModule):
415
+ class PerturbationClassifier(LightningModule):
419
416
  def __init__(
420
417
  self,
421
418
  model: torch.nn.Module,
@@ -428,7 +425,8 @@ class PerturbationClassifier(pl.LightningModule):
428
425
  lr=1e-4,
429
426
  seed=42,
430
427
  ):
431
- """
428
+ """Perturbation Classifier.
429
+
432
430
  Args:
433
431
  model: model to be trained
434
432
  batch_size: batch size
@@ -438,7 +436,7 @@ class PerturbationClassifier(pl.LightningModule):
438
436
  layer_norm: whether to apply layer norm
439
437
  last_layer_act: activation function of last layer
440
438
  lr: learning rate
441
- seed: random seed
439
+ seed: random seed.
442
440
  """
443
441
  super().__init__()
444
442
  self.batch_size = batch_size
@@ -457,16 +455,37 @@ class PerturbationClassifier(pl.LightningModule):
457
455
  last_layer_act=self.hparams.last_layer_act,
458
456
  )
459
457
 
460
- def forward(self, x):
458
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
459
+ """Forward pass through the network.
460
+
461
+ Args:
462
+ x: Input tensor
463
+
464
+ Returns:
465
+ Network output tensor
466
+ """
461
467
  x = self.net(x)
462
468
  return x
463
469
 
464
- def configure_optimizers(self):
465
- optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
470
+ def configure_optimizers(self) -> optim.Adam:
471
+ """Configure optimizer for the model.
466
472
 
473
+ Returns:
474
+ Adam optimizer with weight decay
475
+ """
476
+ optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
467
477
  return optimizer
468
478
 
469
- def training_step(self, batch, batch_idx):
479
+ def training_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
480
+ """Perform a training step.
481
+
482
+ Args:
483
+ batch: Tuple of (input, target, metadata)
484
+ batch_idx: Index of the current batch
485
+
486
+ Returns:
487
+ Loss value
488
+ """
470
489
  x, y, _ = batch
471
490
  x = x.to(torch.float32)
472
491
 
@@ -480,7 +499,16 @@ class PerturbationClassifier(pl.LightningModule):
480
499
 
481
500
  return loss
482
501
 
483
- def validation_step(self, batch, batch_idx):
502
+ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
503
+ """Perform a validation step.
504
+
505
+ Args:
506
+ batch: Tuple of (input, target, metadata)
507
+ batch_idx: Index of the current batch
508
+
509
+ Returns:
510
+ Loss value
511
+ """
484
512
  x, y, _ = batch
485
513
  x = x.to(torch.float32)
486
514
 
@@ -494,7 +522,16 @@ class PerturbationClassifier(pl.LightningModule):
494
522
 
495
523
  return loss
496
524
 
497
- def test_step(self, batch, batch_idx):
525
+ def test_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
526
+ """Perform a test step.
527
+
528
+ Args:
529
+ batch: Tuple of (input, target, metadata)
530
+ batch_idx: Index of the current batch
531
+
532
+ Returns:
533
+ Loss value
534
+ """
498
535
  x, y, _ = batch
499
536
  x = x.to(torch.float32)
500
537
 
@@ -508,15 +545,29 @@ class PerturbationClassifier(pl.LightningModule):
508
545
 
509
546
  return loss
510
547
 
511
- def embedding(self, x):
512
- """
513
- Inputs:
514
- x: Input features of shape [Batch, SeqLen, 1]
548
+ def embedding(self, x: torch.Tensor) -> torch.Tensor:
549
+ """Extract embeddings from input features.
550
+
551
+ Args:
552
+ x: Input tensor of shape [Batch, SeqLen, 1]
553
+
554
+ Returns:
555
+ Embedded representation of the input
515
556
  """
516
557
  x = self.net.embedding(x)
517
558
  return x
518
559
 
519
- def get_embeddings(self, batch):
560
+ def get_embeddings(
561
+ self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
562
+ ) -> tuple[torch.Tensor, torch.Tensor]:
563
+ """Extract embeddings from a batch.
564
+
565
+ Args:
566
+ batch: Tuple of (input, target, metadata)
567
+
568
+ Returns:
569
+ Tuple of (embeddings, metadata)
570
+ """
520
571
  x, _, y = batch
521
572
  x = x.to(torch.float32)
522
573
 
@@ -70,7 +70,7 @@ class PerturbationSpace:
70
70
  if embedding_key is not None and embedding_key not in adata.obsm_keys():
71
71
  raise ValueError(f"Embedding key {embedding_key} not found in obsm keys of the anndata.")
72
72
 
73
- if layer_key is not None and layer_key not in adata.layers.keys():
73
+ if layer_key is not None and layer_key not in adata.layers:
74
74
  raise ValueError(f"Layer {layer_key!r} does not exist in the anndata.")
75
75
 
76
76
  if copy:
@@ -123,7 +123,7 @@ class PerturbationSpace:
123
123
  if all_data:
124
124
  layers_keys = list(adata.layers.keys())
125
125
  for local_layer_key in layers_keys:
126
- if local_layer_key != layer_key and local_layer_key != new_layer_key:
126
+ if local_layer_key not in (layer_key, new_layer_key):
127
127
  adata.layers[local_layer_key + "_control_diff"] = np.zeros((adata.n_obs, adata.n_vars))
128
128
  for mask in group_masks:
129
129
  adata.layers[local_layer_key + "_control_diff"][mask, :] = adata.layers[local_layer_key][
@@ -132,7 +132,7 @@ class PerturbationSpace:
132
132
 
133
133
  embedding_keys = list(adata.obsm_keys())
134
134
  for local_embedding_key in embedding_keys:
135
- if local_embedding_key != embedding_key and local_embedding_key != new_embedding_key:
135
+ if local_embedding_key not in (embedding_key, new_embedding_key):
136
136
  adata.obsm[local_embedding_key + "_control_diff"] = np.zeros(adata.obsm[local_embedding_key].shape)
137
137
  for mask in group_masks:
138
138
  adata.obsm[local_embedding_key + "_control_diff"][mask, :] = adata.obsm[local_embedding_key][
@@ -193,7 +193,7 @@ class PerturbationSpace:
193
193
 
194
194
  data: dict[str, np.array] = {}
195
195
 
196
- for local_layer_key in adata.layers.keys():
196
+ for local_layer_key in adata.layers:
197
197
  data["layers"] = {}
198
198
  control_local = adata[reference_key].layers[local_layer_key].copy()
199
199
  for perturbation in perturbations:
@@ -231,14 +231,14 @@ class PerturbationSpace:
231
231
  new_obs.loc[new_pert_name[:-1]] = new_pert_obs
232
232
  new_perturbation.obs = new_obs
233
233
 
234
- if "layers" in data.keys():
234
+ if "layers" in data:
235
235
  for key in data["layers"]:
236
236
  key_name = key
237
237
  if key.endswith("_control_diff"):
238
238
  key_name = key.removesuffix("_control_diff")
239
239
  new_perturbation.layers[key_name] = data["layers"][key]
240
240
 
241
- if "embeddings" in data.keys():
241
+ if "embeddings" in data:
242
242
  key_name = key
243
243
  for key in data["embeddings"]:
244
244
  if key.endswith("_control_diff"):
@@ -260,7 +260,7 @@ class PerturbationSpace:
260
260
  ensure_consistency: bool = False,
261
261
  target_col: str = "perturbation",
262
262
  ) -> tuple[AnnData, AnnData] | AnnData:
263
- """Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
263
+ """Subtract perturbations linearly. Assumes input of size n_perts x dimensionality.
264
264
 
265
265
  Args:
266
266
  adata: Anndata object of size n_perts x dim.
@@ -302,7 +302,7 @@ class PerturbationSpace:
302
302
 
303
303
  data: dict[str, np.array] = {}
304
304
 
305
- for local_layer_key in adata.layers.keys():
305
+ for local_layer_key in adata.layers:
306
306
  data["layers"] = {}
307
307
  control_local = adata[reference_key].layers[local_layer_key].copy()
308
308
  for perturbation in perturbations:
@@ -340,14 +340,14 @@ class PerturbationSpace:
340
340
  new_obs.loc[new_pert_name[:-1]] = new_pert_obs
341
341
  new_perturbation.obs = new_obs
342
342
 
343
- if "layers" in data.keys():
343
+ if "layers" in data:
344
344
  for key in data["layers"]:
345
345
  key_name = key
346
346
  if key.endswith("_control_diff"):
347
347
  key_name = key.removesuffix("_control_diff")
348
348
  new_perturbation.layers[key_name] = data["layers"][key]
349
349
 
350
- if "embeddings" in data.keys():
350
+ if "embeddings" in data:
351
351
  key_name = key
352
352
  for key in data["embeddings"]:
353
353
  if key.endswith("_control_diff"):
@@ -2,10 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
5
- import decoupler as dc
6
5
  import matplotlib.pyplot as plt
7
6
  import numpy as np
8
7
  from anndata import AnnData
8
+ from decoupler import get_pseudobulk as dc_get_pseudobulk
9
+ from decoupler import plot_psbulk_samples as dc_plot_psbulk_samples
9
10
  from sklearn.cluster import DBSCAN, KMeans
10
11
 
11
12
  from pertpy._doc import _doc_params, doc_common_plot_args
@@ -53,7 +54,6 @@ class CentroidSpace(PerturbationSpace):
53
54
  >>> cs = pt.tl.CentroidSpace()
54
55
  >>> cs_adata = cs.compute(mdata["rna"], target_col="gene_target")
55
56
  """
56
-
57
57
  X = None
58
58
  if layer_key is not None and embedding_key is not None:
59
59
  raise ValueError("Please, select just either layer or embedding for computation.")
@@ -65,7 +65,7 @@ class CentroidSpace(PerturbationSpace):
65
65
  X = np.empty((len(adata.obs[target_col].unique()), adata.obsm[embedding_key].shape[1]))
66
66
 
67
67
  if layer_key is not None:
68
- if layer_key not in adata.layers.keys():
68
+ if layer_key not in adata.layers:
69
69
  raise ValueError(f"Layer {layer_key!r} does not exist in the .layers attribute.")
70
70
  else:
71
71
  X = np.empty((len(adata.obs[target_col].unique()), adata.layers[layer_key].shape[1]))
@@ -79,8 +79,7 @@ class CentroidSpace(PerturbationSpace):
79
79
  X = np.empty((len(adata.obs[target_col].unique()), adata.obsm[embedding_key].shape[1]))
80
80
 
81
81
  index = []
82
- pert_index = 0
83
- for group_name, group_data in grouped:
82
+ for pert_index, (group_name, group_data) in enumerate(grouped):
84
83
  indices = group_data.index
85
84
  if layer_key is not None:
86
85
  points = adata[indices].layers[layer_key]
@@ -94,7 +93,6 @@ class CentroidSpace(PerturbationSpace):
94
93
  points, key=lambda point: np.linalg.norm(point - centroid)
95
94
  ) # Find the point in the array closest to the centroid
96
95
  X[pert_index, :] = closest_point
97
- pert_index += 1
98
96
 
99
97
  ps_adata = AnnData(X=X)
100
98
  ps_adata.obs_names = index
@@ -153,7 +151,7 @@ class PseudobulkSpace(PerturbationSpace):
153
151
  if layer_key is not None and embedding_key is not None:
154
152
  raise ValueError("Please, select just either layer or embedding for computation.")
155
153
 
156
- if layer_key is not None and layer_key not in adata.layers.keys():
154
+ if layer_key is not None and layer_key not in adata.layers:
157
155
  raise ValueError(f"Layer {layer_key!r} does not exist in the .layers attribute.")
158
156
 
159
157
  if target_col not in adata.obs:
@@ -169,14 +167,14 @@ class PseudobulkSpace(PerturbationSpace):
169
167
  adata = adata_emb
170
168
 
171
169
  adata.obs[target_col] = adata.obs[target_col].astype("category")
172
- ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore
170
+ ps_adata = dc_get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore
173
171
 
174
172
  ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
175
173
 
176
174
  return ps_adata
177
175
 
178
176
  @_doc_params(common_plot_args=doc_common_plot_args)
179
- def plot_psbulk_samples(
177
+ def plot_psbulk_samples( # pragma: no cover # noqa: D417
180
178
  self,
181
179
  adata: AnnData,
182
180
  groupby: str,
@@ -209,7 +207,7 @@ class PseudobulkSpace(PerturbationSpace):
209
207
  Preview:
210
208
  .. image:: /_static/docstring_previews/pseudobulk_samples.png
211
209
  """
212
- fig = dc.plot_psbulk_samples(adata, groupby, return_fig=True, **kwargs)
210
+ fig = dc_plot_psbulk_samples(adata, groupby, return_fig=True, **kwargs)
213
211
 
214
212
  if return_fig:
215
213
  return fig
@@ -244,7 +242,7 @@ class KMeansSpace(ClusteringSpace):
244
242
  Returns:
245
243
  If return_object is True, the adata and the clustering object is returned.
246
244
  Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
247
- that stores the cluster labels.
245
+ that stores the cluster labels.
248
246
 
249
247
  Examples:
250
248
  >>> import pertpy as pt
@@ -265,7 +263,7 @@ class KMeansSpace(ClusteringSpace):
265
263
  self.X = adata.obsm[embedding_key]
266
264
 
267
265
  elif layer_key is not None:
268
- if layer_key not in adata.layers.keys():
266
+ if layer_key not in adata.layers:
269
267
  raise ValueError(f"Layer {layer_key!r} does not exist in the anndata.")
270
268
  else:
271
269
  self.X = adata.layers[layer_key]
@@ -284,7 +282,7 @@ class KMeansSpace(ClusteringSpace):
284
282
 
285
283
 
286
284
  class DBSCANSpace(ClusteringSpace):
287
- """Cluster the given data using DBSCAN"""
285
+ """Cluster the given data using DBSCAN."""
288
286
 
289
287
  def compute( # type: ignore
290
288
  self,
@@ -328,7 +326,7 @@ class DBSCANSpace(ClusteringSpace):
328
326
  self.X = adata.obsm[embedding_key]
329
327
 
330
328
  elif layer_key is not None:
331
- if layer_key not in adata.layers.keys():
329
+ if layer_key not in adata.layers:
332
330
  raise ValueError(f"Layer {layer_key!r} does not exist in the anndata.")
333
331
  else:
334
332
  self.X = adata.layers[layer_key]
@@ -77,7 +77,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
77
77
  restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
78
78
 
79
79
  Returns:
80
- `np nd-array` of predicted cells in primary space.
80
+ :class:`numpy.ndarray` of predicted cells in primary space.
81
81
  delta: float
82
82
  Difference between stimulated and control cells in latent space
83
83
 
@@ -198,7 +198,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
198
198
  corresponding to batch and cell type metadata, respectively.
199
199
 
200
200
  Returns:
201
- corrected: `~anndata.AnnData`
201
+ A corrected `~anndata.AnnData` object.
202
202
  AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
203
203
  A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
204
204
 
@@ -343,6 +343,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
343
343
  AnnData object used to initialize the model.
344
344
  indices: Indices of cells in adata to use. If `None`, all cells are used.
345
345
  batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
346
+ give_mean: Whether to return the mean
347
+ n_samples: The number of samples to use.
346
348
 
347
349
  Returns:
348
350
  Low-dimensional representation for each cell
@@ -365,17 +367,14 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
365
367
  latent = []
366
368
  for array_dict in scdl:
367
369
  out = jit_inference_fn(self.module.rngs, array_dict)
368
- if give_mean:
369
- z = out["qz"].mean
370
- else:
371
- z = out["z"]
370
+ z = out["qz"].mean if give_mean else out["z"]
372
371
  latent.append(z)
373
372
  concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
374
373
  latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore
375
374
 
376
375
  return self.module.as_numpy_array(latent)
377
376
 
378
- def plot_reg_mean_plot(
377
+ def plot_reg_mean_plot( # pragma: no cover # noqa: D417
379
378
  self,
380
379
  adata,
381
380
  condition_key: str,
@@ -495,14 +494,14 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
495
494
  ax.text(
496
495
  max(x) - max(x) * x_coeff,
497
496
  max(y) - y_coeff * max(y),
498
- r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
497
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value**2:.2f}",
499
498
  fontsize=kwargs.get("textsize", fontsize),
500
499
  )
501
500
  if diff_genes is not None:
502
501
  ax.text(
503
502
  max(x) - max(x) * x_coeff,
504
503
  max(y) - (y_coeff + 0.15) * max(y),
505
- r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
504
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff**2:.2f}",
506
505
  fontsize=kwargs.get("textsize", fontsize),
507
506
  )
508
507
 
@@ -516,7 +515,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
516
515
  else:
517
516
  return r_value**2
518
517
 
519
- def plot_reg_var_plot(
518
+ def plot_reg_var_plot( # pragma: no cover # noqa: D417
520
519
  self,
521
520
  adata,
522
521
  condition_key: str,
@@ -576,7 +575,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
576
575
  m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
577
576
  if verbose:
578
577
  logger.info("Top 100 DEGs var: ", r_value_diff**2)
579
- if "y1" in axis_keys.keys():
578
+ if "y1" in axis_keys:
580
579
  real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
581
580
  x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
582
581
  y = np.asarray(np.var(stim.X, axis=0)).ravel()
@@ -594,7 +593,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
594
593
  # plt.plot(x, m * x + b, "-", color="green")
595
594
  ax.set_xlabel(labels["x"], fontsize=fontsize)
596
595
  ax.set_ylabel(labels["y"], fontsize=fontsize)
597
- if "y1" in axis_keys.keys():
596
+ if "y1" in axis_keys:
598
597
  y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
599
598
  _ = plt.scatter(
600
599
  x,
@@ -611,7 +610,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
611
610
  y_bar = y[j]
612
611
  plt.text(x_bar, y_bar, i, fontsize=11, color="black")
613
612
  plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
614
- if "y1" in axis_keys.keys():
613
+ if "y1" in axis_keys:
615
614
  y1_bar = y1[j]
616
615
  plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
617
616
  if legend:
@@ -623,14 +622,14 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
623
622
  ax.text(
624
623
  max(x) - max(x) * x_coeff,
625
624
  max(y) - y_coeff * max(y),
626
- r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}",
625
+ r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value**2:.2f}",
627
626
  fontsize=kwargs.get("textsize", fontsize),
628
627
  )
629
628
  if diff_genes is not None:
630
629
  ax.text(
631
630
  max(x) - max(x) * x_coeff,
632
631
  max(y) - (y_coeff + 0.15) * max(y),
633
- r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}",
632
+ r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff**2:.2f}",
634
633
  fontsize=kwargs.get("textsize", fontsize),
635
634
  )
636
635
 
@@ -645,7 +644,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
645
644
  return r_value**2
646
645
 
647
646
  @_doc_params(common_plot_args=doc_common_plot_args)
648
- def plot_binary_classifier(
647
+ def plot_binary_classifier( # pragma: no cover # noqa: D417
649
648
  self,
650
649
  scgen: Scgen,
651
650
  adata: AnnData | None,
@@ -665,7 +664,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
665
664
  Args:
666
665
  scgen: ScGen object that was trained.
667
666
  adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
668
- AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,
667
+ AnnData object used to initialize the model. Must have been set up with `batch_key` and `labels_key`,
669
668
  corresponding to batch and cell type metadata, respectively.
670
669
  delta: Difference between stimulated and control cells in latent space
671
670
  ctrl_key: Key for `control` part of the `data` found in `condition_key`.
@@ -24,8 +24,8 @@ class JaxSCGENVAE(JaxBaseModuleClass):
24
24
  training: bool = True
25
25
 
26
26
  def setup(self):
27
- use_batch_norm_encoder = self.use_batch_norm == "encoder" or self.use_batch_norm == "both"
28
- use_layer_norm_encoder = self.use_layer_norm == "encoder" or self.use_layer_norm == "both"
27
+ use_batch_norm_encoder = self.use_batch_norm in ("encoder", "both")
28
+ use_layer_norm_encoder = self.use_layer_norm in ("encoder", "both")
29
29
 
30
30
  self.encoder = FlaxEncoder(
31
31
  n_latent=self.n_latent,
@@ -32,7 +32,9 @@ def extractor(
32
32
 
33
33
  train_data = anndata.read("./data/train.h5ad")
34
34
  test_data = anndata.read("./data/test.h5ad")
35
- train_data_extracted_list = extractor(train_data, "CD4T", "conditions", "cell_type", "control", "stimulated")
35
+ train_data_extracted_list = extractor(
36
+ train_data, "CD4T", "conditions", "cell_type", "control", "stimulated"
37
+ )
36
38
  """
37
39
  cell_with_both_condition = data[data.obs[cell_type_key] == cell_type]
38
40
  condition_1 = data[(data.obs[cell_type_key] == cell_type) & (data.obs[condition_key] == ctrl_key)]