pertpy 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +2 -1
- pertpy/data/__init__.py +61 -0
- pertpy/data/_dataloader.py +27 -23
- pertpy/data/_datasets.py +58 -0
- pertpy/metadata/__init__.py +2 -0
- pertpy/metadata/_cell_line.py +39 -70
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_drug.py +2 -6
- pertpy/metadata/_look_up.py +38 -51
- pertpy/metadata/_metadata.py +7 -10
- pertpy/metadata/_moa.py +2 -6
- pertpy/plot/__init__.py +0 -5
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +2 -3
- pertpy/tools/__init__.py +42 -4
- pertpy/tools/_augur.py +14 -15
- pertpy/tools/_cinemaot.py +2 -2
- pertpy/tools/_coda/_base_coda.py +118 -142
- pertpy/tools/_coda/_sccoda.py +16 -15
- pertpy/tools/_coda/_tasccoda.py +21 -22
- pertpy/tools/_dialogue.py +18 -23
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +21 -16
- pertpy/tools/_distances/_distances.py +406 -70
- pertpy/tools/_enrichment.py +10 -15
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +76 -53
- pertpy/tools/_mixscape.py +15 -11
- pertpy/tools/_perturbation_space/_clustering.py +5 -2
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +20 -22
- pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
- pertpy/tools/_perturbation_space/_simple.py +3 -3
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +33 -28
- pertpy/tools/_scgen/_utils.py +2 -2
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +22 -13
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -171
- pertpy/plot/_coda.py +0 -601
- pertpy/plot/_guide_rna.py +0 -64
- pertpy/plot/_milopy.py +0 -209
- pertpy/plot/_mixscape.py +0 -355
- pertpy/tools/_differential_gene_expression.py +0 -325
- pertpy-0.7.0.dist-info/RECORD +0 -53
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import pynndescent
|
5
|
+
from scipy.sparse import issparse
|
6
|
+
from scipy.sparse import vstack as sp_vstack
|
7
|
+
from sklearn.base import ClassifierMixin
|
8
|
+
from sklearn.linear_model import LogisticRegression
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from numpy.typing import NDArray
|
12
|
+
|
13
|
+
|
14
|
+
class PerturbationComparison:
|
15
|
+
"""Comparison between real and simulated perturbations."""
|
16
|
+
|
17
|
+
def compare_classification(
|
18
|
+
self,
|
19
|
+
real: np.ndarray,
|
20
|
+
simulated: np.ndarray,
|
21
|
+
control: np.ndarray,
|
22
|
+
clf: ClassifierMixin | None = None,
|
23
|
+
) -> float:
|
24
|
+
"""Compare classification accuracy between real and simulated perturbations.
|
25
|
+
|
26
|
+
Trains a classifier on the real perturbation data + the control data and reports a normalized
|
27
|
+
classification accuracy on the simulated perturbation.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
real: Real perturbed data.
|
31
|
+
simulated: Simulated perturbed data.
|
32
|
+
control: Control data
|
33
|
+
clf: sklearn classifier to use, `sklearn.linear_model.LogisticRegression` if not provided.
|
34
|
+
"""
|
35
|
+
assert real.shape[1] == simulated.shape[1] == control.shape[1]
|
36
|
+
if clf is None:
|
37
|
+
clf = LogisticRegression()
|
38
|
+
n_x = real.shape[0]
|
39
|
+
data = sp_vstack((real, control)) if issparse(real) else np.vstack((real, control))
|
40
|
+
labels = np.concatenate([np.full(real.shape[0], "comp"), np.full(control.shape[0], "ctrl")])
|
41
|
+
|
42
|
+
clf.fit(data, labels)
|
43
|
+
norm_score = clf.score(simulated, np.full(simulated.shape[0], "comp")) / clf.score(real, labels[:n_x])
|
44
|
+
norm_score = min(1.0, norm_score)
|
45
|
+
|
46
|
+
return norm_score
|
47
|
+
|
48
|
+
def compare_knn(
|
49
|
+
self,
|
50
|
+
real: np.ndarray,
|
51
|
+
simulated: np.ndarray,
|
52
|
+
control: np.ndarray | None = None,
|
53
|
+
use_simulated_for_knn: bool = False,
|
54
|
+
n_neighbors: int = 20,
|
55
|
+
random_state: int = 0,
|
56
|
+
n_jobs: int = 1,
|
57
|
+
) -> dict[str, float]:
|
58
|
+
"""Calculate proportions of real perturbed and control data points for simulated data.
|
59
|
+
|
60
|
+
Computes proportions of real perturbed, control and simulated (if `use_simulated_for_knn=True`)
|
61
|
+
data points for simulated data. If control (`C`) is not provided, builds the knn graph from
|
62
|
+
real perturbed + simulated perturbed.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
real: Real perturbed data.
|
66
|
+
simulated: Simulated perturbed data.
|
67
|
+
control: Control data
|
68
|
+
use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when
|
69
|
+
control (`control`) is provided.
|
70
|
+
n_neighbors: Number of neighbors to use in k-neighbor graph.
|
71
|
+
random_state: Random state used for k-neighbor graph construction.
|
72
|
+
n_jobs: Number of cores to use. Defaults to -1 (all).
|
73
|
+
|
74
|
+
"""
|
75
|
+
assert real.shape[1] == simulated.shape[1]
|
76
|
+
if control is not None:
|
77
|
+
assert real.shape[1] == control.shape[1]
|
78
|
+
|
79
|
+
n_y = simulated.shape[0]
|
80
|
+
|
81
|
+
if control is None:
|
82
|
+
index_data = sp_vstack((simulated, real)) if issparse(real) else np.vstack((simulated, real))
|
83
|
+
else:
|
84
|
+
datas = (simulated, real, control) if use_simulated_for_knn else (real, control)
|
85
|
+
index_data = sp_vstack(datas) if issparse(real) else np.vstack(datas)
|
86
|
+
|
87
|
+
y_in_index = use_simulated_for_knn or control is None
|
88
|
+
c_in_index = control is not None
|
89
|
+
label_groups = ["comp"]
|
90
|
+
labels: NDArray[np.str_] = np.full(index_data.shape[0], "comp")
|
91
|
+
if y_in_index:
|
92
|
+
labels[:n_y] = "siml"
|
93
|
+
label_groups.append("siml")
|
94
|
+
if c_in_index:
|
95
|
+
labels[-control.shape[0] :] = "ctrl"
|
96
|
+
label_groups.append("ctrl")
|
97
|
+
|
98
|
+
index = pynndescent.NNDescent(
|
99
|
+
index_data,
|
100
|
+
n_neighbors=max(50, n_neighbors),
|
101
|
+
random_state=random_state,
|
102
|
+
n_jobs=n_jobs,
|
103
|
+
)
|
104
|
+
indices = index.query(simulated, k=n_neighbors)[0]
|
105
|
+
|
106
|
+
uq, uq_counts = np.unique(labels[indices], return_counts=True)
|
107
|
+
uq_counts_norm = uq_counts / uq_counts.sum()
|
108
|
+
counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
|
109
|
+
for group, count_norm in zip(uq, uq_counts_norm, strict=False):
|
110
|
+
counts[group] = count_norm
|
111
|
+
|
112
|
+
return counts
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import warnings
|
4
|
-
from typing import TYPE_CHECKING, Literal
|
5
4
|
|
6
5
|
import anndata
|
7
6
|
import numpy as np
|
@@ -42,12 +41,12 @@ class LRClassifierSpace(PerturbationSpace):
|
|
42
41
|
|
43
42
|
Args:
|
44
43
|
adata: AnnData object of size cells x genes
|
45
|
-
target_col: .obs column that stores the perturbations.
|
46
|
-
layer_key: Layer in adata to use.
|
44
|
+
target_col: .obs column that stores the perturbations.
|
45
|
+
layer_key: Layer in adata to use.
|
47
46
|
embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
|
48
|
-
Can only be specified if layer_key is None.
|
49
|
-
test_split_size: Fraction of data to put in the test set.
|
50
|
-
max_iter: Maximum number of iterations taken for the solvers to converge.
|
47
|
+
Can only be specified if layer_key is None.
|
48
|
+
test_split_size: Fraction of data to put in the test set.
|
49
|
+
max_iter: Maximum number of iterations taken for the solvers to converge.
|
51
50
|
|
52
51
|
Returns:
|
53
52
|
AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
|
@@ -163,24 +162,23 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
163
162
|
|
164
163
|
Args:
|
165
164
|
adata: AnnData object of size cells x genes
|
166
|
-
target_col: .obs column that stores the perturbations.
|
167
|
-
layer_key: Layer in adata to use.
|
165
|
+
target_col: .obs column that stores the perturbations.
|
166
|
+
layer_key: Layer in adata to use.
|
168
167
|
hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
|
169
168
|
will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
|
169
|
+
dropout: Amount of dropout applied, constant for all layers.
|
170
|
+
batch_norm: Whether to apply batch normalization.
|
171
|
+
batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass.
|
174
172
|
test_split_size: Fraction of data to put in the test set. Default to 0.2.
|
175
173
|
validation_split_size: Fraction of data to put in the validation set of the resultant train set.
|
176
174
|
E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
|
177
|
-
will be used for validation.
|
178
|
-
max_epochs: Maximum number of epochs for training.
|
175
|
+
will be used for validation.
|
176
|
+
max_epochs: Maximum number of epochs for training.
|
179
177
|
val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
|
180
178
|
Note that this affects early stopping, as the model will be stopped if the validation performance does not
|
181
|
-
improve for patience epochs.
|
179
|
+
improve for patience epochs.
|
182
180
|
patience: Number of validation performance checks without improvement, after which the early stopping flag
|
183
|
-
is activated and training is therefore stopped.
|
181
|
+
is activated and training is therefore stopped.
|
184
182
|
|
185
183
|
Returns:
|
186
184
|
AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
|
@@ -325,10 +323,10 @@ class MLP(torch.nn.Module):
|
|
325
323
|
"""
|
326
324
|
Args:
|
327
325
|
sizes: size of layers.
|
328
|
-
dropout: Dropout probability.
|
329
|
-
batch_norm: specifies if batch norm should be applied.
|
330
|
-
layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
|
331
|
-
last_layer_act: activation function of last layer.
|
326
|
+
dropout: Dropout probability.
|
327
|
+
batch_norm: specifies if batch norm should be applied.
|
328
|
+
layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
|
329
|
+
last_layer_act: activation function of last layer.
|
332
330
|
"""
|
333
331
|
super().__init__()
|
334
332
|
layers = []
|
@@ -392,8 +390,8 @@ class PLDataset(Dataset):
|
|
392
390
|
"""
|
393
391
|
Args:
|
394
392
|
adata: AnnData object with observations and labels.
|
395
|
-
target_col: key with the perturbation labels numerically encoded.
|
396
|
-
label_col: key with the perturbation labels.
|
393
|
+
target_col: key with the perturbation labels numerically encoded.
|
394
|
+
label_col: key with the perturbation labels.
|
397
395
|
layer_key: key of the layer to be used as data, otherwise .X
|
398
396
|
"""
|
399
397
|
|
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
|
5
5
|
import numpy as np
|
6
6
|
import pandas as pd
|
7
7
|
from anndata import AnnData
|
8
|
-
from
|
8
|
+
from lamin_utils import logger
|
9
9
|
from rich import print
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
@@ -40,13 +40,13 @@ class PerturbationSpace:
|
|
40
40
|
|
41
41
|
Args:
|
42
42
|
adata: Anndata object of size cells x genes.
|
43
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
44
|
-
group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups.
|
45
|
-
reference_key: The key of the control values.
|
46
|
-
layer_key: Key of the AnnData layer to use for computation.
|
47
|
-
new_layer_key: the results are stored in the given layer.
|
48
|
-
embedding_key: `obsm` key of the AnnData embedding to use for computation.
|
49
|
-
new_embedding_key: Results are stored in a new embedding in `obsm` with this key.
|
43
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
44
|
+
group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups.
|
45
|
+
reference_key: The key of the control values.
|
46
|
+
layer_key: Key of the AnnData layer to use for computation.
|
47
|
+
new_layer_key: the results are stored in the given layer.
|
48
|
+
embedding_key: `obsm` key of the AnnData embedding to use for computation.
|
49
|
+
new_embedding_key: Results are stored in a new embedding in `obsm` with this key.
|
50
50
|
all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
|
51
51
|
copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
|
52
52
|
|
@@ -150,14 +150,14 @@ class PerturbationSpace:
|
|
150
150
|
ensure_consistency: bool = False,
|
151
151
|
target_col: str = "perturbation",
|
152
152
|
) -> tuple[AnnData, AnnData] | AnnData:
|
153
|
-
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality
|
153
|
+
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality.
|
154
154
|
|
155
155
|
Args:
|
156
156
|
adata: Anndata object of size n_perts x dim.
|
157
157
|
perturbations: Perturbations to add.
|
158
|
-
reference_key: perturbation source from which the perturbation summation starts.
|
158
|
+
reference_key: perturbation source from which the perturbation summation starts.
|
159
159
|
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
|
160
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
160
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
161
161
|
|
162
162
|
Returns:
|
163
163
|
Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
|
@@ -182,8 +182,8 @@ class PerturbationSpace:
|
|
182
182
|
new_pert_name += perturbation + "+"
|
183
183
|
|
184
184
|
if not ensure_consistency:
|
185
|
-
|
186
|
-
"
|
185
|
+
logger.warning(
|
186
|
+
"Operation might be done in non-consistent space (perturbation - perturbation != control). \n"
|
187
187
|
"Subtract control perturbation needed for consistency of space in all data representations. \n"
|
188
188
|
"Run with ensure_consistency=True"
|
189
189
|
)
|
@@ -264,9 +264,9 @@ class PerturbationSpace:
|
|
264
264
|
Args:
|
265
265
|
adata: Anndata object of size n_perts x dim.
|
266
266
|
perturbations: Perturbations to subtract.
|
267
|
-
reference_key: Perturbation source from which the perturbation subtraction starts.
|
267
|
+
reference_key: Perturbation source from which the perturbation subtraction starts.
|
268
268
|
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
|
269
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
269
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
270
270
|
|
271
271
|
Returns:
|
272
272
|
Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations.
|
@@ -291,8 +291,8 @@ class PerturbationSpace:
|
|
291
291
|
new_pert_name += perturbation + "-"
|
292
292
|
|
293
293
|
if not ensure_consistency:
|
294
|
-
|
295
|
-
"
|
294
|
+
logger.warning(
|
295
|
+
"Operation might be done in non-consistent space (perturbation - perturbation != control).\n"
|
296
296
|
"Subtract control perturbation needed for consistency of space in all data representations.\n"
|
297
297
|
"Run with ensure_consistency=True"
|
298
298
|
)
|
@@ -372,10 +372,10 @@ class PerturbationSpace:
|
|
372
372
|
|
373
373
|
Args:
|
374
374
|
adata: The AnnData object containing single-cell data.
|
375
|
-
column: The column name in AnnData object to perform imputation on.
|
376
|
-
target_val: The target value to impute.
|
377
|
-
n_neighbors: Number of neighbors to use for imputation.
|
378
|
-
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
|
375
|
+
column: The column name in AnnData object to perform imputation on.
|
376
|
+
target_val: The target value to impute.
|
377
|
+
n_neighbors: Number of neighbors to use for imputation.
|
378
|
+
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
|
379
379
|
|
380
380
|
Examples:
|
381
381
|
>>> import pertpy as pt
|
@@ -396,6 +396,8 @@ class PerturbationSpace:
|
|
396
396
|
|
397
397
|
embedding = adata.obsm[use_rep]
|
398
398
|
|
399
|
+
from pynndescent import NNDescent
|
400
|
+
|
399
401
|
nnd = NNDescent(embedding, n_neighbors=n_neighbors)
|
400
402
|
indices, _ = nnd.query(embedding, k=n_neighbors)
|
401
403
|
|
@@ -28,7 +28,7 @@ class CentroidSpace(PerturbationSpace):
|
|
28
28
|
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
|
29
29
|
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
|
30
30
|
keep_obs: Whether .obs columns in the input AnnData should be kept in the output pseudobulk AnnData. Only .obs columns with the same value for
|
31
|
-
each cell of one perturbation are kept.
|
31
|
+
each cell of one perturbation are kept.
|
32
32
|
|
33
33
|
Returns:
|
34
34
|
AnnData object with one observation per perturbation, storing the embedding data of the
|
@@ -129,7 +129,7 @@ class PseudobulkSpace(PerturbationSpace):
|
|
129
129
|
adata: Anndata object of size cells x genes
|
130
130
|
target_col: .obs column that stores the label of the perturbation applied to each cell.
|
131
131
|
groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation.
|
132
|
-
The summarized expression per perturbation (target_col) and group (groups_col) is computed.
|
132
|
+
The summarized expression per perturbation (target_col) and group (groups_col) is computed.
|
133
133
|
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
|
134
134
|
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
|
135
135
|
**kwargs: Are passed to decoupler's get_pseuobulk.
|
@@ -254,7 +254,7 @@ class DBSCANSpace(ClusteringSpace):
|
|
254
254
|
adata: Anndata object of size cells x genes
|
255
255
|
layer_key: If specified and exists in the adata, the clustering is done by using it. Otherwise, clustering is done with .X
|
256
256
|
embedding_key: if specified and exists in the adata, the clustering is done with that embedding. Otherwise, clustering is done with .X
|
257
|
-
cluster_key: name of the .obs column to store the cluster labels.
|
257
|
+
cluster_key: name of the .obs column to store the cluster labels.
|
258
258
|
copy: if True returns a new Anndata of same size with the new column; otherwise it updates the initial adata
|
259
259
|
return_object: if True returns the clustering object
|
260
260
|
**kwargs: Are passed to sklearn's DBSCAN.
|
pertpy/tools/_scgen/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
from pertpy.tools._scgen._scgen import
|
1
|
+
from pertpy.tools._scgen._scgen import Scgen
|
@@ -28,7 +28,7 @@ class FlaxEncoder(nn.Module):
|
|
28
28
|
|
29
29
|
Args:
|
30
30
|
x: The input data matrix.
|
31
|
-
training: Whether
|
31
|
+
training: Whether to use running training average.
|
32
32
|
|
33
33
|
Returns:
|
34
34
|
Mean and variance.
|
@@ -69,12 +69,11 @@ class FlaxDecoder(nn.Module):
|
|
69
69
|
|
70
70
|
Args:
|
71
71
|
x: Input data.
|
72
|
-
training:
|
72
|
+
training: Whether to use running training average.
|
73
73
|
|
74
74
|
Returns:
|
75
75
|
Decoded data.
|
76
76
|
"""
|
77
|
-
|
78
77
|
training = nn.merge_param("training", self.training, training)
|
79
78
|
|
80
79
|
for _ in range(self.n_layers):
|
pertpy/tools/_scgen/_scgen.py
CHANGED
@@ -10,6 +10,7 @@ import scanpy as sc
|
|
10
10
|
from adjustText import adjust_text
|
11
11
|
from anndata import AnnData
|
12
12
|
from jax import Array
|
13
|
+
from lamin_utils import logger
|
13
14
|
from scipy import stats
|
14
15
|
from scvi import REGISTRY_KEYS
|
15
16
|
from scvi.data import AnnDataManager
|
@@ -26,7 +27,7 @@ if TYPE_CHECKING:
|
|
26
27
|
font = {"family": "Arial", "size": 14}
|
27
28
|
|
28
29
|
|
29
|
-
class
|
30
|
+
class Scgen(JaxTrainingMixin, BaseModelClass):
|
30
31
|
"""Jax Implementation of scGen model for batch removal and perturbation prediction."""
|
31
32
|
|
32
33
|
def __init__(
|
@@ -49,7 +50,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
49
50
|
**model_kwargs,
|
50
51
|
)
|
51
52
|
self._model_summary_string = (
|
52
|
-
f"
|
53
|
+
f"Scgen Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
|
53
54
|
f"{dropout_rate}"
|
54
55
|
)
|
55
56
|
self.init_params_ = self._get_init_params(locals())
|
@@ -79,8 +80,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
79
80
|
Examples:
|
80
81
|
>>> import pertpy as pt
|
81
82
|
>>> data = pt.dt.kang_2018()
|
82
|
-
>>> pt.tl.
|
83
|
-
>>> model = pt.tl.
|
83
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
84
|
+
>>> model = pt.tl.Scgen(data)
|
84
85
|
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
85
86
|
>>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells")
|
86
87
|
"""
|
@@ -166,8 +167,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
166
167
|
Examples:
|
167
168
|
>>> import pertpy as pt
|
168
169
|
>>> data = pt.dt.kang_2018()
|
169
|
-
>>> pt.tl.
|
170
|
-
>>> model = pt.tl.
|
170
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
171
|
+
>>> model = pt.tl.Scgen(data)
|
171
172
|
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
172
173
|
>>> decoded_X = model.get_decoded_expression()
|
173
174
|
"""
|
@@ -200,8 +201,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
200
201
|
Examples:
|
201
202
|
>>> import pertpy as pt
|
202
203
|
>>> data = pt.dt.kang_2018()
|
203
|
-
>>> pt.tl.
|
204
|
-
>>> model = pt.tl.
|
204
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
205
|
+
>>> model = pt.tl.Scgen(data)
|
205
206
|
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
206
207
|
>>> corrected_adata = model.batch_removal()
|
207
208
|
"""
|
@@ -304,7 +305,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
304
305
|
Examples:
|
305
306
|
>>> import pertpy as pt
|
306
307
|
>>> data = pt.dt.kang_2018()
|
307
|
-
>>> pt.tl.
|
308
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
308
309
|
"""
|
309
310
|
setup_method_args = cls._get_setup_method_args(**locals())
|
310
311
|
anndata_fields = [
|
@@ -345,8 +346,8 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
345
346
|
Examples:
|
346
347
|
>>> import pertpy as pt
|
347
348
|
>>> data = pt.dt.kang_2018()
|
348
|
-
>>> pt.tl.
|
349
|
-
>>> model = pt.tl.
|
349
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
350
|
+
>>> model = pt.tl.Scgen(data)
|
350
351
|
>>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
351
352
|
>>> latent_X = model.get_latent_representation()
|
352
353
|
"""
|
@@ -403,19 +404,19 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
403
404
|
gene_list: list of gene names to be plotted.
|
404
405
|
show: if `True`: will show to the plot after saving it.
|
405
406
|
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
406
|
-
verbose: Specify if you want information to be printed while creating the plot
|
407
|
-
legend:
|
407
|
+
verbose: Specify if you want information to be printed while creating the plot.,
|
408
|
+
legend: Whether to plot a legend.
|
408
409
|
title: Set if you want the plot to display a title.
|
409
|
-
x_coeff: Offset to print the R^2 value in x-direction
|
410
|
-
y_coeff: Offset to print the R^2 value in y-direction
|
411
|
-
fontsize: Fontsize used for text in the plot
|
410
|
+
x_coeff: Offset to print the R^2 value in x-direction.
|
411
|
+
y_coeff: Offset to print the R^2 value in y-direction.
|
412
|
+
fontsize: Fontsize used for text in the plot.
|
412
413
|
**kwargs:
|
413
414
|
|
414
415
|
Examples:
|
415
416
|
>>> import pertpy as pt
|
416
417
|
>>> data = pt.dt.kang_2018()
|
417
|
-
>>> pt.tl.
|
418
|
-
>>> scg = pt.tl.
|
418
|
+
>>> pt.tl.Scgen.setup_anndata(data, batch_key="label", labels_key="cell_type")
|
419
|
+
>>> scg = pt.tl.Scgen(data)
|
419
420
|
>>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5)
|
420
421
|
>>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells')
|
421
422
|
>>> pred.obs['label'] = 'pred'
|
@@ -444,12 +445,12 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
444
445
|
y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel()
|
445
446
|
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
446
447
|
if verbose:
|
447
|
-
|
448
|
+
logger.info("top_100 DEGs mean: ", r_value_diff**2)
|
448
449
|
x = np.asarray(np.mean(ctrl.X, axis=0)).ravel()
|
449
450
|
y = np.asarray(np.mean(stim.X, axis=0)).ravel()
|
450
451
|
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
451
452
|
if verbose:
|
452
|
-
|
453
|
+
logger.info("All genes mean: ", r_value**2)
|
453
454
|
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
454
455
|
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
455
456
|
ax.tick_params(labelsize=fontsize)
|
@@ -540,12 +541,12 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
540
541
|
gene_list: list of gene names to be plotted.
|
541
542
|
show: if `True`: will show to the plot after saving it.
|
542
543
|
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
|
543
|
-
legend:
|
544
|
+
legend: Whether to plot a elgend
|
544
545
|
title: Set if you want the plot to display a title.
|
545
|
-
verbose: Specify if you want information to be printed while creating the plot
|
546
|
-
x_coeff: Offset to print the R^2 value in x-direction
|
547
|
-
y_coeff: Offset to print the R^2 value in y-direction
|
548
|
-
fontsize: Fontsize used for text in the plot
|
546
|
+
verbose: Specify if you want information to be printed while creating the plot.
|
547
|
+
x_coeff: Offset to print the R^2 value in x-direction.
|
548
|
+
y_coeff: Offset to print the R^2 value in y-direction.
|
549
|
+
fontsize: Fontsize used for text in the plot.
|
549
550
|
"""
|
550
551
|
import seaborn as sns
|
551
552
|
|
@@ -566,14 +567,14 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
566
567
|
y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel()
|
567
568
|
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
568
569
|
if verbose:
|
569
|
-
|
570
|
+
logger.info("Top 100 DEGs var: ", r_value_diff**2)
|
570
571
|
if "y1" in axis_keys.keys():
|
571
572
|
real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
|
572
573
|
x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
|
573
574
|
y = np.asarray(np.var(stim.X, axis=0)).ravel()
|
574
575
|
m, b, r_value, p_value, std_err = stats.linregress(x, y)
|
575
576
|
if verbose:
|
576
|
-
|
577
|
+
logger.info("All genes var: ", r_value**2)
|
577
578
|
df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y})
|
578
579
|
ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df)
|
579
580
|
ax.tick_params(labelsize=fontsize)
|
@@ -637,7 +638,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
637
638
|
|
638
639
|
def plot_binary_classifier(
|
639
640
|
self,
|
640
|
-
scgen:
|
641
|
+
scgen: Scgen,
|
641
642
|
adata: AnnData | None,
|
642
643
|
delta: np.ndarray,
|
643
644
|
ctrl_key: str,
|
@@ -699,3 +700,7 @@ class SCGEN(JaxTrainingMixin, BaseModelClass):
|
|
699
700
|
if not (show or save):
|
700
701
|
return ax
|
701
702
|
return None
|
703
|
+
|
704
|
+
|
705
|
+
# compatibility
|
706
|
+
SCGEN = Scgen
|
pertpy/tools/_scgen/_utils.py
CHANGED
@@ -27,7 +27,7 @@ def extractor(
|
|
27
27
|
Example:
|
28
28
|
.. code-block:: python
|
29
29
|
|
30
|
-
import
|
30
|
+
import Scgen
|
31
31
|
import anndata
|
32
32
|
|
33
33
|
train_data = anndata.read("./data/train.h5ad")
|
@@ -58,7 +58,7 @@ def balancer(
|
|
58
58
|
Example:
|
59
59
|
.. code-block:: python
|
60
60
|
|
61
|
-
import
|
61
|
+
import Scgen
|
62
62
|
import anndata
|
63
63
|
|
64
64
|
train_data = anndata.read("./train_kang.h5ad")
|
@@ -1,11 +1,11 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: pertpy
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.8.0
|
4
4
|
Summary: Perturbation Analysis in the scverse ecosystem.
|
5
5
|
Project-URL: Documentation, https://pertpy.readthedocs.io
|
6
|
-
Project-URL: Source, https://github.com/
|
7
|
-
Project-URL: Home-page, https://github.com/
|
8
|
-
Author: Lukas Heumos, Yuge Ji, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu,
|
6
|
+
Project-URL: Source, https://github.com/scverse/pertpy
|
7
|
+
Project-URL: Home-page, https://github.com/scverse/pertpy
|
8
|
+
Author: Lukas Heumos, Yuge Ji, Lilly May, Alejandro Tejada, Johannes Köster, Emma Dann, Xinyue Zhang, Xichen Wu, Tessa Green, Stefan Peidli, Antonia Schumacher, Gregor Sturm
|
9
9
|
Maintainer-email: Lukas Heumos <lukas.heumos@posteo.net>
|
10
10
|
License: MIT License
|
11
11
|
|
@@ -45,11 +45,10 @@ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
45
45
|
Classifier: Topic :: Scientific/Engineering :: Visualization
|
46
46
|
Requires-Python: >=3.10
|
47
47
|
Requires-Dist: adjusttext
|
48
|
-
Requires-Dist: arviz
|
49
48
|
Requires-Dist: blitzgsea
|
50
49
|
Requires-Dist: decoupler
|
50
|
+
Requires-Dist: lamin-utils
|
51
51
|
Requires-Dist: muon
|
52
|
-
Requires-Dist: numpyro
|
53
52
|
Requires-Dist: openpyxl
|
54
53
|
Requires-Dist: ott-jax
|
55
54
|
Requires-Dist: pubchempy
|
@@ -61,10 +60,14 @@ Requires-Dist: scikit-misc
|
|
61
60
|
Requires-Dist: scipy
|
62
61
|
Requires-Dist: scvi-tools
|
63
62
|
Requires-Dist: sparsecca
|
64
|
-
Requires-Dist: toytree
|
65
63
|
Provides-Extra: coda
|
64
|
+
Requires-Dist: arviz; extra == 'coda'
|
66
65
|
Requires-Dist: ete3; extra == 'coda'
|
67
66
|
Requires-Dist: pyqt5; extra == 'coda'
|
67
|
+
Requires-Dist: toytree; extra == 'coda'
|
68
|
+
Provides-Extra: de
|
69
|
+
Requires-Dist: formulaic; extra == 'de'
|
70
|
+
Requires-Dist: pydeseq2; extra == 'de'
|
68
71
|
Provides-Extra: dev
|
69
72
|
Requires-Dist: pre-commit; extra == 'dev'
|
70
73
|
Provides-Extra: doc
|
@@ -94,18 +97,18 @@ Requires-Dist: pytest; extra == 'test'
|
|
94
97
|
Description-Content-Type: text/markdown
|
95
98
|
|
96
99
|
[](https://github.com/psf/black)
|
97
|
-
[](https://github.com/scverse/pertpy/actions/workflows/build.yml)
|
101
|
+
[](https://codecov.io/gh/scverse/pertpy)
|
102
|
+
[](https://opensource.org/licenses/Apache2.0)
|
100
103
|
[](https://pypi.org/project/pertpy/)
|
101
104
|
[](https://pypi.org/project/pertpy)
|
102
105
|
[](https://pertpy.readthedocs.io/)
|
103
|
-
[](https://github.com/scverse/pertpy/actions/workflows/test.yml)
|
104
107
|
[](https://github.com/pre-commit/pre-commit)
|
105
108
|
|
106
109
|
# pertpy
|
107
110
|
|
108
|
-

|
109
112
|
|
110
113
|
## Documentation
|
111
114
|
|
@@ -119,12 +122,18 @@ You can install _pertpy_ via [pip] from [PyPI]:
|
|
119
122
|
pip install pertpy
|
120
123
|
```
|
121
124
|
|
122
|
-
if you want to use scCODA please install
|
125
|
+
if you want to use scCODA or tascCODA, please install pertpy as follows:
|
123
126
|
|
124
127
|
```console
|
125
128
|
pip install pertpy[coda]
|
126
129
|
```
|
127
130
|
|
131
|
+
If you want to use the differential gene expression interface, please install pertpy by running:
|
132
|
+
|
133
|
+
```console
|
134
|
+
pip install pertpy[de]
|
135
|
+
```
|
136
|
+
|
128
137
|
[pip]: https://pip.pypa.io/
|
129
138
|
[pypi]: https://pypi.org/
|
130
139
|
[usage]: https://pertpy.readthedocs.io/en/latest/usage/usage.html
|