pertpy 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. pertpy/__init__.py +2 -1
  2. pertpy/data/__init__.py +61 -0
  3. pertpy/data/_dataloader.py +27 -23
  4. pertpy/data/_datasets.py +58 -0
  5. pertpy/metadata/__init__.py +2 -0
  6. pertpy/metadata/_cell_line.py +39 -70
  7. pertpy/metadata/_compound.py +3 -4
  8. pertpy/metadata/_drug.py +2 -6
  9. pertpy/metadata/_look_up.py +38 -51
  10. pertpy/metadata/_metadata.py +7 -10
  11. pertpy/metadata/_moa.py +2 -6
  12. pertpy/plot/__init__.py +0 -5
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +2 -3
  15. pertpy/tools/__init__.py +42 -4
  16. pertpy/tools/_augur.py +14 -15
  17. pertpy/tools/_cinemaot.py +2 -2
  18. pertpy/tools/_coda/_base_coda.py +118 -142
  19. pertpy/tools/_coda/_sccoda.py +16 -15
  20. pertpy/tools/_coda/_tasccoda.py +21 -22
  21. pertpy/tools/_dialogue.py +18 -23
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +21 -16
  32. pertpy/tools/_distances/_distances.py +406 -70
  33. pertpy/tools/_enrichment.py +10 -15
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +76 -53
  36. pertpy/tools/_mixscape.py +15 -11
  37. pertpy/tools/_perturbation_space/_clustering.py +5 -2
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +20 -22
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
  41. pertpy/tools/_perturbation_space/_simple.py +3 -3
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +33 -28
  45. pertpy/tools/_scgen/_utils.py +2 -2
  46. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +22 -13
  47. pertpy-0.8.0.dist-info/RECORD +57 -0
  48. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  49. pertpy/plot/_augur.py +0 -171
  50. pertpy/plot/_coda.py +0 -601
  51. pertpy/plot/_guide_rna.py +0 -64
  52. pertpy/plot/_milopy.py +0 -209
  53. pertpy/plot/_mixscape.py +0 -355
  54. pertpy/tools/_differential_gene_expression.py +0 -325
  55. pertpy-0.7.0.dist-info/RECORD +0 -53
  56. {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,112 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import numpy as np
4
+ import pynndescent
5
+ from scipy.sparse import issparse
6
+ from scipy.sparse import vstack as sp_vstack
7
+ from sklearn.base import ClassifierMixin
8
+ from sklearn.linear_model import LogisticRegression
9
+
10
+ if TYPE_CHECKING:
11
+ from numpy.typing import NDArray
12
+
13
+
14
+ class PerturbationComparison:
15
+ """Comparison between real and simulated perturbations."""
16
+
17
+ def compare_classification(
18
+ self,
19
+ real: np.ndarray,
20
+ simulated: np.ndarray,
21
+ control: np.ndarray,
22
+ clf: ClassifierMixin | None = None,
23
+ ) -> float:
24
+ """Compare classification accuracy between real and simulated perturbations.
25
+
26
+ Trains a classifier on the real perturbation data + the control data and reports a normalized
27
+ classification accuracy on the simulated perturbation.
28
+
29
+ Args:
30
+ real: Real perturbed data.
31
+ simulated: Simulated perturbed data.
32
+ control: Control data
33
+ clf: sklearn classifier to use, `sklearn.linear_model.LogisticRegression` if not provided.
34
+ """
35
+ assert real.shape[1] == simulated.shape[1] == control.shape[1]
36
+ if clf is None:
37
+ clf = LogisticRegression()
38
+ n_x = real.shape[0]
39
+ data = sp_vstack((real, control)) if issparse(real) else np.vstack((real, control))
40
+ labels = np.concatenate([np.full(real.shape[0], "comp"), np.full(control.shape[0], "ctrl")])
41
+
42
+ clf.fit(data, labels)
43
+ norm_score = clf.score(simulated, np.full(simulated.shape[0], "comp")) / clf.score(real, labels[:n_x])
44
+ norm_score = min(1.0, norm_score)
45
+
46
+ return norm_score
47
+
48
+ def compare_knn(
49
+ self,
50
+ real: np.ndarray,
51
+ simulated: np.ndarray,
52
+ control: np.ndarray | None = None,
53
+ use_simulated_for_knn: bool = False,
54
+ n_neighbors: int = 20,
55
+ random_state: int = 0,
56
+ n_jobs: int = 1,
57
+ ) -> dict[str, float]:
58
+ """Calculate proportions of real perturbed and control data points for simulated data.
59
+
60
+ Computes proportions of real perturbed, control and simulated (if `use_simulated_for_knn=True`)
61
+ data points for simulated data. If control (`C`) is not provided, builds the knn graph from
62
+ real perturbed + simulated perturbed.
63
+
64
+ Args:
65
+ real: Real perturbed data.
66
+ simulated: Simulated perturbed data.
67
+ control: Control data
68
+ use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when
69
+ control (`control`) is provided.
70
+ n_neighbors: Number of neighbors to use in k-neighbor graph.
71
+ random_state: Random state used for k-neighbor graph construction.
72
+ n_jobs: Number of cores to use. Defaults to -1 (all).
73
+
74
+ """
75
+ assert real.shape[1] == simulated.shape[1]
76
+ if control is not None:
77
+ assert real.shape[1] == control.shape[1]
78
+
79
+ n_y = simulated.shape[0]
80
+
81
+ if control is None:
82
+ index_data = sp_vstack((simulated, real)) if issparse(real) else np.vstack((simulated, real))
83
+ else:
84
+ datas = (simulated, real, control) if use_simulated_for_knn else (real, control)
85
+ index_data = sp_vstack(datas) if issparse(real) else np.vstack(datas)
86
+
87
+ y_in_index = use_simulated_for_knn or control is None
88
+ c_in_index = control is not None
89
+ label_groups = ["comp"]
90
+ labels: NDArray[np.str_] = np.full(index_data.shape[0], "comp")
91
+ if y_in_index:
92
+ labels[:n_y] = "siml"
93
+ label_groups.append("siml")
94
+ if c_in_index:
95
+ labels[-control.shape[0] :] = "ctrl"
96
+ label_groups.append("ctrl")
97
+
98
+ index = pynndescent.NNDescent(
99
+ index_data,
100
+ n_neighbors=max(50, n_neighbors),
101
+ random_state=random_state,
102
+ n_jobs=n_jobs,
103
+ )
104
+ indices = index.query(simulated, k=n_neighbors)[0]
105
+
106
+ uq, uq_counts = np.unique(labels[indices], return_counts=True)
107
+ uq_counts_norm = uq_counts / uq_counts.sum()
108
+ counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
109
+ for group, count_norm in zip(uq, uq_counts_norm, strict=False):
110
+ counts[group] = count_norm
111
+
112
+ return counts
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import warnings
4
- from typing import TYPE_CHECKING, Literal
5
4
 
6
5
  import anndata
7
6
  import numpy as np
@@ -42,12 +41,12 @@ class LRClassifierSpace(PerturbationSpace):
42
41
 
43
42
  Args:
44
43
  adata: AnnData object of size cells x genes
45
- target_col: .obs column that stores the perturbations. Defaults to "perturbations".
46
- layer_key: Layer in adata to use. Defaults to None.
44
+ target_col: .obs column that stores the perturbations.
45
+ layer_key: Layer in adata to use.
47
46
  embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
48
- Can only be specified if layer_key is None. Defaults to None.
49
- test_split_size: Fraction of data to put in the test set. Default to 0.2.
50
- max_iter: Maximum number of iterations taken for the solvers to converge. Defaults to 1000.
47
+ Can only be specified if layer_key is None.
48
+ test_split_size: Fraction of data to put in the test set.
49
+ max_iter: Maximum number of iterations taken for the solvers to converge.
51
50
 
52
51
  Returns:
53
52
  AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
@@ -163,24 +162,23 @@ class MLPClassifierSpace(PerturbationSpace):
163
162
 
164
163
  Args:
165
164
  adata: AnnData object of size cells x genes
166
- target_col: .obs column that stores the perturbations. Defaults to "perturbations".
167
- layer_key: Layer in adata to use. Defaults to None.
165
+ target_col: .obs column that stores the perturbations.
166
+ layer_key: Layer in adata to use.
168
167
  hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
169
168
  will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
170
- Defaults to [512].
171
- dropout: Amount of dropout applied, constant for all layers. Defaults to 0.
172
- batch_norm: Whether to apply batch normalization. Defaults to True.
173
- batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
169
+ dropout: Amount of dropout applied, constant for all layers.
170
+ batch_norm: Whether to apply batch normalization.
171
+ batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass.
174
172
  test_split_size: Fraction of data to put in the test set. Default to 0.2.
175
173
  validation_split_size: Fraction of data to put in the validation set of the resultant train set.
176
174
  E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
177
- will be used for validation. Defaults to 0.25.
178
- max_epochs: Maximum number of epochs for training. Defaults to 20.
175
+ will be used for validation.
176
+ max_epochs: Maximum number of epochs for training.
179
177
  val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
180
178
  Note that this affects early stopping, as the model will be stopped if the validation performance does not
181
- improve for patience epochs. Defaults to 2.
179
+ improve for patience epochs.
182
180
  patience: Number of validation performance checks without improvement, after which the early stopping flag
183
- is activated and training is therefore stopped. Defaults to 2.
181
+ is activated and training is therefore stopped.
184
182
 
185
183
  Returns:
186
184
  AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
@@ -325,10 +323,10 @@ class MLP(torch.nn.Module):
325
323
  """
326
324
  Args:
327
325
  sizes: size of layers.
328
- dropout: Dropout probability. Defaults to 0.0.
329
- batch_norm: specifies if batch norm should be applied. Defaults to True.
330
- layer_norm: specifies if layer norm should be applied, as commonly used in Transformers. Defaults to False.
331
- last_layer_act: activation function of last layer. Defaults to "linear".
326
+ dropout: Dropout probability.
327
+ batch_norm: specifies if batch norm should be applied.
328
+ layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
329
+ last_layer_act: activation function of last layer.
332
330
  """
333
331
  super().__init__()
334
332
  layers = []
@@ -392,8 +390,8 @@ class PLDataset(Dataset):
392
390
  """
393
391
  Args:
394
392
  adata: AnnData object with observations and labels.
395
- target_col: key with the perturbation labels numerically encoded. Defaults to 'perturbations'.
396
- label_col: key with the perturbation labels. Defaults to 'perturbations'.
393
+ target_col: key with the perturbation labels numerically encoded.
394
+ label_col: key with the perturbation labels.
397
395
  layer_key: key of the layer to be used as data, otherwise .X
398
396
  """
399
397
 
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  from anndata import AnnData
8
- from pynndescent import NNDescent
8
+ from lamin_utils import logger
9
9
  from rich import print
10
10
 
11
11
  if TYPE_CHECKING:
@@ -40,13 +40,13 @@ class PerturbationSpace:
40
40
 
41
41
  Args:
42
42
  adata: Anndata object of size cells x genes.
43
- target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
44
- group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups. Defaults to 'perturbations'.
45
- reference_key: The key of the control values. Defaults to 'control'.
46
- layer_key: Key of the AnnData layer to use for computation. Defaults to the `X` matrix otherwise.
47
- new_layer_key: the results are stored in the given layer. Defaults to 'control_diff'.
48
- embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
49
- new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control_diff'.
43
+ target_col: .obs column name that stores the label of the perturbation applied to each cell.
44
+ group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups.
45
+ reference_key: The key of the control values.
46
+ layer_key: Key of the AnnData layer to use for computation.
47
+ new_layer_key: the results are stored in the given layer.
48
+ embedding_key: `obsm` key of the AnnData embedding to use for computation.
49
+ new_embedding_key: Results are stored in a new embedding in `obsm` with this key.
50
50
  all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
51
51
  copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
52
52
 
@@ -150,14 +150,14 @@ class PerturbationSpace:
150
150
  ensure_consistency: bool = False,
151
151
  target_col: str = "perturbation",
152
152
  ) -> tuple[AnnData, AnnData] | AnnData:
153
- """Add perturbations linearly. Assumes input of size n_perts x dimensionality
153
+ """Add perturbations linearly. Assumes input of size n_perts x dimensionality.
154
154
 
155
155
  Args:
156
156
  adata: Anndata object of size n_perts x dim.
157
157
  perturbations: Perturbations to add.
158
- reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'.
158
+ reference_key: perturbation source from which the perturbation summation starts.
159
159
  ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
160
- target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'.
160
+ target_col: .obs column name that stores the label of the perturbation applied to each cell.
161
161
 
162
162
  Returns:
163
163
  Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
@@ -182,8 +182,8 @@ class PerturbationSpace:
182
182
  new_pert_name += perturbation + "+"
183
183
 
184
184
  if not ensure_consistency:
185
- print(
186
- "[bold yellow]Operation might be done in non-consistent space (perturbation - perturbation != control). \n"
185
+ logger.warning(
186
+ "Operation might be done in non-consistent space (perturbation - perturbation != control). \n"
187
187
  "Subtract control perturbation needed for consistency of space in all data representations. \n"
188
188
  "Run with ensure_consistency=True"
189
189
  )
@@ -264,9 +264,9 @@ class PerturbationSpace:
264
264
  Args:
265
265
  adata: Anndata object of size n_perts x dim.
266
266
  perturbations: Perturbations to subtract.
267
- reference_key: Perturbation source from which the perturbation subtraction starts. Defaults to 'control'.
267
+ reference_key: Perturbation source from which the perturbation subtraction starts.
268
268
  ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
269
- target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
269
+ target_col: .obs column name that stores the label of the perturbation applied to each cell.
270
270
 
271
271
  Returns:
272
272
  Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations.
@@ -291,8 +291,8 @@ class PerturbationSpace:
291
291
  new_pert_name += perturbation + "-"
292
292
 
293
293
  if not ensure_consistency:
294
- print(
295
- "[bold yellow]Operation might be done in non-consistent space (perturbation - perturbation != control).\n"
294
+ logger.warning(
295
+ "Operation might be done in non-consistent space (perturbation - perturbation != control).\n"
296
296
  "Subtract control perturbation needed for consistency of space in all data representations.\n"
297
297
  "Run with ensure_consistency=True"
298
298
  )
@@ -372,10 +372,10 @@ class PerturbationSpace:
372
372
 
373
373
  Args:
374
374
  adata: The AnnData object containing single-cell data.
375
- column: The column name in AnnData object to perform imputation on. Defaults to "perturbation".
376
- target_val: The target value to impute. Defaults to "unknown".
377
- n_neighbors: Number of neighbors to use for imputation. Defaults to 5.
378
- use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'.
375
+ column: The column name in AnnData object to perform imputation on.
376
+ target_val: The target value to impute.
377
+ n_neighbors: Number of neighbors to use for imputation.
378
+ use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
379
379
 
380
380
  Examples:
381
381
  >>> import pertpy as pt
@@ -396,6 +396,8 @@ class PerturbationSpace:
396
396
 
397
397
  embedding = adata.obsm[use_rep]
398
398
 
399
+ from pynndescent import NNDescent
400
+
399
401
  nnd = NNDescent(embedding, n_neighbors=n_neighbors)
400
402
  indices, _ = nnd.query(embedding, k=n_neighbors)
401
403
 
@@ -28,7 +28,7 @@ class CentroidSpace(PerturbationSpace):
28
28
  layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
29
29
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
30
30
  keep_obs: Whether .obs columns in the input AnnData should be kept in the output pseudobulk AnnData. Only .obs columns with the same value for
31
- each cell of one perturbation are kept. Defaults to True.
31
+ each cell of one perturbation are kept.
32
32
 
33
33
  Returns:
34
34
  AnnData object with one observation per perturbation, storing the embedding data of the
@@ -129,7 +129,7 @@ class PseudobulkSpace(PerturbationSpace):
129
129
  adata: Anndata object of size cells x genes
130
130
  target_col: .obs column that stores the label of the perturbation applied to each cell.
131
131
  groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation.
132
- The summarized expression per perturbation (target_col) and group (groups_col) is computed. Defaults to None.
132
+ The summarized expression per perturbation (target_col) and group (groups_col) is computed.
133
133
  layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
134
134
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
135
135
  **kwargs: Are passed to decoupler's get_pseuobulk.
@@ -254,7 +254,7 @@ class DBSCANSpace(ClusteringSpace):
254
254
  adata: Anndata object of size cells x genes
255
255
  layer_key: If specified and exists in the adata, the clustering is done by using it. Otherwise, clustering is done with .X
256
256
  embedding_key: if specified and exists in the adata, the clustering is done with that embedding. Otherwise, clustering is done with .X
257
- cluster_key: name of the .obs column to store the cluster labels. Defaults to 'dbscan'
257
+ cluster_key: name of the .obs column to store the cluster labels.
258
258
  copy: if True returns a new Anndata of same size with the new column; otherwise it updates the initial adata
259
259
  return_object: if True returns the clustering object
260
260
  **kwargs: Are passed to sklearn's DBSCAN.
@@ -1 +1 @@
1
- from pertpy.tools._scgen._scgen import SCGEN
1
+ from pertpy.tools._scgen._scgen import Scgen
@@ -28,7 +28,7 @@ class FlaxEncoder(nn.Module):
28
28
 
29
29
  Args:
30
30
  x: The input data matrix.
31
- training: Whether
31
+ training: Whether to use running training average.
32
32
 
33
33
  Returns:
34
34
  Mean and variance.
@@ -69,12 +69,11 @@ class FlaxDecoder(nn.Module):
69
69
 
70
70
  Args:
71
71
  x: Input data.
72
- training:
72
+ training: Whether to use running training average.
73
73
 
74
74
  Returns:
75
75
  Decoded data.
76
76
  """
77
-
78
77
  training = nn.merge_param("training", self.training, training)
79
78
 
80
79
  for _ in range(self.n_layers):
@@ -10,6 +10,7 @@ import scanpy as sc
10
10
  from adjustText import adjust_text
11
11
  from anndata import AnnData
12
12
  from jax import Array
13
+ from lamin_utils import logger
13
14
  from scipy import stats
14
15
  from scvi import REGISTRY_KEYS
15
16
  from scvi.data import AnnDataManager
@@ -26,7 +27,7 @@ if TYPE_CHECKING:
26
27
  font = {"family": "Arial", "size": 14}
27
28
 
28
29
 
29
- class SCGEN(JaxTrainingMixin, BaseModelClass):
30
+ class Scgen(JaxTrainingMixin, BaseModelClass):
30
31
  """Jax Implementation of scGen model for batch removal and perturbation prediction."""
31
32
 
32
33
  def __init__(
@@ -49,7 +50,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
49
50
  **model_kwargs,
50
51
  )
51
52
  self._model_summary_string = (
52
- f"SCGEN Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
53
+ f"Scgen Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
53
54
  f"{dropout_rate}"
54
55
  )
55
56
  self.init_params_ = self._get_init_params(locals())
@@ -79,8 +80,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
79
80
  Examples:
80
81
  >>> import pertpy as pt
81
82
  >>> data = pt.dt.kang_2018()
82
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
83
- >>> model = pt.tl.SCGEN(data)
83
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
84
+ >>> model = pt.tl.Scgen(data)
84
85
  >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
85
86
  >>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
86
87
  """
@@ -166,8 +167,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
166
167
  Examples:
167
168
  >>> import pertpy as pt
168
169
  >>> data = pt.dt.kang_2018()
169
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
170
- >>> model = pt.tl.SCGEN(data)
170
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
171
+ >>> model = pt.tl.Scgen(data)
171
172
  >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
172
173
  >>> decoded_X = model.get_decoded_expression()
173
174
  """
@@ -200,8 +201,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
200
201
  Examples:
201
202
  >>> import pertpy as pt
202
203
  >>> data = pt.dt.kang_2018()
203
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
204
- >>> model = pt.tl.SCGEN(data)
204
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
205
+ >>> model = pt.tl.Scgen(data)
205
206
  >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
206
207
  >>> corrected_adata = model.batch_removal()
207
208
  """
@@ -304,7 +305,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
304
305
  Examples:
305
306
  >>> import pertpy as pt
306
307
  >>> data = pt.dt.kang_2018()
307
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
308
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
308
309
  """
309
310
  setup_method_args = cls._get_setup_method_args(**locals())
310
311
  anndata_fields = [
@@ -345,8 +346,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
345
346
  Examples:
346
347
  >>> import pertpy as pt
347
348
  >>> data = pt.dt.kang_2018()
348
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
349
- >>> model = pt.tl.SCGEN(data)
349
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
350
+ >>> model = pt.tl.Scgen(data)
350
351
  >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
351
352
  >>> latent_X = model.get_latent_representation()
352
353
  """
@@ -403,19 +404,19 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
403
404
  gene_list: list of gene names to be plotted.
404
405
  show: if `True`: will show to the plot after saving it.
405
406
  top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
406
- verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
407
- legend: if `True`: plots a legend, defaults to `True`.
407
+ verbose: Specify if you want information to be printed while creating the plot.,
408
+ legend: Whether to plot a legend.
408
409
  title: Set if you want the plot to display a title.
409
- x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
410
- y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
411
- fontsize: Fontsize used for text in the plot, defaults to 14.
410
+ x_coeff: Offset to print the R^2 value in x-direction.
411
+ y_coeff: Offset to print the R^2 value in y-direction.
412
+ fontsize: Fontsize used for text in the plot.
412
413
  **kwargs:
413
414
 
414
415
  Examples:
415
416
  >>> import pertpy as pt
416
417
  >>> data = pt.dt.kang_2018()
417
- >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type")
418
- >>> scg = pt.tl.SCGEN(data)
418
+ >>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
419
+ >>> scg = pt.tl.Scgen(data)
419
420
  >>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
420
421
  >>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
421
422
  >>> pred.obs['label'] = 'pred'
@@ -444,12 +445,12 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
444
445
  y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
445
446
  m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
446
447
  if verbose:
447
- print("top_100 DEGs mean: ", r_value_diff**2)
448
+ logger.info("top_100 DEGs mean: ", r_value_diff**2)
448
449
  x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
449
450
  y = np.asarray(np.mean(stim.X, axis=0)).ravel()
450
451
  m, b, r_value, p_value, std_err = stats.linregress(x, y)
451
452
  if verbose:
452
- print("All genes mean: ", r_value**2)
453
+ logger.info("All genes mean: ", r_value**2)
453
454
  df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
454
455
  ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
455
456
  ax.tick_params(labelsize=fontsize)
@@ -540,12 +541,12 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
540
541
  gene_list: list of gene names to be plotted.
541
542
  show: if `True`: will show to the plot after saving it.
542
543
  top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
543
- legend: if `True`: plots a legend, defaults to `True`.
544
+ legend: Whether to plot a elgend
544
545
  title: Set if you want the plot to display a title.
545
- verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
546
- x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
547
- y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
548
- fontsize: Fontsize used for text in the plot, defaults to 14.
546
+ verbose: Specify if you want information to be printed while creating the plot.
547
+ x_coeff: Offset to print the R^2 value in x-direction.
548
+ y_coeff: Offset to print the R^2 value in y-direction.
549
+ fontsize: Fontsize used for text in the plot.
549
550
  """
550
551
  import seaborn as sns
551
552
 
@@ -566,14 +567,14 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
566
567
  y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
567
568
  m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
568
569
  if verbose:
569
- print("Top 100 DEGs var: ", r_value_diff**2)
570
+ logger.info("Top 100 DEGs var: ", r_value_diff**2)
570
571
  if "y1" in axis_keys.keys():
571
572
  real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
572
573
  x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
573
574
  y = np.asarray(np.var(stim.X, axis=0)).ravel()
574
575
  m, b, r_value, p_value, std_err = stats.linregress(x, y)
575
576
  if verbose:
576
- print("All genes var: ", r_value**2)
577
+ logger.info("All genes var: ", r_value**2)
577
578
  df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
578
579
  ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
579
580
  ax.tick_params(labelsize=fontsize)
@@ -637,7 +638,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
637
638
 
638
639
  def plot_binary_classifier(
639
640
  self,
640
- scgen: SCGEN,
641
+ scgen: Scgen,
641
642
  adata: AnnData | None,
642
643
  delta: np.ndarray,
643
644
  ctrl_key: str,
@@ -699,3 +700,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
699
700
  if not (show or save):
700
701
  return ax
701
702
  return None
703
+
704
+
705
+ # compatibility
706
+ SCGEN = Scgen
@@ -27,7 +27,7 @@ def extractor(
27
27
  Example:
28
28
  .. code-block:: python
29
29
 
30
- import SCGEN
30
+ import Scgen
31
31
  import anndata
32
32
 
33
33
  train_data = anndata.read("./data/train.h5ad")
@@ -58,7 +58,7 @@ def balancer(
58
58
  Example:
59
59
  .. code-block:: python
60
60
 
61
- import SCGEN
61
+ import Scgen
62
62
  import anndata
63
63
 
64
64
  train_data = anndata.read("./train_kang.h5ad")
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pertpy
3
- Version: 0.7.0
3
+ Version: 0.8.0
4
4
  Summary: Perturbation Analysis in the scverse ecosystem.
5
5
  Project-URL: Documentation, https://pertpy.readthedocs.io
6
- Project-URL: Source, https://github.com/theislab/pertpy
7
- Project-URL: Home-page, https://github.com/theislab/pertpy
8
- Author: Lukas Heumos, Yuge Ji, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu, Amir Moinfar, Sergei Rybakov, Tessa Green, Stefan Peidli, Antonia Schumacher, Lilly May
6
+ Project-URL: Source, https://github.com/scverse/pertpy
7
+ Project-URL: Home-page, https://github.com/scverse/pertpy
8
+ Author: Lukas Heumos, Yuge Ji, Lilly May, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu, Tessa Green, Stefan Peidli, Antonia Schumacher, Gregor Sturm
9
9
  Maintainer-email: Lukas Heumos <lukas.heumos@posteo.net>
10
10
  License: MIT License
11
11
 
@@ -45,11 +45,10 @@ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
45
45
  Classifier: Topic :: Scientific/Engineering :: Visualization
46
46
  Requires-Python: >=3.10
47
47
  Requires-Dist: adjusttext
48
- Requires-Dist: arviz
49
48
  Requires-Dist: blitzgsea
50
49
  Requires-Dist: decoupler
50
+ Requires-Dist: lamin-utils
51
51
  Requires-Dist: muon
52
- Requires-Dist: numpyro
53
52
  Requires-Dist: openpyxl
54
53
  Requires-Dist: ott-jax
55
54
  Requires-Dist: pubchempy
@@ -61,10 +60,14 @@ Requires-Dist: scikit-misc
61
60
  Requires-Dist: scipy
62
61
  Requires-Dist: scvi-tools
63
62
  Requires-Dist: sparsecca
64
- Requires-Dist: toytree
65
63
  Provides-Extra: coda
64
+ Requires-Dist: arviz; extra == 'coda'
66
65
  Requires-Dist: ete3; extra == 'coda'
67
66
  Requires-Dist: pyqt5; extra == 'coda'
67
+ Requires-Dist: toytree; extra == 'coda'
68
+ Provides-Extra: de
69
+ Requires-Dist: formulaic; extra == 'de'
70
+ Requires-Dist: pydeseq2; extra == 'de'
68
71
  Provides-Extra: dev
69
72
  Requires-Dist: pre-commit; extra == 'dev'
70
73
  Provides-Extra: doc
@@ -94,18 +97,18 @@ Requires-Dist: pytest; extra == 'test'
94
97
  Description-Content-Type: text/markdown
95
98
 
96
99
  [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
97
- [![Build](https://github.com/theislab/pertpy/actions/workflows/build.yml/badge.svg)](https://github.com/theislab/pertpy/actions/workflows/build.yml)
98
- [![codecov](https://codecov.io/gh/theislab/pertpy/graph/badge.svg?token=1dTpIPBShv)](https://codecov.io/gh/theislab/pertpy)
99
- [![License](https://img.shields.io/github/license/theislab/pertpy)](https://opensource.org/licenses/Apache2.0)
100
+ [![Build](https://github.com/scverse/pertpy/actions/workflows/build.yml/badge.svg)](https://github.com/scverse/pertpy/actions/workflows/build.yml)
101
+ [![codecov](https://codecov.io/gh/scverse/pertpy/graph/badge.svg?token=1dTpIPBShv)](https://codecov.io/gh/scverse/pertpy)
102
+ [![License](https://img.shields.io/github/license/scverse/pertpy)](https://opensource.org/licenses/Apache2.0)
100
103
  [![PyPI](https://img.shields.io/pypi/v/pertpy.svg)](https://pypi.org/project/pertpy/)
101
104
  [![Python Version](https://img.shields.io/pypi/pyversions/pertpy)](https://pypi.org/project/pertpy)
102
105
  [![Read the Docs](https://img.shields.io/readthedocs/pertpy/latest.svg?label=Read%20the%20Docs)](https://pertpy.readthedocs.io/)
103
- [![Test](https://github.com/theislab/pertpy/actions/workflows/test.yml/badge.svg)](https://github.com/theislab/pertpy/actions/workflows/test.yml)
106
+ [![Test](https://github.com/scverse/pertpy/actions/workflows/test.yml/badge.svg)](https://github.com/scverse/pertpy/actions/workflows/test.yml)
104
107
  [![PyPI](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
105
108
 
106
109
  # pertpy
107
110
 
108
- ![fig1](https://github.com/theislab/pertpy/assets/99650244/182fa9c3-6d23-4002-b86a-82bf2a243377)
111
+ ![fig1](https://github.com/scverse/pertpy/assets/99650244/182fa9c3-6d23-4002-b86a-82bf2a243377)
109
112
 
110
113
  ## Documentation
111
114
 
@@ -119,12 +122,18 @@ You can install _pertpy_ via [pip] from [PyPI]:
119
122
  pip install pertpy
120
123
  ```
121
124
 
122
- if you want to use scCODA please install it as:
125
+ if you want to use scCODA or tascCODA, please install pertpy as follows:
123
126
 
124
127
  ```console
125
128
  pip install pertpy[coda]
126
129
  ```
127
130
 
131
+ If you want to use the differential gene expression interface, please install pertpy by running:
132
+
133
+ ```console
134
+ pip install pertpy[de]
135
+ ```
136
+
128
137
  [pip]: https://pip.pypa.io/
129
138
  [pypi]: https://pypi.org/
130
139
  [usage]: https://pertpy.readthedocs.io/en/latest/usage/usage.html