pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
|
5
5
|
import numpy as np
|
6
6
|
import pandas as pd
|
7
7
|
from anndata import AnnData
|
8
|
+
from lamin_utils import logger
|
8
9
|
from rich import print
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
@@ -15,7 +16,7 @@ class PerturbationSpace:
|
|
15
16
|
"""Implements various ways of interacting with PerturbationSpaces.
|
16
17
|
|
17
18
|
We differentiate between a cell space and a perturbation space.
|
18
|
-
Visually speaking, in cell spaces single
|
19
|
+
Visually speaking, in cell spaces single data points in an embeddings summarize a cell,
|
19
20
|
whereas in a perturbation space, data points summarize whole perturbations.
|
20
21
|
"""
|
21
22
|
|
@@ -25,7 +26,8 @@ class PerturbationSpace:
|
|
25
26
|
def compute_control_diff( # type: ignore
|
26
27
|
self,
|
27
28
|
adata: AnnData,
|
28
|
-
target_col: str = "
|
29
|
+
target_col: str = "perturbation",
|
30
|
+
group_col: str = None,
|
29
31
|
reference_key: str = "control",
|
30
32
|
layer_key: str = None,
|
31
33
|
new_layer_key: str = "control_diff",
|
@@ -33,26 +35,31 @@ class PerturbationSpace:
|
|
33
35
|
new_embedding_key: str = "control_diff",
|
34
36
|
all_data: bool = False,
|
35
37
|
copy: bool = False,
|
36
|
-
):
|
38
|
+
) -> AnnData:
|
37
39
|
"""Subtract mean of the control from the perturbation.
|
38
40
|
|
39
41
|
Args:
|
40
42
|
adata: Anndata object of size cells x genes.
|
41
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
43
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
44
|
+
group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups.
|
45
|
+
reference_key: The key of the control values.
|
46
|
+
layer_key: Key of the AnnData layer to use for computation.
|
47
|
+
new_layer_key: the results are stored in the given layer.
|
48
|
+
embedding_key: `obsm` key of the AnnData embedding to use for computation.
|
49
|
+
new_embedding_key: Results are stored in a new embedding in `obsm` with this key.
|
47
50
|
all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
|
48
51
|
copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
|
49
52
|
|
53
|
+
Returns:
|
54
|
+
Updated AnnData object.
|
55
|
+
|
50
56
|
Examples:
|
51
57
|
Example usage with PseudobulkSpace:
|
58
|
+
|
52
59
|
>>> import pertpy as pt
|
53
60
|
>>> mdata = pt.dt.papalexi_2021()
|
54
61
|
>>> ps = pt.tl.PseudobulkSpace()
|
55
|
-
>>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key=
|
62
|
+
>>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key="NT")
|
56
63
|
"""
|
57
64
|
if reference_key not in adata.obs[target_col].unique():
|
58
65
|
raise ValueError(
|
@@ -69,48 +76,67 @@ class PerturbationSpace:
|
|
69
76
|
adata = adata.copy()
|
70
77
|
|
71
78
|
control_mask = adata.obs[target_col] == reference_key
|
72
|
-
|
79
|
+
group_masks = (
|
80
|
+
[(adata.obs[group_col] == sample) for sample in adata.obs[group_col].unique()]
|
81
|
+
if group_col
|
82
|
+
else [[True] * adata.n_obs]
|
83
|
+
)
|
73
84
|
|
74
85
|
if layer_key:
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
86
|
+
adata.layers[new_layer_key] = np.zeros((adata.n_obs, adata.n_vars))
|
87
|
+
for mask in group_masks:
|
88
|
+
num_control = (control_mask & mask).sum()
|
89
|
+
if num_control == 1:
|
90
|
+
control_expression = adata.layers[layer_key][(control_mask & mask), :]
|
91
|
+
elif num_control > 1:
|
92
|
+
control_expression = np.mean(adata.layers[layer_key][(control_mask & mask), :], axis=0)
|
93
|
+
else:
|
94
|
+
control_expression = np.zeros((1, adata.n_vars))
|
95
|
+
adata.layers[new_layer_key][mask, :] = adata.layers[layer_key][mask, :] - control_expression
|
81
96
|
|
82
97
|
if embedding_key:
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
98
|
+
adata.obsm[new_embedding_key] = np.zeros(adata.obsm[embedding_key].shape)
|
99
|
+
for mask in group_masks:
|
100
|
+
num_control = (control_mask & mask).sum()
|
101
|
+
if num_control == 1:
|
102
|
+
control_expression = adata.obsm[embedding_key][(control_mask & mask), :]
|
103
|
+
elif num_control > 1:
|
104
|
+
control_expression = np.mean(adata.obsm[embedding_key][(control_mask & mask), :], axis=0)
|
105
|
+
else:
|
106
|
+
control_expression = np.zeros((1, adata.n_vars))
|
107
|
+
adata.obsm[new_embedding_key][mask, :] = adata.obsm[embedding_key][mask, :] - control_expression
|
89
108
|
|
90
109
|
if (not layer_key and not embedding_key) or all_data:
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
110
|
+
adata_x = np.zeros((adata.n_obs, adata.n_vars))
|
111
|
+
for mask in group_masks:
|
112
|
+
num_control = (control_mask & mask).sum()
|
113
|
+
if num_control == 1:
|
114
|
+
control_expression = adata.X[(control_mask & mask), :]
|
115
|
+
elif num_control > 1:
|
116
|
+
control_expression = np.mean(adata.X[(control_mask & mask), :], axis=0)
|
117
|
+
else:
|
118
|
+
control_expression = np.zeros((1, adata.n_vars))
|
119
|
+
adata_x[mask, :] = adata.X[mask, :] - control_expression
|
120
|
+
adata.X = adata_x
|
97
121
|
|
98
122
|
if all_data:
|
99
123
|
layers_keys = list(adata.layers.keys())
|
100
124
|
for local_layer_key in layers_keys:
|
101
125
|
if local_layer_key != layer_key and local_layer_key != new_layer_key:
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
126
|
+
adata.layers[local_layer_key + "_control_diff"] = np.zeros((adata.n_obs, adata.n_vars))
|
127
|
+
for mask in group_masks:
|
128
|
+
adata.layers[local_layer_key + "_control_diff"][mask, :] = adata.layers[local_layer_key][
|
129
|
+
mask, :
|
130
|
+
] - np.mean(adata.layers[local_layer_key][(control_mask & mask), :], axis=0)
|
106
131
|
|
107
132
|
embedding_keys = list(adata.obsm_keys())
|
108
133
|
for local_embedding_key in embedding_keys:
|
109
134
|
if local_embedding_key != embedding_key and local_embedding_key != new_embedding_key:
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
135
|
+
adata.obsm[local_embedding_key + "_control_diff"] = np.zeros(adata.obsm[local_embedding_key].shape)
|
136
|
+
for mask in group_masks:
|
137
|
+
adata.obsm[local_embedding_key + "_control_diff"][mask, :] = adata.obsm[local_embedding_key][
|
138
|
+
mask, :
|
139
|
+
] - np.mean(adata.obsm[local_embedding_key][(control_mask & mask), :], axis=0)
|
114
140
|
|
115
141
|
self.control_diff_computed = True
|
116
142
|
|
@@ -122,24 +148,30 @@ class PerturbationSpace:
|
|
122
148
|
perturbations: Iterable[str],
|
123
149
|
reference_key: str = "control",
|
124
150
|
ensure_consistency: bool = False,
|
125
|
-
target_col: str = "
|
126
|
-
):
|
127
|
-
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality
|
151
|
+
target_col: str = "perturbation",
|
152
|
+
) -> tuple[AnnData, AnnData] | AnnData:
|
153
|
+
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality.
|
128
154
|
|
129
155
|
Args:
|
130
156
|
adata: Anndata object of size n_perts x dim.
|
131
157
|
perturbations: Perturbations to add.
|
132
158
|
reference_key: perturbation source from which the perturbation summation starts.
|
133
159
|
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
|
134
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
160
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
|
164
|
+
If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
|
165
|
+
provided as input but updated using compute_control_diff.
|
135
166
|
|
136
167
|
Examples:
|
137
168
|
Example usage with PseudobulkSpace:
|
169
|
+
|
138
170
|
>>> import pertpy as pt
|
139
171
|
>>> mdata = pt.dt.papalexi_2021()
|
140
172
|
>>> ps = pt.tl.PseudobulkSpace()
|
141
173
|
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
|
142
|
-
>>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key=
|
174
|
+
>>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key="NT")
|
143
175
|
"""
|
144
176
|
new_pert_name = ""
|
145
177
|
for perturbation in perturbations:
|
@@ -150,8 +182,8 @@ class PerturbationSpace:
|
|
150
182
|
new_pert_name += perturbation + "+"
|
151
183
|
|
152
184
|
if not ensure_consistency:
|
153
|
-
|
154
|
-
"
|
185
|
+
logger.warning(
|
186
|
+
"Operation might be done in non-consistent space (perturbation - perturbation != control). \n"
|
155
187
|
"Subtract control perturbation needed for consistency of space in all data representations. \n"
|
156
188
|
"Run with ensure_consistency=True"
|
157
189
|
)
|
@@ -212,6 +244,8 @@ class PerturbationSpace:
|
|
212
244
|
key_name = key.removesuffix("_control_diff")
|
213
245
|
new_perturbation.obsm[key_name] = data["embeddings"][key]
|
214
246
|
|
247
|
+
new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")
|
248
|
+
|
215
249
|
if ensure_consistency:
|
216
250
|
return new_perturbation, adata
|
217
251
|
|
@@ -223,24 +257,30 @@ class PerturbationSpace:
|
|
223
257
|
perturbations: Iterable[str],
|
224
258
|
reference_key: str = "control",
|
225
259
|
ensure_consistency: bool = False,
|
226
|
-
target_col: str = "
|
227
|
-
):
|
260
|
+
target_col: str = "perturbation",
|
261
|
+
) -> tuple[AnnData, AnnData] | AnnData:
|
228
262
|
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
|
229
263
|
|
230
264
|
Args:
|
231
265
|
adata: Anndata object of size n_perts x dim.
|
232
|
-
perturbations: Perturbations to subtract
|
233
|
-
reference_key: Perturbation source from which the perturbation subtraction starts
|
266
|
+
perturbations: Perturbations to subtract.
|
267
|
+
reference_key: Perturbation source from which the perturbation subtraction starts.
|
234
268
|
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
|
235
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
269
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations.
|
273
|
+
If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
|
274
|
+
provided as input but updated using compute_control_diff.
|
236
275
|
|
237
276
|
Examples:
|
238
277
|
Example usage with PseudobulkSpace:
|
278
|
+
|
239
279
|
>>> import pertpy as pt
|
240
280
|
>>> mdata = pt.dt.papalexi_2021()
|
241
281
|
>>> ps = pt.tl.PseudobulkSpace()
|
242
282
|
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
|
243
|
-
>>> new_perturbation = ps.
|
283
|
+
>>> new_perturbation = ps.subtract(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"])
|
244
284
|
"""
|
245
285
|
new_pert_name = reference_key + "-"
|
246
286
|
for perturbation in perturbations:
|
@@ -251,8 +291,8 @@ class PerturbationSpace:
|
|
251
291
|
new_pert_name += perturbation + "-"
|
252
292
|
|
253
293
|
if not ensure_consistency:
|
254
|
-
|
255
|
-
"
|
294
|
+
logger.warning(
|
295
|
+
"Operation might be done in non-consistent space (perturbation - perturbation != control).\n"
|
256
296
|
"Subtract control perturbation needed for consistency of space in all data representations.\n"
|
257
297
|
"Run with ensure_consistency=True"
|
258
298
|
)
|
@@ -313,7 +353,61 @@ class PerturbationSpace:
|
|
313
353
|
key_name = key.removesuffix("_control_diff")
|
314
354
|
new_perturbation.obsm[key_name] = data["embeddings"][key]
|
315
355
|
|
356
|
+
new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")
|
357
|
+
|
316
358
|
if ensure_consistency:
|
317
359
|
return new_perturbation, adata
|
318
360
|
|
319
361
|
return new_perturbation
|
362
|
+
|
363
|
+
def label_transfer(
|
364
|
+
self,
|
365
|
+
adata: AnnData,
|
366
|
+
column: str = "perturbation",
|
367
|
+
target_val: str = "unknown",
|
368
|
+
n_neighbors: int = 5,
|
369
|
+
use_rep: str = "X_umap",
|
370
|
+
) -> None:
|
371
|
+
"""Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
adata: The AnnData object containing single-cell data.
|
375
|
+
column: The column name in AnnData object to perform imputation on.
|
376
|
+
target_val: The target value to impute.
|
377
|
+
n_neighbors: Number of neighbors to use for imputation.
|
378
|
+
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
|
379
|
+
|
380
|
+
Examples:
|
381
|
+
>>> import pertpy as pt
|
382
|
+
>>> import scanpy as sc
|
383
|
+
>>> import numpy as np
|
384
|
+
>>> adata = sc.datasets.pbmc68k_reduced()
|
385
|
+
>>> rng = np.random.default_rng()
|
386
|
+
>>> adata.obs["perturbation"] = rng.choice(
|
387
|
+
... ["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01]
|
388
|
+
... )
|
389
|
+
>>> sc.pp.neighbors(adata)
|
390
|
+
>>> sc.tl.umap(adata)
|
391
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
392
|
+
>>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap")
|
393
|
+
"""
|
394
|
+
if use_rep not in adata.obsm:
|
395
|
+
raise ValueError(f"Representation {use_rep} not found in the AnnData object.")
|
396
|
+
|
397
|
+
embedding = adata.obsm[use_rep]
|
398
|
+
|
399
|
+
from pynndescent import NNDescent
|
400
|
+
|
401
|
+
nnd = NNDescent(embedding, n_neighbors=n_neighbors)
|
402
|
+
indices, _ = nnd.query(embedding, k=n_neighbors)
|
403
|
+
|
404
|
+
perturbations = np.array(adata.obs[column])
|
405
|
+
missing_mask = perturbations == target_val
|
406
|
+
|
407
|
+
for idx in np.where(missing_mask)[0]:
|
408
|
+
neighbor_indices = indices[idx]
|
409
|
+
neighbor_categories = perturbations[neighbor_indices]
|
410
|
+
most_common = pd.Series(neighbor_categories).mode()[0]
|
411
|
+
perturbations[idx] = most_common
|
412
|
+
|
413
|
+
adata.obs[column] = perturbations
|
@@ -15,9 +15,10 @@ class CentroidSpace(PerturbationSpace):
|
|
15
15
|
def compute(
|
16
16
|
self,
|
17
17
|
adata: AnnData,
|
18
|
-
target_col: str = "
|
18
|
+
target_col: str = "perturbation",
|
19
19
|
layer_key: str = None,
|
20
20
|
embedding_key: str = "X_umap",
|
21
|
+
keep_obs: bool = True,
|
21
22
|
) -> AnnData: # type: ignore
|
22
23
|
"""Computes the centroids of a pre-computed embedding such as UMAP.
|
23
24
|
|
@@ -26,6 +27,12 @@ class CentroidSpace(PerturbationSpace):
|
|
26
27
|
target_col: .obs column that stores the label of the perturbation applied to each cell.
|
27
28
|
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
|
28
29
|
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
|
30
|
+
keep_obs: Whether .obs columns in the input AnnData should be kept in the output pseudobulk AnnData. Only .obs columns with the same value for
|
31
|
+
each cell of one perturbation are kept.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
AnnData object with one observation per perturbation, storing the embedding data of the
|
35
|
+
centroid of the respective perturbation.
|
29
36
|
|
30
37
|
Examples:
|
31
38
|
Compute the centroids of a UMAP embedding of the papalexi_2021 dataset:
|
@@ -34,7 +41,7 @@ class CentroidSpace(PerturbationSpace):
|
|
34
41
|
>>> import scanpy as sc
|
35
42
|
>>> mdata = pt.dt.papalexi_2021()
|
36
43
|
>>> sc.pp.pca(mdata["rna"])
|
37
|
-
>>> sc.pp.neighbors(mdata[
|
44
|
+
>>> sc.pp.neighbors(mdata["rna"])
|
38
45
|
>>> sc.tl.umap(mdata["rna"])
|
39
46
|
>>> cs = pt.tl.CentroidSpace()
|
40
47
|
>>> cs_adata = cs.compute(mdata["rna"], target_col="gene_target")
|
@@ -84,6 +91,22 @@ class CentroidSpace(PerturbationSpace):
|
|
84
91
|
|
85
92
|
ps_adata = AnnData(X=X)
|
86
93
|
ps_adata.obs_names = index
|
94
|
+
ps_adata.obs[target_col] = index
|
95
|
+
|
96
|
+
if embedding_key is not None:
|
97
|
+
ps_adata.obsm[embedding_key] = X
|
98
|
+
|
99
|
+
if keep_obs: # Save the values of the obs columns of interest in the ps_adata object
|
100
|
+
obs_df = adata.obs
|
101
|
+
obs_df = obs_df.groupby(target_col).agg(
|
102
|
+
lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
|
103
|
+
)
|
104
|
+
for obs_name in obs_df.columns:
|
105
|
+
if not obs_df[obs_name].isnull().values.any():
|
106
|
+
mapping = {pert: obs_df.loc[pert][obs_name] for pert in index}
|
107
|
+
ps_adata.obs[obs_name] = ps_adata.obs[target_col].map(mapping)
|
108
|
+
|
109
|
+
ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
|
87
110
|
|
88
111
|
return ps_adata
|
89
112
|
|
@@ -94,7 +117,8 @@ class PseudobulkSpace(PerturbationSpace):
|
|
94
117
|
def compute(
|
95
118
|
self,
|
96
119
|
adata: AnnData,
|
97
|
-
target_col: str = "
|
120
|
+
target_col: str = "perturbation",
|
121
|
+
groups_col: str = None,
|
98
122
|
layer_key: str = None,
|
99
123
|
embedding_key: str = None,
|
100
124
|
**kwargs,
|
@@ -104,19 +128,21 @@ class PseudobulkSpace(PerturbationSpace):
|
|
104
128
|
Args:
|
105
129
|
adata: Anndata object of size cells x genes
|
106
130
|
target_col: .obs column that stores the label of the perturbation applied to each cell.
|
131
|
+
groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation.
|
132
|
+
The summarized expression per perturbation (target_col) and group (groups_col) is computed.
|
107
133
|
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
|
108
134
|
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
|
109
135
|
**kwargs: Are passed to decoupler's get_pseuobulk.
|
110
136
|
|
137
|
+
Returns:
|
138
|
+
AnnData object with one observation per perturbation.
|
139
|
+
|
111
140
|
Examples:
|
112
|
-
>>> import pertpy as
|
141
|
+
>>> import pertpy as pt
|
113
142
|
>>> mdata = pt.dt.papalexi_2021()
|
114
143
|
>>> ps = pt.tl.PseudobulkSpace()
|
115
|
-
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target"
|
144
|
+
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target")
|
116
145
|
"""
|
117
|
-
if "groups_col" not in kwargs:
|
118
|
-
kwargs["groups_col"] = "perturbations"
|
119
|
-
|
120
146
|
if layer_key is not None and embedding_key is not None:
|
121
147
|
raise ValueError("Please, select just either layer or embedding for computation.")
|
122
148
|
|
@@ -135,7 +161,10 @@ class PseudobulkSpace(PerturbationSpace):
|
|
135
161
|
adata_emb.obs = adata.obs
|
136
162
|
adata = adata_emb
|
137
163
|
|
138
|
-
|
164
|
+
adata.obs[target_col] = adata.obs[target_col].astype("category")
|
165
|
+
ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore
|
166
|
+
|
167
|
+
ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
|
139
168
|
|
140
169
|
return ps_adata
|
141
170
|
|
@@ -164,6 +193,11 @@ class KMeansSpace(ClusteringSpace):
|
|
164
193
|
return_object: if True returns the clustering object
|
165
194
|
**kwargs: Are passed to sklearn's KMeans.
|
166
195
|
|
196
|
+
Returns:
|
197
|
+
If return_object is True, the adata and the clustering object is returned.
|
198
|
+
Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
|
199
|
+
that stores the cluster labels.
|
200
|
+
|
167
201
|
Examples:
|
168
202
|
>>> import pertpy as pt
|
169
203
|
>>> mdata = pt.dt.papalexi_2021()
|
@@ -193,6 +227,7 @@ class KMeansSpace(ClusteringSpace):
|
|
193
227
|
|
194
228
|
clustering = KMeans(**kwargs).fit(self.X)
|
195
229
|
adata.obs[cluster_key] = clustering.labels_
|
230
|
+
adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
|
196
231
|
|
197
232
|
if return_object:
|
198
233
|
return adata, clustering
|
@@ -212,18 +247,23 @@ class DBSCANSpace(ClusteringSpace):
|
|
212
247
|
copy: bool = True,
|
213
248
|
return_object: bool = False,
|
214
249
|
**kwargs,
|
215
|
-
) -> tuple[AnnData, object | AnnData
|
250
|
+
) -> tuple[AnnData, object] | AnnData:
|
216
251
|
"""Computes a clustering using Density-based spatial clustering of applications (DBSCAN).
|
217
252
|
|
218
253
|
Args:
|
219
254
|
adata: Anndata object of size cells x genes
|
220
255
|
layer_key: If specified and exists in the adata, the clustering is done by using it. Otherwise, clustering is done with .X
|
221
256
|
embedding_key: if specified and exists in the adata, the clustering is done with that embedding. Otherwise, clustering is done with .X
|
222
|
-
cluster_key: name of the .obs column to store the cluster labels.
|
257
|
+
cluster_key: name of the .obs column to store the cluster labels.
|
223
258
|
copy: if True returns a new Anndata of same size with the new column; otherwise it updates the initial adata
|
224
259
|
return_object: if True returns the clustering object
|
225
260
|
**kwargs: Are passed to sklearn's DBSCAN.
|
226
261
|
|
262
|
+
Returns:
|
263
|
+
If return_object is True, the adata and the clustering object is returned.
|
264
|
+
Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
|
265
|
+
that stores the cluster labels.
|
266
|
+
|
227
267
|
Examples:
|
228
268
|
>>> import pertpy as pt
|
229
269
|
>>> mdata = pt.dt.papalexi_2021()
|
@@ -250,6 +290,7 @@ class DBSCANSpace(ClusteringSpace):
|
|
250
290
|
|
251
291
|
clustering = DBSCAN(**kwargs).fit(self.X)
|
252
292
|
adata.obs[cluster_key] = clustering.labels_
|
293
|
+
adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
|
253
294
|
|
254
295
|
if return_object:
|
255
296
|
return adata, clustering
|
pertpy/tools/_scgen/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
from pertpy.tools._scgen.
|
1
|
+
from pertpy.tools._scgen._scgen import Scgen
|
@@ -28,7 +28,7 @@ class FlaxEncoder(nn.Module):
|
|
28
28
|
|
29
29
|
Args:
|
30
30
|
x: The input data matrix.
|
31
|
-
training: Whether
|
31
|
+
training: Whether to use running training average.
|
32
32
|
|
33
33
|
Returns:
|
34
34
|
Mean and variance.
|
@@ -69,12 +69,11 @@ class FlaxDecoder(nn.Module):
|
|
69
69
|
|
70
70
|
Args:
|
71
71
|
x: Input data.
|
72
|
-
training:
|
72
|
+
training: Whether to use running training average.
|
73
73
|
|
74
74
|
Returns:
|
75
75
|
Decoded data.
|
76
76
|
"""
|
77
|
-
|
78
77
|
training = nn.merge_param("training", self.training, training)
|
79
78
|
|
80
79
|
for _ in range(self.n_layers):
|