pertpy 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +2 -1
- pertpy/data/__init__.py +61 -0
- pertpy/data/_dataloader.py +27 -23
- pertpy/data/_datasets.py +58 -0
- pertpy/metadata/__init__.py +2 -0
- pertpy/metadata/_cell_line.py +39 -70
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_drug.py +2 -6
- pertpy/metadata/_look_up.py +38 -51
- pertpy/metadata/_metadata.py +7 -10
- pertpy/metadata/_moa.py +2 -6
- pertpy/plot/__init__.py +0 -5
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +2 -3
- pertpy/tools/__init__.py +42 -4
- pertpy/tools/_augur.py +14 -15
- pertpy/tools/_cinemaot.py +2 -2
- pertpy/tools/_coda/_base_coda.py +118 -142
- pertpy/tools/_coda/_sccoda.py +16 -15
- pertpy/tools/_coda/_tasccoda.py +21 -22
- pertpy/tools/_dialogue.py +18 -23
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +21 -16
- pertpy/tools/_distances/_distances.py +406 -70
- pertpy/tools/_enrichment.py +10 -15
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +76 -53
- pertpy/tools/_mixscape.py +15 -11
- pertpy/tools/_perturbation_space/_clustering.py +5 -2
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +20 -22
- pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
- pertpy/tools/_perturbation_space/_simple.py +3 -3
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +33 -28
- pertpy/tools/_scgen/_utils.py +2 -2
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +22 -13
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -171
- pertpy/plot/_coda.py +0 -601
- pertpy/plot/_guide_rna.py +0 -64
- pertpy/plot/_milopy.py +0 -209
- pertpy/plot/_mixscape.py +0 -355
- pertpy/tools/_differential_gene_expression.py +0 -325
- pertpy-0.7.0.dist-info/RECORD +0 -53
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import pynndescent
|
5
|
+
from scipy.sparse import issparse
|
6
|
+
from scipy.sparse import vstack as sp_vstack
|
7
|
+
from sklearn.base import ClassifierMixin
|
8
|
+
from sklearn.linear_model import LogisticRegression
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from numpy.typing import NDArray
|
12
|
+
|
13
|
+
|
14
|
+
class PerturbationComparison:
|
15
|
+
"""Comparison between real and simulated perturbations."""
|
16
|
+
|
17
|
+
def compare_classification(
|
18
|
+
self,
|
19
|
+
real: np.ndarray,
|
20
|
+
simulated: np.ndarray,
|
21
|
+
control: np.ndarray,
|
22
|
+
clf: ClassifierMixin | None = None,
|
23
|
+
) -> float:
|
24
|
+
"""Compare classification accuracy between real and simulated perturbations.
|
25
|
+
|
26
|
+
Trains a classifier on the real perturbation data + the control data and reports a normalized
|
27
|
+
classification accuracy on the simulated perturbation.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
real: Real perturbed data.
|
31
|
+
simulated: Simulated perturbed data.
|
32
|
+
control: Control data
|
33
|
+
clf: sklearn classifier to use, `sklearn.linear_model.LogisticRegression` if not provided.
|
34
|
+
"""
|
35
|
+
assert real.shape[1] == simulated.shape[1] == control.shape[1]
|
36
|
+
if clf is None:
|
37
|
+
clf = LogisticRegression()
|
38
|
+
n_x = real.shape[0]
|
39
|
+
data = sp_vstack((real, control)) if issparse(real) else np.vstack((real, control))
|
40
|
+
labels = np.concatenate([np.full(real.shape[0], "comp"), np.full(control.shape[0], "ctrl")])
|
41
|
+
|
42
|
+
clf.fit(data, labels)
|
43
|
+
norm_score = clf.score(simulated, np.full(simulated.shape[0], "comp")) / clf.score(real, labels[:n_x])
|
44
|
+
norm_score = min(1.0, norm_score)
|
45
|
+
|
46
|
+
return norm_score
|
47
|
+
|
48
|
+
def compare_knn(
|
49
|
+
self,
|
50
|
+
real: np.ndarray,
|
51
|
+
simulated: np.ndarray,
|
52
|
+
control: np.ndarray | None = None,
|
53
|
+
use_simulated_for_knn: bool = False,
|
54
|
+
n_neighbors: int = 20,
|
55
|
+
random_state: int = 0,
|
56
|
+
n_jobs: int = 1,
|
57
|
+
) -> dict[str, float]:
|
58
|
+
"""Calculate proportions of real perturbed and control data points for simulated data.
|
59
|
+
|
60
|
+
Computes proportions of real perturbed, control and simulated (if `use_simulated_for_knn=True`)
|
61
|
+
data points for simulated data. If control (`C`) is not provided, builds the knn graph from
|
62
|
+
real perturbed + simulated perturbed.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
real: Real perturbed data.
|
66
|
+
simulated: Simulated perturbed data.
|
67
|
+
control: Control data
|
68
|
+
use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when
|
69
|
+
control (`control`) is provided.
|
70
|
+
n_neighbors: Number of neighbors to use in k-neighbor graph.
|
71
|
+
random_state: Random state used for k-neighbor graph construction.
|
72
|
+
n_jobs: Number of cores to use. Defaults to -1 (all).
|
73
|
+
|
74
|
+
"""
|
75
|
+
assert real.shape[1] == simulated.shape[1]
|
76
|
+
if control is not None:
|
77
|
+
assert real.shape[1] == control.shape[1]
|
78
|
+
|
79
|
+
n_y = simulated.shape[0]
|
80
|
+
|
81
|
+
if control is None:
|
82
|
+
index_data = sp_vstack((simulated, real)) if issparse(real) else np.vstack((simulated, real))
|
83
|
+
else:
|
84
|
+
datas = (simulated, real, control) if use_simulated_for_knn else (real, control)
|
85
|
+
index_data = sp_vstack(datas) if issparse(real) else np.vstack(datas)
|
86
|
+
|
87
|
+
y_in_index = use_simulated_for_knn or control is None
|
88
|
+
c_in_index = control is not None
|
89
|
+
label_groups = ["comp"]
|
90
|
+
labels: NDArray[np.str_] = np.full(index_data.shape[0], "comp")
|
91
|
+
if y_in_index:
|
92
|
+
labels[:n_y] = "siml"
|
93
|
+
label_groups.append("siml")
|
94
|
+
if c_in_index:
|
95
|
+
labels[-control.shape[0] :] = "ctrl"
|
96
|
+
label_groups.append("ctrl")
|
97
|
+
|
98
|
+
index = pynndescent.NNDescent(
|
99
|
+
index_data,
|
100
|
+
n_neighbors=max(50, n_neighbors),
|
101
|
+
random_state=random_state,
|
102
|
+
n_jobs=n_jobs,
|
103
|
+
)
|
104
|
+
indices = index.query(simulated, k=n_neighbors)[0]
|
105
|
+
|
106
|
+
uq, uq_counts = np.unique(labels[indices], return_counts=True)
|
107
|
+
uq_counts_norm = uq_counts / uq_counts.sum()
|
108
|
+
counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
|
109
|
+
for group, count_norm in zip(uq, uq_counts_norm, strict=False):
|
110
|
+
counts[group] = count_norm
|
111
|
+
|
112
|
+
return counts
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import warnings
|
4
|
-
from typing import TYPE_CHECKING, Literal
|
5
4
|
|
6
5
|
import anndata
|
7
6
|
import numpy as np
|
@@ -42,12 +41,12 @@ class LRClassifierSpace(PerturbationSpace):
|
|
42
41
|
|
43
42
|
Args:
|
44
43
|
adata: AnnData object of size cells x genes
|
45
|
-
target_col: .obs column that stores the perturbations.
|
46
|
-
layer_key: Layer in adata to use.
|
44
|
+
target_col: .obs column that stores the perturbations.
|
45
|
+
layer_key: Layer in adata to use.
|
47
46
|
embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
|
48
|
-
Can only be specified if layer_key is None.
|
49
|
-
test_split_size: Fraction of data to put in the test set.
|
50
|
-
max_iter: Maximum number of iterations taken for the solvers to converge.
|
47
|
+
Can only be specified if layer_key is None.
|
48
|
+
test_split_size: Fraction of data to put in the test set.
|
49
|
+
max_iter: Maximum number of iterations taken for the solvers to converge.
|
51
50
|
|
52
51
|
Returns:
|
53
52
|
AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
|
@@ -163,24 +162,23 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
163
162
|
|
164
163
|
Args:
|
165
164
|
adata: AnnData object of size cells x genes
|
166
|
-
target_col: .obs column that stores the perturbations.
|
167
|
-
layer_key: Layer in adata to use.
|
165
|
+
target_col: .obs column that stores the perturbations.
|
166
|
+
layer_key: Layer in adata to use.
|
168
167
|
hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
|
169
168
|
will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
|
169
|
+
dropout: Amount of dropout applied, constant for all layers.
|
170
|
+
batch_norm: Whether to apply batch normalization.
|
171
|
+
batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass.
|
174
172
|
test_split_size: Fraction of data to put in the test set. Default to 0.2.
|
175
173
|
validation_split_size: Fraction of data to put in the validation set of the resultant train set.
|
176
174
|
E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
|
177
|
-
will be used for validation.
|
178
|
-
max_epochs: Maximum number of epochs for training.
|
175
|
+
will be used for validation.
|
176
|
+
max_epochs: Maximum number of epochs for training.
|
179
177
|
val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
|
180
178
|
Note that this affects early stopping, as the model will be stopped if the validation performance does not
|
181
|
-
improve for patience epochs.
|
179
|
+
improve for patience epochs.
|
182
180
|
patience: Number of validation performance checks without improvement, after which the early stopping flag
|
183
|
-
is activated and training is therefore stopped.
|
181
|
+
is activated and training is therefore stopped.
|
184
182
|
|
185
183
|
Returns:
|
186
184
|
AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
|
@@ -325,10 +323,10 @@ class MLP(torch.nn.Module):
|
|
325
323
|
"""
|
326
324
|
Args:
|
327
325
|
sizes: size of layers.
|
328
|
-
dropout: Dropout probability.
|
329
|
-
batch_norm: specifies if batch norm should be applied.
|
330
|
-
layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
|
331
|
-
last_layer_act: activation function of last layer.
|
326
|
+
dropout: Dropout probability.
|
327
|
+
batch_norm: specifies if batch norm should be applied.
|
328
|
+
layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
|
329
|
+
last_layer_act: activation function of last layer.
|
332
330
|
"""
|
333
331
|
super().__init__()
|
334
332
|
layers = []
|
@@ -392,8 +390,8 @@ class PLDataset(Dataset):
|
|
392
390
|
"""
|
393
391
|
Args:
|
394
392
|
adata: AnnData object with observations and labels.
|
395
|
-
target_col: key with the perturbation labels numerically encoded.
|
396
|
-
label_col: key with the perturbation labels.
|
393
|
+
target_col: key with the perturbation labels numerically encoded.
|
394
|
+
label_col: key with the perturbation labels.
|
397
395
|
layer_key: key of the layer to be used as data, otherwise .X
|
398
396
|
"""
|
399
397
|
|
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
|
5
5
|
import numpy as np
|
6
6
|
import pandas as pd
|
7
7
|
from anndata import AnnData
|
8
|
-
from
|
8
|
+
from lamin_utils import logger
|
9
9
|
from rich import print
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
@@ -40,13 +40,13 @@ class PerturbationSpace:
|
|
40
40
|
|
41
41
|
Args:
|
42
42
|
adata: Anndata object of size cells x genes.
|
43
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
44
|
-
group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups.
|
45
|
-
reference_key: The key of the control values.
|
46
|
-
layer_key: Key of the AnnData layer to use for computation.
|
47
|
-
new_layer_key: the results are stored in the given layer.
|
48
|
-
embedding_key: `obsm` key of the AnnData embedding to use for computation.
|
49
|
-
new_embedding_key: Results are stored in a new embedding in `obsm` with this key.
|
43
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
44
|
+
group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups.
|
45
|
+
reference_key: The key of the control values.
|
46
|
+
layer_key: Key of the AnnData layer to use for computation.
|
47
|
+
new_layer_key: the results are stored in the given layer.
|
48
|
+
embedding_key: `obsm` key of the AnnData embedding to use for computation.
|
49
|
+
new_embedding_key: Results are stored in a new embedding in `obsm` with this key.
|
50
50
|
all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
|
51
51
|
copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
|
52
52
|
|
@@ -150,14 +150,14 @@ class PerturbationSpace:
|
|
150
150
|
ensure_consistency: bool = False,
|
151
151
|
target_col: str = "perturbation",
|
152
152
|
) -> tuple[AnnData, AnnData] | AnnData:
|
153
|
-
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality
|
153
|
+
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality.
|
154
154
|
|
155
155
|
Args:
|
156
156
|
adata: Anndata object of size n_perts x dim.
|
157
157
|
perturbations: Perturbations to add.
|
158
|
-
reference_key: perturbation source from which the perturbation summation starts.
|
158
|
+
reference_key: perturbation source from which the perturbation summation starts.
|
159
159
|
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
|
160
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
160
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
161
161
|
|
162
162
|
Returns:
|
163
163
|
Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
|
@@ -182,8 +182,8 @@ class PerturbationSpace:
|
|
182
182
|
new_pert_name += perturbation + "+"
|
183
183
|
|
184
184
|
if not ensure_consistency:
|
185
|
-
|
186
|
-
"
|
185
|
+
logger.warning(
|
186
|
+
"Operation might be done in non-consistent space (perturbation - perturbation != control). \n"
|
187
187
|
"Subtract control perturbation needed for consistency of space in all data representations. \n"
|
188
188
|
"Run with ensure_consistency=True"
|
189
189
|
)
|
@@ -264,9 +264,9 @@ class PerturbationSpace:
|
|
264
264
|
Args:
|
265
265
|
adata: Anndata object of size n_perts x dim.
|
266
266
|
perturbations: Perturbations to subtract.
|
267
|
-
reference_key: Perturbation source from which the perturbation subtraction starts.
|
267
|
+
reference_key: Perturbation source from which the perturbation subtraction starts.
|
268
268
|
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
|
269
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
269
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
270
270
|
|
271
271
|
Returns:
|
272
272
|
Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations.
|
@@ -291,8 +291,8 @@ class PerturbationSpace:
|
|
291
291
|
new_pert_name += perturbation + "-"
|
292
292
|
|
293
293
|
if not ensure_consistency:
|
294
|
-
|
295
|
-
"
|
294
|
+
logger.warning(
|
295
|
+
"Operation might be done in non-consistent space (perturbation - perturbation != control).\n"
|
296
296
|
"Subtract control perturbation needed for consistency of space in all data representations.\n"
|
297
297
|
"Run with ensure_consistency=True"
|
298
298
|
)
|
@@ -372,10 +372,10 @@ class PerturbationSpace:
|
|
372
372
|
|
373
373
|
Args:
|
374
374
|
adata: The AnnData object containing single-cell data.
|
375
|
-
column: The column name in AnnData object to perform imputation on.
|
376
|
-
target_val: The target value to impute.
|
377
|
-
n_neighbors: Number of neighbors to use for imputation.
|
378
|
-
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
|
375
|
+
column: The column name in AnnData object to perform imputation on.
|
376
|
+
target_val: The target value to impute.
|
377
|
+
n_neighbors: Number of neighbors to use for imputation.
|
378
|
+
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
|
379
379
|
|
380
380
|
Examples:
|
381
381
|
>>> import pertpy as pt
|
@@ -396,6 +396,8 @@ class PerturbationSpace:
|
|
396
396
|
|
397
397
|
embedding = adata.obsm[use_rep]
|
398
398
|
|
399
|
+
from pynndescent import NNDescent
|
400
|
+
|
399
401
|
nnd = NNDescent(embedding, n_neighbors=n_neighbors)
|
400
402
|
indices, _ = nnd.query(embedding, k=n_neighbors)
|
401
403
|
|
@@ -28,7 +28,7 @@ class CentroidSpace(PerturbationSpace):
|
|
28
28
|
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
|
29
29
|
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
|
30
30
|
keep_obs: Whether .obs columns in the input AnnData should be kept in the output pseudobulk AnnData. Only .obs columns with the same value for
|
31
|
-
each cell of one perturbation are kept.
|
31
|
+
each cell of one perturbation are kept.
|
32
32
|
|
33
33
|
Returns:
|
34
34
|
AnnData object with one observation per perturbation, storing the embedding data of the
|
@@ -129,7 +129,7 @@ class PseudobulkSpace(PerturbationSpace):
|
|
129
129
|
adata: Anndata object of size cells x genes
|
130
130
|
target_col: .obs column that stores the label of the perturbation applied to each cell.
|
131
131
|
groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation.
|
132
|
-
The summarized expression per perturbation (target_col) and group (groups_col) is computed.
|
132
|
+
The summarized expression per perturbation (target_col) and group (groups_col) is computed.
|
133
133
|
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
|
134
134
|
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
|
135
135
|
**kwargs: Are passed to decoupler's get_pseuobulk.
|
@@ -254,7 +254,7 @@ class DBSCANSpace(ClusteringSpace):
|
|
254
254
|
adata: Anndata object of size cells x genes
|
255
255
|
layer_key: If specified and exists in the adata, the clustering is done by using it. Otherwise, clustering is done with .X
|
256
256
|
embedding_key: if specified and exists in the adata, the clustering is done with that embedding. Otherwise, clustering is done with .X
|
257
|
-
cluster_key: name of the .obs column to store the cluster labels.
|
257
|
+
cluster_key: name of the .obs column to store the cluster labels.
|
258
258
|
copy: if True returns a new Anndata of same size with the new column; otherwise it updates the initial adata
|
259
259
|
return_object: if True returns the clustering object
|
260
260
|
**kwargs: Are passed to sklearn's DBSCAN.
|
pertpy/tools/_scgen/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
from pertpy.tools._scgen._scgen import
|
1
|
+
from pertpy.tools._scgen._scgen import Scgen
|
@@ -28,7 +28,7 @@ class FlaxEncoder(nn.Module):
|
|
28
28
|
|
29
29
|
Args:
|
30
30
|
x: The input data matrix.
|
31
|
-
training: Whether
|
31
|
+
training: Whether to use running training average.
|
32
32
|
|
33
33
|
Returns:
|
34
34
|
Mean and variance.
|
@@ -69,12 +69,11 @@ class FlaxDecoder(nn.Module):
|
|
69
69
|
|
70
70
|
Args:
|
71
71
|
x: Input data.
|
72
|
-
training:
|
72
|
+
training: Whether to use running training average.
|
73
73
|
|
74
74
|
Returns:
|
75
75
|
Decoded data.
|
76
76
|
"""
|
77
|
-
|
78
77
|
training = nn.merge_param("training", self.training, training)
|
79
78
|
|
80
79
|
for _ in range(self.n_layers):
|
pertpy/tools/_scgen/_scgen.py
CHANGED
@@ -10,6 +10,7 @@ import scanpy as sc
|
|
10
10
|
from adjustText import adjust_text
|
11
11
|
from anndata import AnnData
|
12
12
|
from jax import Array
|
13
|
+
from lamin_utils import logger
|
13
14
|
from scipy import stats
|
14
15
|
from scvi import REGISTRY_KEYS
|
15
16
|
from scvi.data import AnnDataManager
|
@@ -26,7 +27,7 @@ if TYPE_CHECKING:
|
|
26
27
|
font = {"family": "Arial", "size": 14}
|
27
28
|
|
28
29
|
|
29
|
-
class
|
30
|
+
class Scgen(JaxTrainingMixin, BaseModelClass):
|
30
31
|
"""Jax Implementation of scGen model for batch removal and perturbation prediction."""
|
31
32
|
|
32
33
|
def __init__(
|
@@ -49,7 +50,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
49
50
|
**model_kwargs,
|
50
51
|
)
|
51
52
|
self._model_summary_string = (
|
52
|
-
f"
|
53
|
+
f"Scgen Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
|
53
54
|
f"{dropout_rate}"
|
54
55
|
)
|
55
56
|
self.init_params_ = self._get_init_params(locals())
|
@@ -79,8 +80,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
79
80
|
Examples:
|
80
81
|
>>> import pertpy as pt
|
81
82
|
>>> data = pt.dt.kang_2018()
|
82
|
-
>>> pt.tl.
|
83
|
-
>>> model = pt.tl.
|
83
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
84
|
+
>>> model = pt.tl.Scgen(data)
|
84
85
|
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
85
86
|
>>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
|
86
87
|
"""
|
@@ -166,8 +167,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
166
167
|
Examples:
|
167
168
|
>>> import pertpy as pt
|
168
169
|
>>> data = pt.dt.kang_2018()
|
169
|
-
>>> pt.tl.
|
170
|
-
>>> model = pt.tl.
|
170
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
171
|
+
>>> model = pt.tl.Scgen(data)
|
171
172
|
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
172
173
|
>>> decoded_X = model.get_decoded_expression()
|
173
174
|
"""
|
@@ -200,8 +201,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
200
201
|
Examples:
|
201
202
|
>>> import pertpy as pt
|
202
203
|
>>> data = pt.dt.kang_2018()
|
203
|
-
>>> pt.tl.
|
204
|
-
>>> model = pt.tl.
|
204
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
205
|
+
>>> model = pt.tl.Scgen(data)
|
205
206
|
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
206
207
|
>>> corrected_adata = model.batch_removal()
|
207
208
|
"""
|
@@ -304,7 +305,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
304
305
|
Examples:
|
305
306
|
>>> import pertpy as pt
|
306
307
|
>>> data = pt.dt.kang_2018()
|
307
|
-
>>> pt.tl.
|
308
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
308
309
|
"""
|
309
310
|
setup_method_args = cls._get_setup_method_args(**locals())
|
310
311
|
anndata_fields = [
|
@@ -345,8 +346,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
345
346
|
Examples:
|
346
347
|
>>> import pertpy as pt
|
347
348
|
>>> data = pt.dt.kang_2018()
|
348
|
-
>>> pt.tl.
|
349
|
-
>>> model = pt.tl.
|
349
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
350
|
+
>>> model = pt.tl.Scgen(data)
|
350
351
|
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
351
352
|
>>> latent_X = model.get_latent_representation()
|
352
353
|
"""
|
@@ -403,19 +404,19 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
403
404
|
gene_list: list of gene names to be plotted.
|
404
405
|
show: if `True`: will show to the plot after saving it.
|
405
406
|
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
406
|
-
verbose: Specify if you want information to be printed while creating the plot
|
407
|
-
legend:
|
407
|
+
verbose: Specify if you want information to be printed while creating the plot.,
|
408
|
+
legend: Whether to plot a legend.
|
408
409
|
title: Set if you want the plot to display a title.
|
409
|
-
x_coeff: Offset to print the R^2 value in x-direction
|
410
|
-
y_coeff: Offset to print the R^2 value in y-direction
|
411
|
-
fontsize: Fontsize used for text in the plot
|
410
|
+
x_coeff: Offset to print the R^2 value in x-direction.
|
411
|
+
y_coeff: Offset to print the R^2 value in y-direction.
|
412
|
+
fontsize: Fontsize used for text in the plot.
|
412
413
|
**kwargs:
|
413
414
|
|
414
415
|
Examples:
|
415
416
|
>>> import pertpy as pt
|
416
417
|
>>> data = pt.dt.kang_2018()
|
417
|
-
>>> pt.tl.
|
418
|
-
>>> scg = pt.tl.
|
418
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
419
|
+
>>> scg = pt.tl.Scgen(data)
|
419
420
|
>>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
420
421
|
>>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
|
421
422
|
>>> pred.obs['label'] = 'pred'
|
@@ -444,12 +445,12 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
444
445
|
y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
|
445
446
|
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
446
447
|
if verbose:
|
447
|
-
|
448
|
+
logger.info("top_100 DEGs mean: ", r_value_diff**2)
|
448
449
|
x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
|
449
450
|
y = np.asarray(np.mean(stim.X, axis=0)).ravel()
|
450
451
|
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
451
452
|
if verbose:
|
452
|
-
|
453
|
+
logger.info("All genes mean: ", r_value**2)
|
453
454
|
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
454
455
|
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
455
456
|
ax.tick_params(labelsize=fontsize)
|
@@ -540,12 +541,12 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
540
541
|
gene_list: list of gene names to be plotted.
|
541
542
|
show: if `True`: will show to the plot after saving it.
|
542
543
|
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
543
|
-
legend:
|
544
|
+
legend: Whether to plot a elgend
|
544
545
|
title: Set if you want the plot to display a title.
|
545
|
-
verbose: Specify if you want information to be printed while creating the plot
|
546
|
-
x_coeff: Offset to print the R^2 value in x-direction
|
547
|
-
y_coeff: Offset to print the R^2 value in y-direction
|
548
|
-
fontsize: Fontsize used for text in the plot
|
546
|
+
verbose: Specify if you want information to be printed while creating the plot.
|
547
|
+
x_coeff: Offset to print the R^2 value in x-direction.
|
548
|
+
y_coeff: Offset to print the R^2 value in y-direction.
|
549
|
+
fontsize: Fontsize used for text in the plot.
|
549
550
|
"""
|
550
551
|
import seaborn as sns
|
551
552
|
|
@@ -566,14 +567,14 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
566
567
|
y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
|
567
568
|
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
568
569
|
if verbose:
|
569
|
-
|
570
|
+
logger.info("Top 100 DEGs var: ", r_value_diff**2)
|
570
571
|
if "y1" in axis_keys.keys():
|
571
572
|
real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
|
572
573
|
x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
|
573
574
|
y = np.asarray(np.var(stim.X, axis=0)).ravel()
|
574
575
|
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
575
576
|
if verbose:
|
576
|
-
|
577
|
+
logger.info("All genes var: ", r_value**2)
|
577
578
|
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
578
579
|
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
579
580
|
ax.tick_params(labelsize=fontsize)
|
@@ -637,7 +638,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
637
638
|
|
638
639
|
def plot_binary_classifier(
|
639
640
|
self,
|
640
|
-
scgen:
|
641
|
+
scgen: Scgen,
|
641
642
|
adata: AnnData | None,
|
642
643
|
delta: np.ndarray,
|
643
644
|
ctrl_key: str,
|
@@ -699,3 +700,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
699
700
|
if not (show or save):
|
700
701
|
return ax
|
701
702
|
return None
|
703
|
+
|
704
|
+
|
705
|
+
# compatibility
|
706
|
+
SCGEN = Scgen
|
pertpy/tools/_scgen/_utils.py
CHANGED
@@ -27,7 +27,7 @@ def extractor(
|
|
27
27
|
Example:
|
28
28
|
.. code-block:: python
|
29
29
|
|
30
|
-
import
|
30
|
+
import Scgen
|
31
31
|
import anndata
|
32
32
|
|
33
33
|
train_data = anndata.read("./data/train.h5ad")
|
@@ -58,7 +58,7 @@ def balancer(
|
|
58
58
|
Example:
|
59
59
|
.. code-block:: python
|
60
60
|
|
61
|
-
import
|
61
|
+
import Scgen
|
62
62
|
import anndata
|
63
63
|
|
64
64
|
train_data = anndata.read("./train_kang.h5ad")
|
@@ -1,11 +1,11 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: pertpy
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.8.0
|
4
4
|
Summary: Perturbation Analysis in the scverse ecosystem.
|
5
5
|
Project-URL: Documentation, https://pertpy.readthedocs.io
|
6
|
-
Project-URL: Source, https://github.com/
|
7
|
-
Project-URL: Home-page, https://github.com/
|
8
|
-
Author: Lukas Heumos, Yuge Ji, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu,
|
6
|
+
Project-URL: Source, https://github.com/scverse/pertpy
|
7
|
+
Project-URL: Home-page, https://github.com/scverse/pertpy
|
8
|
+
Author: Lukas Heumos, Yuge Ji, Lilly May, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu, Tessa Green, Stefan Peidli, Antonia Schumacher, Gregor Sturm
|
9
9
|
Maintainer-email: Lukas Heumos <lukas.heumos@posteo.net>
|
10
10
|
License: MIT License
|
11
11
|
|
@@ -45,11 +45,10 @@ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
45
45
|
Classifier: Topic :: Scientific/Engineering :: Visualization
|
46
46
|
Requires-Python: >=3.10
|
47
47
|
Requires-Dist: adjusttext
|
48
|
-
Requires-Dist: arviz
|
49
48
|
Requires-Dist: blitzgsea
|
50
49
|
Requires-Dist: decoupler
|
50
|
+
Requires-Dist: lamin-utils
|
51
51
|
Requires-Dist: muon
|
52
|
-
Requires-Dist: numpyro
|
53
52
|
Requires-Dist: openpyxl
|
54
53
|
Requires-Dist: ott-jax
|
55
54
|
Requires-Dist: pubchempy
|
@@ -61,10 +60,14 @@ Requires-Dist: scikit-misc
|
|
61
60
|
Requires-Dist: scipy
|
62
61
|
Requires-Dist: scvi-tools
|
63
62
|
Requires-Dist: sparsecca
|
64
|
-
Requires-Dist: toytree
|
65
63
|
Provides-Extra: coda
|
64
|
+
Requires-Dist: arviz; extra == 'coda'
|
66
65
|
Requires-Dist: ete3; extra == 'coda'
|
67
66
|
Requires-Dist: pyqt5; extra == 'coda'
|
67
|
+
Requires-Dist: toytree; extra == 'coda'
|
68
|
+
Provides-Extra: de
|
69
|
+
Requires-Dist: formulaic; extra == 'de'
|
70
|
+
Requires-Dist: pydeseq2; extra == 'de'
|
68
71
|
Provides-Extra: dev
|
69
72
|
Requires-Dist: pre-commit; extra == 'dev'
|
70
73
|
Provides-Extra: doc
|
@@ -94,18 +97,18 @@ Requires-Dist: pytest; extra == 'test'
|
|
94
97
|
Description-Content-Type: text/markdown
|
95
98
|
|
96
99
|
[](https://github.com/psf/black)
|
97
|
-
[](https://github.com/scverse/pertpy/actions/workflows/build.yml)
|
101
|
+
[](https://codecov.io/gh/scverse/pertpy)
|
102
|
+
[](https://opensource.org/licenses/Apache2.0)
|
100
103
|
[](https://pypi.org/project/pertpy/)
|
101
104
|
[](https://pypi.org/project/pertpy)
|
102
105
|
[](https://pertpy.readthedocs.io/)
|
103
|
-
[](https://github.com/scverse/pertpy/actions/workflows/test.yml)
|
104
107
|
[](https://github.com/pre-commit/pre-commit)
|
105
108
|
|
106
109
|
# pertpy
|
107
110
|
|
108
|
-

|
109
112
|
|
110
113
|
## Documentation
|
111
114
|
|
@@ -119,12 +122,18 @@ You can install _pertpy_ via [pip] from [PyPI]:
|
|
119
122
|
pip install pertpy
|
120
123
|
```
|
121
124
|
|
122
|
-
if you want to use scCODA please install
|
125
|
+
if you want to use scCODA or tascCODA, please install pertpy as follows:
|
123
126
|
|
124
127
|
```console
|
125
128
|
pip install pertpy[coda]
|
126
129
|
```
|
127
130
|
|
131
|
+
If you want to use the differential gene expression interface, please install pertpy by running:
|
132
|
+
|
133
|
+
```console
|
134
|
+
pip install pertpy[de]
|
135
|
+
```
|
136
|
+
|
128
137
|
[pip]: https://pip.pypa.io/
|
129
138
|
[pypi]: https://pypi.org/
|
130
139
|
[usage]: https://pertpy.readthedocs.io/en/latest/usage/usage.html
|