pertpy 0.7.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. pertpy/__init__.py +2 -1
  2. pertpy/data/__init__.py +61 -0
  3. pertpy/data/_dataloader.py +27 -23
  4. pertpy/data/_datasets.py +58 -0
  5. pertpy/metadata/__init__.py +2 -0
  6. pertpy/metadata/_cell_line.py +39 -70
  7. pertpy/metadata/_compound.py +3 -4
  8. pertpy/metadata/_drug.py +2 -6
  9. pertpy/metadata/_look_up.py +38 -51
  10. pertpy/metadata/_metadata.py +7 -10
  11. pertpy/metadata/_moa.py +2 -6
  12. pertpy/plot/__init__.py +0 -5
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +6 -7
  15. pertpy/tools/__init__.py +67 -6
  16. pertpy/tools/_augur.py +14 -15
  17. pertpy/tools/_cinemaot.py +2 -2
  18. pertpy/tools/_coda/_base_coda.py +118 -142
  19. pertpy/tools/_coda/_sccoda.py +16 -15
  20. pertpy/tools/_coda/_tasccoda.py +21 -22
  21. pertpy/tools/_dialogue.py +18 -23
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +21 -16
  32. pertpy/tools/_distances/_distances.py +406 -70
  33. pertpy/tools/_enrichment.py +10 -15
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +77 -54
  36. pertpy/tools/_mixscape.py +15 -11
  37. pertpy/tools/_perturbation_space/_clustering.py +5 -2
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +21 -23
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
  41. pertpy/tools/_perturbation_space/_simple.py +3 -3
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +33 -28
  45. pertpy/tools/_scgen/_utils.py +2 -2
  46. {pertpy-0.7.0.dist-info → pertpy-0.9.1.dist-info}/METADATA +32 -14
  47. pertpy-0.9.1.dist-info/RECORD +57 -0
  48. {pertpy-0.7.0.dist-info → pertpy-0.9.1.dist-info}/WHEEL +1 -1
  49. pertpy/plot/_augur.py +0 -171
  50. pertpy/plot/_coda.py +0 -601
  51. pertpy/plot/_guide_rna.py +0 -64
  52. pertpy/plot/_milopy.py +0 -209
  53. pertpy/plot/_mixscape.py +0 -355
  54. pertpy/tools/_differential_gene_expression.py +0 -325
  55. pertpy-0.7.0.dist-info/RECORD +0 -53
  56. {pertpy-0.7.0.dist-info → pertpy-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,112 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import numpy as np
4
+ import pynndescent
5
+ from scipy.sparse import issparse
6
+ from scipy.sparse import vstack as sp_vstack
7
+ from sklearn.base import ClassifierMixin
8
+ from sklearn.linear_model import LogisticRegression
9
+
10
+ if TYPE_CHECKING:
11
+ from numpy.typing import NDArray
12
+
13
+
14
+ class PerturbationComparison:
15
+ """Comparison between real and simulated perturbations."""
16
+
17
+ def compare_classification(
18
+ self,
19
+ real: np.ndarray,
20
+ simulated: np.ndarray,
21
+ control: np.ndarray,
22
+ clf: ClassifierMixin | None = None,
23
+ ) -> float:
24
+ """Compare classification accuracy between real and simulated perturbations.
25
+
26
+ Trains a classifier on the real perturbation data + the control data and reports a normalized
27
+ classification accuracy on the simulated perturbation.
28
+
29
+ Args:
30
+ real: Real perturbed data.
31
+ simulated: Simulated perturbed data.
32
+ control: Control data
33
+ clf: sklearn classifier to use, `sklearn.linear_model.LogisticRegression` if not provided.
34
+ """
35
+ assert real.shape[1] == simulated.shape[1] == control.shape[1]
36
+ if clf is None:
37
+ clf = LogisticRegression()
38
+ n_x = real.shape[0]
39
+ data = sp_vstack((real, control)) if issparse(real) else np.vstack((real, control))
40
+ labels = np.concatenate([np.full(real.shape[0], "comp"), np.full(control.shape[0], "ctrl")])
41
+
42
+ clf.fit(data, labels)
43
+ norm_score = clf.score(simulated, np.full(simulated.shape[0], "comp")) / clf.score(real, labels[:n_x])
44
+ norm_score = min(1.0, norm_score)
45
+
46
+ return norm_score
47
+
48
+ def compare_knn(
49
+ self,
50
+ real: np.ndarray,
51
+ simulated: np.ndarray,
52
+ control: np.ndarray | None = None,
53
+ use_simulated_for_knn: bool = False,
54
+ n_neighbors: int = 20,
55
+ random_state: int = 0,
56
+ n_jobs: int = 1,
57
+ ) -> dict[str, float]:
58
+ """Calculate proportions of real perturbed and control data points for simulated data.
59
+
60
+ Computes proportions of real perturbed, control and simulated (if `use_simulated_for_knn=True`)
61
+ data points for simulated data. If control (`C`) is not provided, builds the knn graph from
62
+ real perturbed + simulated perturbed.
63
+
64
+ Args:
65
+ real: Real perturbed data.
66
+ simulated: Simulated perturbed data.
67
+ control: Control data
68
+ use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when
69
+ control (`control`) is provided.
70
+ n_neighbors: Number of neighbors to use in k-neighbor graph.
71
+ random_state: Random state used for k-neighbor graph construction.
72
+ n_jobs: Number of cores to use. Defaults to -1 (all).
73
+
74
+ """
75
+ assert real.shape[1] == simulated.shape[1]
76
+ if control is not None:
77
+ assert real.shape[1] == control.shape[1]
78
+
79
+ n_y = simulated.shape[0]
80
+
81
+ if control is None:
82
+ index_data = sp_vstack((simulated, real)) if issparse(real) else np.vstack((simulated, real))
83
+ else:
84
+ datas = (simulated, real, control) if use_simulated_for_knn else (real, control)
85
+ index_data = sp_vstack(datas) if issparse(real) else np.vstack(datas)
86
+
87
+ y_in_index = use_simulated_for_knn or control is None
88
+ c_in_index = control is not None
89
+ label_groups = ["comp"]
90
+ labels: NDArray[np.str_] = np.full(index_data.shape[0], "comp")
91
+ if y_in_index:
92
+ labels[:n_y] = "siml"
93
+ label_groups.append("siml")
94
+ if c_in_index:
95
+ labels[-control.shape[0] :] = "ctrl"
96
+ label_groups.append("ctrl")
97
+
98
+ index = pynndescent.NNDescent(
99
+ index_data,
100
+ n_neighbors=max(50, n_neighbors),
101
+ random_state=random_state,
102
+ n_jobs=n_jobs,
103
+ )
104
+ indices = index.query(simulated, k=n_neighbors)[0]
105
+
106
+ uq, uq_counts = np.unique(labels[indices], return_counts=True)
107
+ uq_counts_norm = uq_counts / uq_counts.sum()
108
+ counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
109
+ for group, count_norm in zip(uq, uq_counts_norm, strict=False):
110
+ counts[group] = count_norm
111
+
112
+ return counts
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import warnings
4
- from typing import TYPE_CHECKING, Literal
5
4
 
6
5
  import anndata
7
6
  import numpy as np
@@ -42,12 +41,12 @@ class LRClassifierSpace(PerturbationSpace):
42
41
 
43
42
  Args:
44
43
  adata: AnnData object of size cells x genes
45
- target_col: .obs column that stores the perturbations. Defaults to "perturbations".
46
- layer_key: Layer in adata to use. Defaults to None.
44
+ target_col: .obs column that stores the perturbations.
45
+ layer_key: Layer in adata to use.
47
46
  embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
48
- Can only be specified if layer_key is None. Defaults to None.
49
- test_split_size: Fraction of data to put in the test set. Default to 0.2.
50
- max_iter: Maximum number of iterations taken for the solvers to converge. Defaults to 1000.
47
+ Can only be specified if layer_key is None.
48
+ test_split_size: Fraction of data to put in the test set.
49
+ max_iter: Maximum number of iterations taken for the solvers to converge.
51
50
 
52
51
  Returns:
53
52
  AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
@@ -163,24 +162,23 @@ class MLPClassifierSpace(PerturbationSpace):
163
162
 
164
163
  Args:
165
164
  adata: AnnData object of size cells x genes
166
- target_col: .obs column that stores the perturbations. Defaults to "perturbations".
167
- layer_key: Layer in adata to use. Defaults to None.
165
+ target_col: .obs column that stores the perturbations.
166
+ layer_key: Layer in adata to use.
168
167
  hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
169
168
  will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
170
- Defaults to [512].
171
- dropout: Amount of dropout applied, constant for all layers. Defaults to 0.
172
- batch_norm: Whether to apply batch normalization. Defaults to True.
173
- batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
169
+ dropout: Amount of dropout applied, constant for all layers.
170
+ batch_norm: Whether to apply batch normalization.
171
+ batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass.
174
172
  test_split_size: Fraction of data to put in the test set. Default to 0.2.
175
173
  validation_split_size: Fraction of data to put in the validation set of the resultant train set.
176
174
  E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
177
- will be used for validation. Defaults to 0.25.
178
- max_epochs: Maximum number of epochs for training. Defaults to 20.
175
+ will be used for validation.
176
+ max_epochs: Maximum number of epochs for training.
179
177
  val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
180
178
  Note that this affects early stopping, as the model will be stopped if the validation performance does not
181
- improve for patience epochs. Defaults to 2.
179
+ improve for patience epochs.
182
180
  patience: Number of validation performance checks without improvement, after which the early stopping flag
183
- is activated and training is therefore stopped. Defaults to 2.
181
+ is activated and training is therefore stopped.
184
182
 
185
183
  Returns:
186
184
  AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
@@ -325,10 +323,10 @@ class MLP(torch.nn.Module):
325
323
  """
326
324
  Args:
327
325
  sizes: size of layers.
328
- dropout: Dropout probability. Defaults to 0.0.
329
- batch_norm: specifies if batch norm should be applied. Defaults to True.
330
- layer_norm: specifies if layer norm should be applied, as commonly used in Transformers. Defaults to False.
331
- last_layer_act: activation function of last layer. Defaults to "linear".
326
+ dropout: Dropout probability.
327
+ batch_norm: specifies if batch norm should be applied.
328
+ layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
329
+ last_layer_act: activation function of last layer.
332
330
  """
333
331
  super().__init__()
334
332
  layers = []
@@ -392,8 +390,8 @@ class PLDataset(Dataset):
392
390
  """
393
391
  Args:
394
392
  adata: AnnData object with observations and labels.
395
- target_col: key with the perturbation labels numerically encoded. Defaults to 'perturbations'.
396
- label_col: key with the perturbation labels. Defaults to 'perturbations'.
393
+ target_col: key with the perturbation labels numerically encoded.
394
+ label_col: key with the perturbation labels.
397
395
  layer_key: key of the layer to be used as data, otherwise .X
398
396
  """
399
397
 
@@ -410,7 +408,7 @@ class PLDataset(Dataset):
410
408
 
411
409
  def __getitem__(self, idx):
412
410
  """Returns a sample and corresponding perturbations applied (labels)"""
413
- sample = self.data[idx].A.squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
411
+ sample = self.data[idx].toarray().squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
414
412
  num_label = self.labels.iloc[idx]
415
413
  str_label = self.pert_labels.iloc[idx]
416
414
 
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  from anndata import AnnData
8
- from pynndescent import NNDescent
8
+ from lamin_utils import logger
9
9
  from rich import print
10
10
 
11
11
  if TYPE_CHECKING:
@@ -40,13 +40,13 @@ class PerturbationSpace:
40
40
 
41
41
  Args:
42
42
  adata: Anndata object of size cells x genes.
43
- target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
44
- group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups. Defaults to 'perturbations'.
45
- reference_key: The key of the control values. Defaults to 'control'.
46
- layer_key: Key of the AnnData layer to use for computation. Defaults to the `X` matrix otherwise.
47
- new_layer_key: the results are stored in the given layer. Defaults to 'control_diff'.
48
- embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
49
- new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control_diff'.
43
+ target_col: .obs column name that stores the label of the perturbation applied to each cell.
44
+ group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups.
45
+ reference_key: The key of the control values.
46
+ layer_key: Key of the AnnData layer to use for computation.
47
+ new_layer_key: the results are stored in the given layer.
48
+ embedding_key: `obsm` key of the AnnData embedding to use for computation.
49
+ new_embedding_key: Results are stored in a new embedding in `obsm` with this key.
50
50
  all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
51
51
  copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
52
52
 
@@ -150,14 +150,14 @@ class PerturbationSpace:
150
150
  ensure_consistency: bool = False,
151
151
  target_col: str = "perturbation",
152
152
  ) -> tuple[AnnData, AnnData] | AnnData:
153
- """Add perturbations linearly. Assumes input of size n_perts x dimensionality
153
+ """Add perturbations linearly. Assumes input of size n_perts x dimensionality.
154
154
 
155
155
  Args:
156
156
  adata: Anndata object of size n_perts x dim.
157
157
  perturbations: Perturbations to add.
158
- reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'.
158
+ reference_key: perturbation source from which the perturbation summation starts.
159
159
  ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
160
- target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'.
160
+ target_col: .obs column name that stores the label of the perturbation applied to each cell.
161
161
 
162
162
  Returns:
163
163
  Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
@@ -182,8 +182,8 @@ class PerturbationSpace:
182
182
  new_pert_name += perturbation + "+"
183
183
 
184
184
  if not ensure_consistency:
185
- print(
186
- "[bold yellow]Operation might be done in non-consistent space (perturbation - perturbation != control). \n"
185
+ logger.warning(
186
+ "Operation might be done in non-consistent space (perturbation - perturbation != control). \n"
187
187
  "Subtract control perturbation needed for consistency of space in all data representations. \n"
188
188
  "Run with ensure_consistency=True"
189
189
  )
@@ -264,9 +264,9 @@ class PerturbationSpace:
264
264
  Args:
265
265
  adata: Anndata object of size n_perts x dim.
266
266
  perturbations: Perturbations to subtract.
267
- reference_key: Perturbation source from which the perturbation subtraction starts. Defaults to 'control'.
267
+ reference_key: Perturbation source from which the perturbation subtraction starts.
268
268
  ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
269
- target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
269
+ target_col: .obs column name that stores the label of the perturbation applied to each cell.
270
270
 
271
271
  Returns:
272
272
  Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations.
@@ -291,8 +291,8 @@ class PerturbationSpace:
291
291
  new_pert_name += perturbation + "-"
292
292
 
293
293
  if not ensure_consistency:
294
- print(
295
- "[bold yellow]Operation might be done in non-consistent space (perturbation - perturbation != control).\n"
294
+ logger.warning(
295
+ "Operation might be done in non-consistent space (perturbation - perturbation != control).\n"
296
296
  "Subtract control perturbation needed for consistency of space in all data representations.\n"
297
297
  "Run with ensure_consistency=True"
298
298
  )
@@ -372,10 +372,10 @@ class PerturbationSpace:
372
372
 
373
373
  Args:
374
374
  adata: The AnnData object containing single-cell data.
375
- column: The column name in AnnData object to perform imputation on. Defaults to "perturbation".
376
- target_val: The target value to impute. Defaults to "unknown".
377
- n_neighbors: Number of neighbors to use for imputation. Defaults to 5.
378
- use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'.
375
+ column: The column name in AnnData object to perform imputation on.
376
+ target_val: The target value to impute.
377
+ n_neighbors: Number of neighbors to use for imputation.
378
+ use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
379
379
 
380
380
  Examples:
381
381
  >>> import pertpy as pt
@@ -396,6 +396,8 @@ class PerturbationSpace:
396
396
 
397
397
  embedding = adata.obsm[use_rep]
398
398
 
399
+ from pynndescent import NNDescent
400
+
399
401
  nnd = NNDescent(embedding, n_neighbors=n_neighbors)
400
402
  indices, _ = nnd.query(embedding, k=n_neighbors)
401
403
 
@@ -28,7 +28,7 @@ class CentroidSpace(PerturbationSpace):
28
28
  layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
29
29
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
30
30
  keep_obs: Whether .obs columns in the input AnnData should be kept in the output pseudobulk AnnData. Only .obs columns with the same value for
31
- each cell of one perturbation are kept. Defaults to True.
31
+ each cell of one perturbation are kept.
32
32
 
33
33
  Returns:
34
34
  AnnData object with one observation per perturbation, storing the embedding data of the
@@ -129,7 +129,7 @@ class PseudobulkSpace(PerturbationSpace):
129
129
  adata: Anndata object of size cells x genes
130
130
  target_col: .obs column that stores the label of the perturbation applied to each cell.
131
131
  groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation.
132
- The summarized expression per perturbation (target_col) and group (groups_col) is computed. Defaults to None.
132
+ The summarized expression per perturbation (target_col) and group (groups_col) is computed.
133
133
  layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
134
134
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
135
135
  **kwargs: Are passed to decoupler's get_pseuobulk.
@@ -254,7 +254,7 @@ class DBSCANSpace(ClusteringSpace):
254
254
  adata: Anndata object of size cells x genes
255
255
  layer_key: If specified and exists in the adata, the clustering is done by using it. Otherwise, clustering is done with .X
256
256
  embedding_key: if specified and exists in the adata, the clustering is done with that embedding. Otherwise, clustering is done with .X
257
- cluster_key: name of the .obs column to store the cluster labels. Defaults to 'dbscan'
257
+ cluster_key: name of the .obs column to store the cluster labels.
258
258
  copy: if True returns a new Anndata of same size with the new column; otherwise it updates the initial adata
259
259
  return_object: if True returns the clustering object
260
260
  **kwargs: Are passed to sklearn's DBSCAN.
@@ -1 +1 @@
1
- from pertpy.tools._scgen._scgen import SCGEN
1
+ from pertpy.tools._scgen._scgen import Scgen
@@ -28,7 +28,7 @@ class FlaxEncoder(nn.Module):
28
28
 
29
29
  Args:
30
30
  x: The input data matrix.
31
- training: Whether
31
+ training: Whether to use running training average.
32
32
 
33
33
  Returns:
34
34
  Mean and variance.
@@ -69,12 +69,11 @@ class FlaxDecoder(nn.Module):
69
69
 
70
70
  Args:
71
71
  x: Input data.
72
- training:
72
+ training: Whether to use running training average.
73
73
 
74
74
  Returns:
75
75
  Decoded data.
76
76
  """
77
-
78
77
  training = nn.merge_param("training", self.training, training)
79
78
 
80
79
  for _ in range(self.n_layers):
@@ -10,6 +10,7 @@ import scanpy as sc
10
10
  from adjustText import adjust_text
11
11
  from anndata import AnnData
12
12
  from jax import Array
13
+ from lamin_utils import logger
13
14
  from scipy import stats
14
15
  from scvi import REGISTRY_KEYS
15
16
  from scvi.data import AnnDataManager
@@ -26,7 +27,7 @@ if TYPE_CHECKING:
26
27
  font = {"family": "Arial", "size": 14}
27
28
 
28
29
 
29
- class SCGEN(JaxTrainingMixin, BaseModelClass):
30
+ class Scgen(JaxTrainingMixin, BaseModelClass):
30
31
  """Jax Implementation of scGen model for batch removal and perturbation prediction."""
31
32
 
32
33
  def __init__(
@@ -49,7 +50,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
49
50
  **model_kwargs,
50
51
  )
51
52
  self._model_summary_string = (
52
- f"SCGEN Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
53
+ f"Scgen Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
53
54
  f"{dropout_rate}"
54
55
  )
55
56
  self.init_params_ = self._get_init_params(locals())
@@ -79,8 +80,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
79
80
  Examples:
80
81
  >>> import pertpy as pt
81
82
  >>> data = pt.dt.kang_2018()
82
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
83
- >>> model = pt.tl.SCGEN(data)
83
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
84
+ >>> model = pt.tl.Scgen(data)
84
85
  >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
85
86
  >>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
86
87
  """
@@ -166,8 +167,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
166
167
  Examples:
167
168
  >>> import pertpy as pt
168
169
  >>> data = pt.dt.kang_2018()
169
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
170
- >>> model = pt.tl.SCGEN(data)
170
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
171
+ >>> model = pt.tl.Scgen(data)
171
172
  >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
172
173
  >>> decoded_X = model.get_decoded_expression()
173
174
  """
@@ -200,8 +201,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
200
201
  Examples:
201
202
  >>> import pertpy as pt
202
203
  >>> data = pt.dt.kang_2018()
203
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
204
- >>> model = pt.tl.SCGEN(data)
204
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
205
+ >>> model = pt.tl.Scgen(data)
205
206
  >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
206
207
  >>> corrected_adata = model.batch_removal()
207
208
  """
@@ -304,7 +305,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
304
305
  Examples:
305
306
  >>> import pertpy as pt
306
307
  >>> data = pt.dt.kang_2018()
307
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
308
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
308
309
  """
309
310
  setup_method_args = cls._get_setup_method_args(**locals())
310
311
  anndata_fields = [
@@ -345,8 +346,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
345
346
  Examples:
346
347
  >>> import pertpy as pt
347
348
  >>> data = pt.dt.kang_2018()
348
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
349
- >>> model = pt.tl.SCGEN(data)
349
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
350
+ >>> model = pt.tl.Scgen(data)
350
351
  >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
351
352
  >>> latent_X = model.get_latent_representation()
352
353
  """
@@ -403,19 +404,19 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
403
404
  gene_list: list of gene names to be plotted.
404
405
  show: if `True`: will show to the plot after saving it.
405
406
  top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
406
- verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
407
- legend: if `True`: plots a legend, defaults to `True`.
407
+ verbose: Specify if you want information to be printed while creating the plot.,
408
+ legend: Whether to plot a legend.
408
409
  title: Set if you want the plot to display a title.
409
- x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
410
- y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
411
- fontsize: Fontsize used for text in the plot, defaults to 14.
410
+ x_coeff: Offset to print the R^2 value in x-direction.
411
+ y_coeff: Offset to print the R^2 value in y-direction.
412
+ fontsize: Fontsize used for text in the plot.
412
413
  **kwargs:
413
414
 
414
415
  Examples:
415
416
  >>> import pertpy as pt
416
417
  >>> data = pt.dt.kang_2018()
417
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
418
- >>> scg = pt.tl.SCGEN(data)
418
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
419
+ >>> scg = pt.tl.Scgen(data)
419
420
  >>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
420
421
  >>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
421
422
  >>> pred.obs['label'] = 'pred'
@@ -444,12 +445,12 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
444
445
  y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
445
446
  m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
446
447
  if verbose:
447
- print("top_100 DEGs mean: ", r_value_diff**2)
448
+ logger.info("top_100 DEGs mean: ", r_value_diff**2)
448
449
  x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
449
450
  y = np.asarray(np.mean(stim.X, axis=0)).ravel()
450
451
  m, b, r_value, p_value, std_err = stats.linregress(x, y)
451
452
  if verbose:
452
- print("All genes mean: ", r_value**2)
453
+ logger.info("All genes mean: ", r_value**2)
453
454
  df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
454
455
  ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
455
456
  ax.tick_params(labelsize=fontsize)
@@ -540,12 +541,12 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
540
541
  gene_list: list of gene names to be plotted.
541
542
  show: if `True`: will show to the plot after saving it.
542
543
  top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
543
- legend: if `True`: plots a legend, defaults to `True`.
544
+ legend: Whether to plot a elgend
544
545
  title: Set if you want the plot to display a title.
545
- verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
546
- x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
547
- y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
548
- fontsize: Fontsize used for text in the plot, defaults to 14.
546
+ verbose: Specify if you want information to be printed while creating the plot.
547
+ x_coeff: Offset to print the R^2 value in x-direction.
548
+ y_coeff: Offset to print the R^2 value in y-direction.
549
+ fontsize: Fontsize used for text in the plot.
549
550
  """
550
551
  import seaborn as sns
551
552
 
@@ -566,14 +567,14 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
566
567
  y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
567
568
  m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
568
569
  if verbose:
569
- print("Top 100 DEGs var: ", r_value_diff**2)
570
+ logger.info("Top 100 DEGs var: ", r_value_diff**2)
570
571
  if "y1" in axis_keys.keys():
571
572
  real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
572
573
  x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
573
574
  y = np.asarray(np.var(stim.X, axis=0)).ravel()
574
575
  m, b, r_value, p_value, std_err = stats.linregress(x, y)
575
576
  if verbose:
576
- print("All genes var: ", r_value**2)
577
+ logger.info("All genes var: ", r_value**2)
577
578
  df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
578
579
  ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
579
580
  ax.tick_params(labelsize=fontsize)
@@ -637,7 +638,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
637
638
 
638
639
  def plot_binary_classifier(
639
640
  self,
640
- scgen: SCGEN,
641
+ scgen: Scgen,
641
642
  adata: AnnData | None,
642
643
  delta: np.ndarray,
643
644
  ctrl_key: str,
@@ -699,3 +700,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
699
700
  if not (show or save):
700
701
  return ax
701
702
  return None
703
+
704
+
705
+ # compatibility
706
+ SCGEN = Scgen
@@ -27,7 +27,7 @@ def extractor(
27
27
  Example:
28
28
  .. code-block:: python
29
29
 
30
- import SCGEN
30
+ import Scgen
31
31
  import anndata
32
32
 
33
33
  train_data = anndata.read("./data/train.h5ad")
@@ -58,7 +58,7 @@ def balancer(
58
58
  Example:
59
59
  .. code-block:: python
60
60
 
61
- import SCGEN
61
+ import Scgen
62
62
  import anndata
63
63
 
64
64
  train_data = anndata.read("./train_kang.h5ad")