pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
|
5
5
|
import numpy as np
|
6
6
|
import pandas as pd
|
7
7
|
from anndata import AnnData
|
8
|
+
from lamin_utils import logger
|
8
9
|
from rich import print
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
@@ -15,7 +16,7 @@ class PerturbationSpace:
|
|
15
16
|
"""Implements various ways of interacting with PerturbationSpaces.
|
16
17
|
|
17
18
|
We differentiate between a cell space and a perturbation space.
|
18
|
-
Visually speaking, in cell spaces single
|
19
|
+
Visually speaking, in cell spaces single data points in an embeddings summarize a cell,
|
19
20
|
whereas in a perturbation space, data points summarize whole perturbations.
|
20
21
|
"""
|
21
22
|
|
@@ -25,7 +26,8 @@ class PerturbationSpace:
|
|
25
26
|
def compute_control_diff( # type: ignore
|
26
27
|
self,
|
27
28
|
adata: AnnData,
|
28
|
-
target_col: str = "
|
29
|
+
target_col: str = "perturbation",
|
30
|
+
group_col: str = None,
|
29
31
|
reference_key: str = "control",
|
30
32
|
layer_key: str = None,
|
31
33
|
new_layer_key: str = "control_diff",
|
@@ -33,26 +35,31 @@ class PerturbationSpace:
|
|
33
35
|
new_embedding_key: str = "control_diff",
|
34
36
|
all_data: bool = False,
|
35
37
|
copy: bool = False,
|
36
|
-
):
|
38
|
+
) -> AnnData:
|
37
39
|
"""Subtract mean of the control from the perturbation.
|
38
40
|
|
39
41
|
Args:
|
40
42
|
adata: Anndata object of size cells x genes.
|
41
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
43
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
44
|
+
group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups.
|
45
|
+
reference_key: The key of the control values.
|
46
|
+
layer_key: Key of the AnnData layer to use for computation.
|
47
|
+
new_layer_key: the results are stored in the given layer.
|
48
|
+
embedding_key: `obsm` key of the AnnData embedding to use for computation.
|
49
|
+
new_embedding_key: Results are stored in a new embedding in `obsm` with this key.
|
47
50
|
all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
|
48
51
|
copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
|
49
52
|
|
53
|
+
Returns:
|
54
|
+
Updated AnnData object.
|
55
|
+
|
50
56
|
Examples:
|
51
57
|
Example usage with PseudobulkSpace:
|
58
|
+
|
52
59
|
>>> import pertpy as pt
|
53
60
|
>>> mdata = pt.dt.papalexi_2021()
|
54
61
|
>>> ps = pt.tl.PseudobulkSpace()
|
55
|
-
>>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key=
|
62
|
+
>>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key="NT")
|
56
63
|
"""
|
57
64
|
if reference_key not in adata.obs[target_col].unique():
|
58
65
|
raise ValueError(
|
@@ -69,48 +76,67 @@ class PerturbationSpace:
|
|
69
76
|
adata = adata.copy()
|
70
77
|
|
71
78
|
control_mask = adata.obs[target_col] == reference_key
|
72
|
-
|
79
|
+
group_masks = (
|
80
|
+
[(adata.obs[group_col] == sample) for sample in adata.obs[group_col].unique()]
|
81
|
+
if group_col
|
82
|
+
else [[True] * adata.n_obs]
|
83
|
+
)
|
73
84
|
|
74
85
|
if layer_key:
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
86
|
+
adata.layers[new_layer_key] = np.zeros((adata.n_obs, adata.n_vars))
|
87
|
+
for mask in group_masks:
|
88
|
+
num_control = (control_mask & mask).sum()
|
89
|
+
if num_control == 1:
|
90
|
+
control_expression = adata.layers[layer_key][(control_mask & mask), :]
|
91
|
+
elif num_control > 1:
|
92
|
+
control_expression = np.mean(adata.layers[layer_key][(control_mask & mask), :], axis=0)
|
93
|
+
else:
|
94
|
+
control_expression = np.zeros((1, adata.n_vars))
|
95
|
+
adata.layers[new_layer_key][mask, :] = adata.layers[layer_key][mask, :] - control_expression
|
81
96
|
|
82
97
|
if embedding_key:
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
98
|
+
adata.obsm[new_embedding_key] = np.zeros(adata.obsm[embedding_key].shape)
|
99
|
+
for mask in group_masks:
|
100
|
+
num_control = (control_mask & mask).sum()
|
101
|
+
if num_control == 1:
|
102
|
+
control_expression = adata.obsm[embedding_key][(control_mask & mask), :]
|
103
|
+
elif num_control > 1:
|
104
|
+
control_expression = np.mean(adata.obsm[embedding_key][(control_mask & mask), :], axis=0)
|
105
|
+
else:
|
106
|
+
control_expression = np.zeros((1, adata.n_vars))
|
107
|
+
adata.obsm[new_embedding_key][mask, :] = adata.obsm[embedding_key][mask, :] - control_expression
|
89
108
|
|
90
109
|
if (not layer_key and not embedding_key) or all_data:
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
110
|
+
adata_x = np.zeros((adata.n_obs, adata.n_vars))
|
111
|
+
for mask in group_masks:
|
112
|
+
num_control = (control_mask & mask).sum()
|
113
|
+
if num_control == 1:
|
114
|
+
control_expression = adata.X[(control_mask & mask), :]
|
115
|
+
elif num_control > 1:
|
116
|
+
control_expression = np.mean(adata.X[(control_mask & mask), :], axis=0)
|
117
|
+
else:
|
118
|
+
control_expression = np.zeros((1, adata.n_vars))
|
119
|
+
adata_x[mask, :] = adata.X[mask, :] - control_expression
|
120
|
+
adata.X = adata_x
|
97
121
|
|
98
122
|
if all_data:
|
99
123
|
layers_keys = list(adata.layers.keys())
|
100
124
|
for local_layer_key in layers_keys:
|
101
125
|
if local_layer_key != layer_key and local_layer_key != new_layer_key:
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
126
|
+
adata.layers[local_layer_key + "_control_diff"] = np.zeros((adata.n_obs, adata.n_vars))
|
127
|
+
for mask in group_masks:
|
128
|
+
adata.layers[local_layer_key + "_control_diff"][mask, :] = adata.layers[local_layer_key][
|
129
|
+
mask, :
|
130
|
+
] - np.mean(adata.layers[local_layer_key][(control_mask & mask), :], axis=0)
|
106
131
|
|
107
132
|
embedding_keys = list(adata.obsm_keys())
|
108
133
|
for local_embedding_key in embedding_keys:
|
109
134
|
if local_embedding_key != embedding_key and local_embedding_key != new_embedding_key:
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
135
|
+
adata.obsm[local_embedding_key + "_control_diff"] = np.zeros(adata.obsm[local_embedding_key].shape)
|
136
|
+
for mask in group_masks:
|
137
|
+
adata.obsm[local_embedding_key + "_control_diff"][mask, :] = adata.obsm[local_embedding_key][
|
138
|
+
mask, :
|
139
|
+
] - np.mean(adata.obsm[local_embedding_key][(control_mask & mask), :], axis=0)
|
114
140
|
|
115
141
|
self.control_diff_computed = True
|
116
142
|
|
@@ -122,24 +148,30 @@ class PerturbationSpace:
|
|
122
148
|
perturbations: Iterable[str],
|
123
149
|
reference_key: str = "control",
|
124
150
|
ensure_consistency: bool = False,
|
125
|
-
target_col: str = "
|
126
|
-
):
|
127
|
-
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality
|
151
|
+
target_col: str = "perturbation",
|
152
|
+
) -> tuple[AnnData, AnnData] | AnnData:
|
153
|
+
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality.
|
128
154
|
|
129
155
|
Args:
|
130
156
|
adata: Anndata object of size n_perts x dim.
|
131
157
|
perturbations: Perturbations to add.
|
132
158
|
reference_key: perturbation source from which the perturbation summation starts.
|
133
159
|
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
|
134
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
160
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
|
164
|
+
If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
|
165
|
+
provided as input but updated using compute_control_diff.
|
135
166
|
|
136
167
|
Examples:
|
137
168
|
Example usage with PseudobulkSpace:
|
169
|
+
|
138
170
|
>>> import pertpy as pt
|
139
171
|
>>> mdata = pt.dt.papalexi_2021()
|
140
172
|
>>> ps = pt.tl.PseudobulkSpace()
|
141
173
|
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
|
142
|
-
>>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key=
|
174
|
+
>>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key="NT")
|
143
175
|
"""
|
144
176
|
new_pert_name = ""
|
145
177
|
for perturbation in perturbations:
|
@@ -150,8 +182,8 @@ class PerturbationSpace:
|
|
150
182
|
new_pert_name += perturbation + "+"
|
151
183
|
|
152
184
|
if not ensure_consistency:
|
153
|
-
|
154
|
-
"
|
185
|
+
logger.warning(
|
186
|
+
"Operation might be done in non-consistent space (perturbation - perturbation != control). \n"
|
155
187
|
"Subtract control perturbation needed for consistency of space in all data representations. \n"
|
156
188
|
"Run with ensure_consistency=True"
|
157
189
|
)
|
@@ -212,6 +244,8 @@ class PerturbationSpace:
|
|
212
244
|
key_name = key.removesuffix("_control_diff")
|
213
245
|
new_perturbation.obsm[key_name] = data["embeddings"][key]
|
214
246
|
|
247
|
+
new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")
|
248
|
+
|
215
249
|
if ensure_consistency:
|
216
250
|
return new_perturbation, adata
|
217
251
|
|
@@ -223,24 +257,30 @@ class PerturbationSpace:
|
|
223
257
|
perturbations: Iterable[str],
|
224
258
|
reference_key: str = "control",
|
225
259
|
ensure_consistency: bool = False,
|
226
|
-
target_col: str = "
|
227
|
-
):
|
260
|
+
target_col: str = "perturbation",
|
261
|
+
) -> tuple[AnnData, AnnData] | AnnData:
|
228
262
|
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
|
229
263
|
|
230
264
|
Args:
|
231
265
|
adata: Anndata object of size n_perts x dim.
|
232
|
-
perturbations: Perturbations to subtract
|
233
|
-
reference_key: Perturbation source from which the perturbation subtraction starts
|
266
|
+
perturbations: Perturbations to subtract.
|
267
|
+
reference_key: Perturbation source from which the perturbation subtraction starts.
|
234
268
|
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
|
235
|
-
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
269
|
+
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations.
|
273
|
+
If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
|
274
|
+
provided as input but updated using compute_control_diff.
|
236
275
|
|
237
276
|
Examples:
|
238
277
|
Example usage with PseudobulkSpace:
|
278
|
+
|
239
279
|
>>> import pertpy as pt
|
240
280
|
>>> mdata = pt.dt.papalexi_2021()
|
241
281
|
>>> ps = pt.tl.PseudobulkSpace()
|
242
282
|
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
|
243
|
-
>>> new_perturbation = ps.
|
283
|
+
>>> new_perturbation = ps.subtract(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"])
|
244
284
|
"""
|
245
285
|
new_pert_name = reference_key + "-"
|
246
286
|
for perturbation in perturbations:
|
@@ -251,8 +291,8 @@ class PerturbationSpace:
|
|
251
291
|
new_pert_name += perturbation + "-"
|
252
292
|
|
253
293
|
if not ensure_consistency:
|
254
|
-
|
255
|
-
"
|
294
|
+
logger.warning(
|
295
|
+
"Operation might be done in non-consistent space (perturbation - perturbation != control).\n"
|
256
296
|
"Subtract control perturbation needed for consistency of space in all data representations.\n"
|
257
297
|
"Run with ensure_consistency=True"
|
258
298
|
)
|
@@ -313,7 +353,61 @@ class PerturbationSpace:
|
|
313
353
|
key_name = key.removesuffix("_control_diff")
|
314
354
|
new_perturbation.obsm[key_name] = data["embeddings"][key]
|
315
355
|
|
356
|
+
new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")
|
357
|
+
|
316
358
|
if ensure_consistency:
|
317
359
|
return new_perturbation, adata
|
318
360
|
|
319
361
|
return new_perturbation
|
362
|
+
|
363
|
+
def label_transfer(
|
364
|
+
self,
|
365
|
+
adata: AnnData,
|
366
|
+
column: str = "perturbation",
|
367
|
+
target_val: str = "unknown",
|
368
|
+
n_neighbors: int = 5,
|
369
|
+
use_rep: str = "X_umap",
|
370
|
+
) -> None:
|
371
|
+
"""Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
adata: The AnnData object containing single-cell data.
|
375
|
+
column: The column name in AnnData object to perform imputation on.
|
376
|
+
target_val: The target value to impute.
|
377
|
+
n_neighbors: Number of neighbors to use for imputation.
|
378
|
+
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
|
379
|
+
|
380
|
+
Examples:
|
381
|
+
>>> import pertpy as pt
|
382
|
+
>>> import scanpy as sc
|
383
|
+
>>> import numpy as np
|
384
|
+
>>> adata = sc.datasets.pbmc68k_reduced()
|
385
|
+
>>> rng = np.random.default_rng()
|
386
|
+
>>> adata.obs["perturbation"] = rng.choice(
|
387
|
+
... ["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01]
|
388
|
+
... )
|
389
|
+
>>> sc.pp.neighbors(adata)
|
390
|
+
>>> sc.tl.umap(adata)
|
391
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
392
|
+
>>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap")
|
393
|
+
"""
|
394
|
+
if use_rep not in adata.obsm:
|
395
|
+
raise ValueError(f"Representation {use_rep} not found in the AnnData object.")
|
396
|
+
|
397
|
+
embedding = adata.obsm[use_rep]
|
398
|
+
|
399
|
+
from pynndescent import NNDescent
|
400
|
+
|
401
|
+
nnd = NNDescent(embedding, n_neighbors=n_neighbors)
|
402
|
+
indices, _ = nnd.query(embedding, k=n_neighbors)
|
403
|
+
|
404
|
+
perturbations = np.array(adata.obs[column])
|
405
|
+
missing_mask = perturbations == target_val
|
406
|
+
|
407
|
+
for idx in np.where(missing_mask)[0]:
|
408
|
+
neighbor_indices = indices[idx]
|
409
|
+
neighbor_categories = perturbations[neighbor_indices]
|
410
|
+
most_common = pd.Series(neighbor_categories).mode()[0]
|
411
|
+
perturbations[idx] = most_common
|
412
|
+
|
413
|
+
adata.obs[column] = perturbations
|
@@ -15,9 +15,10 @@ class CentroidSpace(PerturbationSpace):
|
|
15
15
|
def compute(
|
16
16
|
self,
|
17
17
|
adata: AnnData,
|
18
|
-
target_col: str = "
|
18
|
+
target_col: str = "perturbation",
|
19
19
|
layer_key: str = None,
|
20
20
|
embedding_key: str = "X_umap",
|
21
|
+
keep_obs: bool = True,
|
21
22
|
) -> AnnData: # type: ignore
|
22
23
|
"""Computes the centroids of a pre-computed embedding such as UMAP.
|
23
24
|
|
@@ -26,6 +27,12 @@ class CentroidSpace(PerturbationSpace):
|
|
26
27
|
target_col: .obs column that stores the label of the perturbation applied to each cell.
|
27
28
|
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
|
28
29
|
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
|
30
|
+
keep_obs: Whether .obs columns in the input AnnData should be kept in the output pseudobulk AnnData. Only .obs columns with the same value for
|
31
|
+
each cell of one perturbation are kept.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
AnnData object with one observation per perturbation, storing the embedding data of the
|
35
|
+
centroid of the respective perturbation.
|
29
36
|
|
30
37
|
Examples:
|
31
38
|
Compute the centroids of a UMAP embedding of the papalexi_2021 dataset:
|
@@ -34,7 +41,7 @@ class CentroidSpace(PerturbationSpace):
|
|
34
41
|
>>> import scanpy as sc
|
35
42
|
>>> mdata = pt.dt.papalexi_2021()
|
36
43
|
>>> sc.pp.pca(mdata["rna"])
|
37
|
-
>>> sc.pp.neighbors(mdata[
|
44
|
+
>>> sc.pp.neighbors(mdata["rna"])
|
38
45
|
>>> sc.tl.umap(mdata["rna"])
|
39
46
|
>>> cs = pt.tl.CentroidSpace()
|
40
47
|
>>> cs_adata = cs.compute(mdata["rna"], target_col="gene_target")
|
@@ -84,6 +91,22 @@ class CentroidSpace(PerturbationSpace):
|
|
84
91
|
|
85
92
|
ps_adata = AnnData(X=X)
|
86
93
|
ps_adata.obs_names = index
|
94
|
+
ps_adata.obs[target_col] = index
|
95
|
+
|
96
|
+
if embedding_key is not None:
|
97
|
+
ps_adata.obsm[embedding_key] = X
|
98
|
+
|
99
|
+
if keep_obs: # Save the values of the obs columns of interest in the ps_adata object
|
100
|
+
obs_df = adata.obs
|
101
|
+
obs_df = obs_df.groupby(target_col).agg(
|
102
|
+
lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
|
103
|
+
)
|
104
|
+
for obs_name in obs_df.columns:
|
105
|
+
if not obs_df[obs_name].isnull().values.any():
|
106
|
+
mapping = {pert: obs_df.loc[pert][obs_name] for pert in index}
|
107
|
+
ps_adata.obs[obs_name] = ps_adata.obs[target_col].map(mapping)
|
108
|
+
|
109
|
+
ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
|
87
110
|
|
88
111
|
return ps_adata
|
89
112
|
|
@@ -94,7 +117,8 @@ class PseudobulkSpace(PerturbationSpace):
|
|
94
117
|
def compute(
|
95
118
|
self,
|
96
119
|
adata: AnnData,
|
97
|
-
target_col: str = "
|
120
|
+
target_col: str = "perturbation",
|
121
|
+
groups_col: str = None,
|
98
122
|
layer_key: str = None,
|
99
123
|
embedding_key: str = None,
|
100
124
|
**kwargs,
|
@@ -104,19 +128,21 @@ class PseudobulkSpace(PerturbationSpace):
|
|
104
128
|
Args:
|
105
129
|
adata: Anndata object of size cells x genes
|
106
130
|
target_col: .obs column that stores the label of the perturbation applied to each cell.
|
131
|
+
groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation.
|
132
|
+
The summarized expression per perturbation (target_col) and group (groups_col) is computed.
|
107
133
|
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
|
108
134
|
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
|
109
135
|
**kwargs: Are passed to decoupler's get_pseuobulk.
|
110
136
|
|
137
|
+
Returns:
|
138
|
+
AnnData object with one observation per perturbation.
|
139
|
+
|
111
140
|
Examples:
|
112
|
-
>>> import pertpy as
|
141
|
+
>>> import pertpy as pt
|
113
142
|
>>> mdata = pt.dt.papalexi_2021()
|
114
143
|
>>> ps = pt.tl.PseudobulkSpace()
|
115
|
-
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target"
|
144
|
+
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target")
|
116
145
|
"""
|
117
|
-
if "groups_col" not in kwargs:
|
118
|
-
kwargs["groups_col"] = "perturbations"
|
119
|
-
|
120
146
|
if layer_key is not None and embedding_key is not None:
|
121
147
|
raise ValueError("Please, select just either layer or embedding for computation.")
|
122
148
|
|
@@ -135,7 +161,10 @@ class PseudobulkSpace(PerturbationSpace):
|
|
135
161
|
adata_emb.obs = adata.obs
|
136
162
|
adata = adata_emb
|
137
163
|
|
138
|
-
|
164
|
+
adata.obs[target_col] = adata.obs[target_col].astype("category")
|
165
|
+
ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore
|
166
|
+
|
167
|
+
ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
|
139
168
|
|
140
169
|
return ps_adata
|
141
170
|
|
@@ -164,6 +193,11 @@ class KMeansSpace(ClusteringSpace):
|
|
164
193
|
return_object: if True returns the clustering object
|
165
194
|
**kwargs: Are passed to sklearn's KMeans.
|
166
195
|
|
196
|
+
Returns:
|
197
|
+
If return_object is True, the adata and the clustering object is returned.
|
198
|
+
Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
|
199
|
+
that stores the cluster labels.
|
200
|
+
|
167
201
|
Examples:
|
168
202
|
>>> import pertpy as pt
|
169
203
|
>>> mdata = pt.dt.papalexi_2021()
|
@@ -193,6 +227,7 @@ class KMeansSpace(ClusteringSpace):
|
|
193
227
|
|
194
228
|
clustering = KMeans(**kwargs).fit(self.X)
|
195
229
|
adata.obs[cluster_key] = clustering.labels_
|
230
|
+
adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
|
196
231
|
|
197
232
|
if return_object:
|
198
233
|
return adata, clustering
|
@@ -212,18 +247,23 @@ class DBSCANSpace(ClusteringSpace):
|
|
212
247
|
copy: bool = True,
|
213
248
|
return_object: bool = False,
|
214
249
|
**kwargs,
|
215
|
-
) -> tuple[AnnData, object | AnnData
|
250
|
+
) -> tuple[AnnData, object] | AnnData:
|
216
251
|
"""Computes a clustering using Density-based spatial clustering of applications (DBSCAN).
|
217
252
|
|
218
253
|
Args:
|
219
254
|
adata: Anndata object of size cells x genes
|
220
255
|
layer_key: If specified and exists in the adata, the clustering is done by using it. Otherwise, clustering is done with .X
|
221
256
|
embedding_key: if specified and exists in the adata, the clustering is done with that embedding. Otherwise, clustering is done with .X
|
222
|
-
cluster_key: name of the .obs column to store the cluster labels.
|
257
|
+
cluster_key: name of the .obs column to store the cluster labels.
|
223
258
|
copy: if True returns a new Anndata of same size with the new column; otherwise it updates the initial adata
|
224
259
|
return_object: if True returns the clustering object
|
225
260
|
**kwargs: Are passed to sklearn's DBSCAN.
|
226
261
|
|
262
|
+
Returns:
|
263
|
+
If return_object is True, the adata and the clustering object is returned.
|
264
|
+
Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
|
265
|
+
that stores the cluster labels.
|
266
|
+
|
227
267
|
Examples:
|
228
268
|
>>> import pertpy as pt
|
229
269
|
>>> mdata = pt.dt.papalexi_2021()
|
@@ -250,6 +290,7 @@ class DBSCANSpace(ClusteringSpace):
|
|
250
290
|
|
251
291
|
clustering = DBSCAN(**kwargs).fit(self.X)
|
252
292
|
adata.obs[cluster_key] = clustering.labels_
|
293
|
+
adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
|
253
294
|
|
254
295
|
if return_object:
|
255
296
|
return adata, clustering
|
pertpy/tools/_scgen/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
from pertpy.tools._scgen.
|
1
|
+
from pertpy.tools._scgen._scgen import Scgen
|
@@ -28,7 +28,7 @@ class FlaxEncoder(nn.Module):
|
|
28
28
|
|
29
29
|
Args:
|
30
30
|
x: The input data matrix.
|
31
|
-
training: Whether
|
31
|
+
training: Whether to use running training average.
|
32
32
|
|
33
33
|
Returns:
|
34
34
|
Mean and variance.
|
@@ -69,12 +69,11 @@ class FlaxDecoder(nn.Module):
|
|
69
69
|
|
70
70
|
Args:
|
71
71
|
x: Input data.
|
72
|
-
training:
|
72
|
+
training: Whether to use running training average.
|
73
73
|
|
74
74
|
Returns:
|
75
75
|
Decoded data.
|
76
76
|
"""
|
77
|
-
|
78
77
|
training = nn.merge_param("training", self.training, training)
|
79
78
|
|
80
79
|
for _ in range(self.n_layers):
|