pertpy 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +5 -1
- pertpy/_doc.py +1 -3
- pertpy/_types.py +6 -0
- pertpy/data/_dataloader.py +68 -24
- pertpy/data/_datasets.py +9 -9
- pertpy/metadata/__init__.py +2 -1
- pertpy/metadata/_cell_line.py +133 -25
- pertpy/metadata/_look_up.py +13 -19
- pertpy/metadata/_moa.py +1 -1
- pertpy/preprocessing/_guide_rna.py +138 -44
- pertpy/preprocessing/_guide_rna_mixture.py +17 -19
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +106 -98
- pertpy/tools/_cinemaot.py +74 -114
- pertpy/tools/_coda/_base_coda.py +129 -145
- pertpy/tools/_coda/_sccoda.py +66 -69
- pertpy/tools/_coda/_tasccoda.py +71 -79
- pertpy/tools/_dialogue.py +48 -40
- pertpy/tools/_differential_gene_expression/_base.py +21 -31
- pertpy/tools/_differential_gene_expression/_checks.py +4 -6
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
- pertpy/tools/_differential_gene_expression/_edger.py +6 -10
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
- pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
- pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
- pertpy/tools/_distances/_distance_tests.py +1 -2
- pertpy/tools/_distances/_distances.py +31 -45
- pertpy/tools/_enrichment.py +7 -22
- pertpy/tools/_milo.py +19 -15
- pertpy/tools/_mixscape.py +73 -75
- pertpy/tools/_perturbation_space/_clustering.py +4 -4
- pertpy/tools/_perturbation_space/_comparison.py +4 -4
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
- pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
- pertpy/tools/_perturbation_space/_simple.py +12 -14
- pertpy/tools/_scgen/_scgen.py +16 -17
- pertpy/tools/_scgen/_scgenvae.py +2 -2
- pertpy/tools/_scgen/_utils.py +3 -1
- {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/METADATA +36 -20
- pertpy-0.11.0.dist-info/RECORD +58 -0
- {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
- pertpy/tools/_kernel_pca.py +0 -50
- pertpy-0.10.0.dist-info/RECORD +0 -58
- {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
@@ -5,10 +5,10 @@ import warnings
|
|
5
5
|
import anndata
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
-
import pytorch_lightning as pl
|
9
8
|
import scipy
|
10
9
|
import torch
|
11
10
|
from anndata import AnnData
|
11
|
+
from pytorch_lightning import LightningModule, Trainer
|
12
12
|
from pytorch_lightning.callbacks import EarlyStopping
|
13
13
|
from sklearn.linear_model import LogisticRegression
|
14
14
|
from sklearn.model_selection import train_test_split
|
@@ -35,9 +35,7 @@ class LRClassifierSpace(PerturbationSpace):
|
|
35
35
|
test_split_size: float = 0.2,
|
36
36
|
max_iter: int = 1000,
|
37
37
|
):
|
38
|
-
"""
|
39
|
-
Fits a logistic regression model to the data and takes the coefficients of the logistic regression
|
40
|
-
model as perturbation embedding.
|
38
|
+
"""Fits a logistic regression model to the data and takes the coefficients of the logistic regression model as perturbation embedding.
|
41
39
|
|
42
40
|
Args:
|
43
41
|
adata: AnnData object of size cells x genes
|
@@ -60,7 +58,7 @@ class LRClassifierSpace(PerturbationSpace):
|
|
60
58
|
if layer_key is not None and layer_key not in adata.obs.columns:
|
61
59
|
raise ValueError(f"Layer key {layer_key} not found in adata.")
|
62
60
|
|
63
|
-
if embedding_key is not None and embedding_key not in adata.obsm
|
61
|
+
if embedding_key is not None and embedding_key not in adata.obsm:
|
64
62
|
raise ValueError(f"Embedding key {embedding_key} not found in adata.obsm.")
|
65
63
|
|
66
64
|
if layer_key is not None and embedding_key is not None:
|
@@ -207,7 +205,7 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
207
205
|
adata.obs["encoded_perturbations"] = [np.float32(label) for label in encoded_labels]
|
208
206
|
|
209
207
|
# Split the data in train, test and validation
|
210
|
-
X = list(range(
|
208
|
+
X = list(range(adata.n_obs))
|
211
209
|
y = adata.obs[target_col]
|
212
210
|
|
213
211
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
|
@@ -248,7 +246,7 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
248
246
|
# Save adata observations for embedding annotations in get_embeddings
|
249
247
|
self.adata_obs = adata.obs.reset_index(drop=True)
|
250
248
|
|
251
|
-
self.trainer =
|
249
|
+
self.trainer = Trainer(
|
252
250
|
min_epochs=1,
|
253
251
|
max_epochs=max_epochs,
|
254
252
|
check_val_every_n_epoch=val_epochs_check,
|
@@ -273,7 +271,7 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
273
271
|
if dataset_count == 0:
|
274
272
|
pert_adata = batch_adata
|
275
273
|
else:
|
276
|
-
pert_adata = anndata.concat([pert_adata, batch_adata])
|
274
|
+
pert_adata = batch_adata if dataset_count == 0 else anndata.concat([pert_adata, batch_adata])
|
277
275
|
|
278
276
|
# Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
|
279
277
|
# and we can just add the annotations from the original AnnData object
|
@@ -308,9 +306,7 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
308
306
|
|
309
307
|
|
310
308
|
class MLP(torch.nn.Module):
|
311
|
-
"""
|
312
|
-
A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
|
313
|
-
"""
|
309
|
+
"""A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm."""
|
314
310
|
|
315
311
|
def __init__(
|
316
312
|
self,
|
@@ -320,7 +316,8 @@ class MLP(torch.nn.Module):
|
|
320
316
|
layer_norm: bool = False,
|
321
317
|
last_layer_act: str = "linear",
|
322
318
|
) -> None:
|
323
|
-
"""
|
319
|
+
"""Multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
|
320
|
+
|
324
321
|
Args:
|
325
322
|
sizes: size of layers.
|
326
323
|
dropout: Dropout probability.
|
@@ -375,8 +372,8 @@ def init_weights(m):
|
|
375
372
|
|
376
373
|
|
377
374
|
class PLDataset(Dataset):
|
378
|
-
"""
|
379
|
-
|
375
|
+
"""Dataset for perturbation classification.
|
376
|
+
|
380
377
|
Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
|
381
378
|
"""
|
382
379
|
|
@@ -387,14 +384,14 @@ class PLDataset(Dataset):
|
|
387
384
|
label_col: str = "perturbations",
|
388
385
|
layer_key: str = None,
|
389
386
|
):
|
390
|
-
"""
|
387
|
+
"""PyTorch lightning Dataset for perturbation classification.
|
388
|
+
|
391
389
|
Args:
|
392
390
|
adata: AnnData object with observations and labels.
|
393
391
|
target_col: key with the perturbation labels numerically encoded.
|
394
392
|
label_col: key with the perturbation labels.
|
395
|
-
layer_key: key of the layer to be used as data, otherwise .X
|
393
|
+
layer_key: key of the layer to be used as data, otherwise .X.
|
396
394
|
"""
|
397
|
-
|
398
395
|
if layer_key:
|
399
396
|
self.data = adata.layers[layer_key]
|
400
397
|
else:
|
@@ -407,7 +404,7 @@ class PLDataset(Dataset):
|
|
407
404
|
return self.data.shape[0]
|
408
405
|
|
409
406
|
def __getitem__(self, idx):
|
410
|
-
"""Returns a sample and corresponding perturbations applied (labels)"""
|
407
|
+
"""Returns a sample and corresponding perturbations applied (labels)."""
|
411
408
|
sample = self.data[idx].toarray().squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
|
412
409
|
num_label = self.labels.iloc[idx]
|
413
410
|
str_label = self.pert_labels.iloc[idx]
|
@@ -415,7 +412,7 @@ class PLDataset(Dataset):
|
|
415
412
|
return sample, num_label, str_label
|
416
413
|
|
417
414
|
|
418
|
-
class PerturbationClassifier(
|
415
|
+
class PerturbationClassifier(LightningModule):
|
419
416
|
def __init__(
|
420
417
|
self,
|
421
418
|
model: torch.nn.Module,
|
@@ -428,7 +425,8 @@ class PerturbationClassifier(pl.LightningModule):
|
|
428
425
|
lr=1e-4,
|
429
426
|
seed=42,
|
430
427
|
):
|
431
|
-
"""
|
428
|
+
"""Perturbation Classifier.
|
429
|
+
|
432
430
|
Args:
|
433
431
|
model: model to be trained
|
434
432
|
batch_size: batch size
|
@@ -438,7 +436,7 @@ class PerturbationClassifier(pl.LightningModule):
|
|
438
436
|
layer_norm: whether to apply layer norm
|
439
437
|
last_layer_act: activation function of last layer
|
440
438
|
lr: learning rate
|
441
|
-
seed: random seed
|
439
|
+
seed: random seed.
|
442
440
|
"""
|
443
441
|
super().__init__()
|
444
442
|
self.batch_size = batch_size
|
@@ -457,16 +455,37 @@ class PerturbationClassifier(pl.LightningModule):
|
|
457
455
|
last_layer_act=self.hparams.last_layer_act,
|
458
456
|
)
|
459
457
|
|
460
|
-
def forward(self, x):
|
458
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
459
|
+
"""Forward pass through the network.
|
460
|
+
|
461
|
+
Args:
|
462
|
+
x: Input tensor
|
463
|
+
|
464
|
+
Returns:
|
465
|
+
Network output tensor
|
466
|
+
"""
|
461
467
|
x = self.net(x)
|
462
468
|
return x
|
463
469
|
|
464
|
-
def configure_optimizers(self):
|
465
|
-
optimizer
|
470
|
+
def configure_optimizers(self) -> optim.Adam:
|
471
|
+
"""Configure optimizer for the model.
|
466
472
|
|
473
|
+
Returns:
|
474
|
+
Adam optimizer with weight decay
|
475
|
+
"""
|
476
|
+
optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
|
467
477
|
return optimizer
|
468
478
|
|
469
|
-
def training_step(self, batch, batch_idx):
|
479
|
+
def training_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
480
|
+
"""Perform a training step.
|
481
|
+
|
482
|
+
Args:
|
483
|
+
batch: Tuple of (input, target, metadata)
|
484
|
+
batch_idx: Index of the current batch
|
485
|
+
|
486
|
+
Returns:
|
487
|
+
Loss value
|
488
|
+
"""
|
470
489
|
x, y, _ = batch
|
471
490
|
x = x.to(torch.float32)
|
472
491
|
|
@@ -480,7 +499,16 @@ class PerturbationClassifier(pl.LightningModule):
|
|
480
499
|
|
481
500
|
return loss
|
482
501
|
|
483
|
-
def validation_step(self, batch, batch_idx):
|
502
|
+
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
503
|
+
"""Perform a validation step.
|
504
|
+
|
505
|
+
Args:
|
506
|
+
batch: Tuple of (input, target, metadata)
|
507
|
+
batch_idx: Index of the current batch
|
508
|
+
|
509
|
+
Returns:
|
510
|
+
Loss value
|
511
|
+
"""
|
484
512
|
x, y, _ = batch
|
485
513
|
x = x.to(torch.float32)
|
486
514
|
|
@@ -494,7 +522,16 @@ class PerturbationClassifier(pl.LightningModule):
|
|
494
522
|
|
495
523
|
return loss
|
496
524
|
|
497
|
-
def test_step(self, batch, batch_idx):
|
525
|
+
def test_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
526
|
+
"""Perform a test step.
|
527
|
+
|
528
|
+
Args:
|
529
|
+
batch: Tuple of (input, target, metadata)
|
530
|
+
batch_idx: Index of the current batch
|
531
|
+
|
532
|
+
Returns:
|
533
|
+
Loss value
|
534
|
+
"""
|
498
535
|
x, y, _ = batch
|
499
536
|
x = x.to(torch.float32)
|
500
537
|
|
@@ -508,15 +545,29 @@ class PerturbationClassifier(pl.LightningModule):
|
|
508
545
|
|
509
546
|
return loss
|
510
547
|
|
511
|
-
def embedding(self, x):
|
512
|
-
"""
|
513
|
-
|
514
|
-
|
548
|
+
def embedding(self, x: torch.Tensor) -> torch.Tensor:
|
549
|
+
"""Extract embeddings from input features.
|
550
|
+
|
551
|
+
Args:
|
552
|
+
x: Input tensor of shape [Batch, SeqLen, 1]
|
553
|
+
|
554
|
+
Returns:
|
555
|
+
Embedded representation of the input
|
515
556
|
"""
|
516
557
|
x = self.net.embedding(x)
|
517
558
|
return x
|
518
559
|
|
519
|
-
def get_embeddings(
|
560
|
+
def get_embeddings(
|
561
|
+
self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
562
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
563
|
+
"""Extract embeddings from a batch.
|
564
|
+
|
565
|
+
Args:
|
566
|
+
batch: Tuple of (input, target, metadata)
|
567
|
+
|
568
|
+
Returns:
|
569
|
+
Tuple of (embeddings, metadata)
|
570
|
+
"""
|
520
571
|
x, _, y = batch
|
521
572
|
x = x.to(torch.float32)
|
522
573
|
|
@@ -70,7 +70,7 @@ class PerturbationSpace:
|
|
70
70
|
if embedding_key is not None and embedding_key not in adata.obsm_keys():
|
71
71
|
raise ValueError(f"Embedding key {embedding_key} not found in obsm keys of the anndata.")
|
72
72
|
|
73
|
-
if layer_key is not None and layer_key not in adata.layers
|
73
|
+
if layer_key is not None and layer_key not in adata.layers:
|
74
74
|
raise ValueError(f"Layer {layer_key!r} does not exist in the anndata.")
|
75
75
|
|
76
76
|
if copy:
|
@@ -123,7 +123,7 @@ class PerturbationSpace:
|
|
123
123
|
if all_data:
|
124
124
|
layers_keys = list(adata.layers.keys())
|
125
125
|
for local_layer_key in layers_keys:
|
126
|
-
if local_layer_key
|
126
|
+
if local_layer_key not in (layer_key, new_layer_key):
|
127
127
|
adata.layers[local_layer_key + "_control_diff"] = np.zeros((adata.n_obs, adata.n_vars))
|
128
128
|
for mask in group_masks:
|
129
129
|
adata.layers[local_layer_key + "_control_diff"][mask, :] = adata.layers[local_layer_key][
|
@@ -132,7 +132,7 @@ class PerturbationSpace:
|
|
132
132
|
|
133
133
|
embedding_keys = list(adata.obsm_keys())
|
134
134
|
for local_embedding_key in embedding_keys:
|
135
|
-
if local_embedding_key
|
135
|
+
if local_embedding_key not in (embedding_key, new_embedding_key):
|
136
136
|
adata.obsm[local_embedding_key + "_control_diff"] = np.zeros(adata.obsm[local_embedding_key].shape)
|
137
137
|
for mask in group_masks:
|
138
138
|
adata.obsm[local_embedding_key + "_control_diff"][mask, :] = adata.obsm[local_embedding_key][
|
@@ -193,7 +193,7 @@ class PerturbationSpace:
|
|
193
193
|
|
194
194
|
data: dict[str, np.array] = {}
|
195
195
|
|
196
|
-
for local_layer_key in adata.layers
|
196
|
+
for local_layer_key in adata.layers:
|
197
197
|
data["layers"] = {}
|
198
198
|
control_local = adata[reference_key].layers[local_layer_key].copy()
|
199
199
|
for perturbation in perturbations:
|
@@ -231,14 +231,14 @@ class PerturbationSpace:
|
|
231
231
|
new_obs.loc[new_pert_name[:-1]] = new_pert_obs
|
232
232
|
new_perturbation.obs = new_obs
|
233
233
|
|
234
|
-
if "layers" in data
|
234
|
+
if "layers" in data:
|
235
235
|
for key in data["layers"]:
|
236
236
|
key_name = key
|
237
237
|
if key.endswith("_control_diff"):
|
238
238
|
key_name = key.removesuffix("_control_diff")
|
239
239
|
new_perturbation.layers[key_name] = data["layers"][key]
|
240
240
|
|
241
|
-
if "embeddings" in data
|
241
|
+
if "embeddings" in data:
|
242
242
|
key_name = key
|
243
243
|
for key in data["embeddings"]:
|
244
244
|
if key.endswith("_control_diff"):
|
@@ -260,7 +260,7 @@ class PerturbationSpace:
|
|
260
260
|
ensure_consistency: bool = False,
|
261
261
|
target_col: str = "perturbation",
|
262
262
|
) -> tuple[AnnData, AnnData] | AnnData:
|
263
|
-
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
|
263
|
+
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality.
|
264
264
|
|
265
265
|
Args:
|
266
266
|
adata: Anndata object of size n_perts x dim.
|
@@ -302,7 +302,7 @@ class PerturbationSpace:
|
|
302
302
|
|
303
303
|
data: dict[str, np.array] = {}
|
304
304
|
|
305
|
-
for local_layer_key in adata.layers
|
305
|
+
for local_layer_key in adata.layers:
|
306
306
|
data["layers"] = {}
|
307
307
|
control_local = adata[reference_key].layers[local_layer_key].copy()
|
308
308
|
for perturbation in perturbations:
|
@@ -340,14 +340,14 @@ class PerturbationSpace:
|
|
340
340
|
new_obs.loc[new_pert_name[:-1]] = new_pert_obs
|
341
341
|
new_perturbation.obs = new_obs
|
342
342
|
|
343
|
-
if "layers" in data
|
343
|
+
if "layers" in data:
|
344
344
|
for key in data["layers"]:
|
345
345
|
key_name = key
|
346
346
|
if key.endswith("_control_diff"):
|
347
347
|
key_name = key.removesuffix("_control_diff")
|
348
348
|
new_perturbation.layers[key_name] = data["layers"][key]
|
349
349
|
|
350
|
-
if "embeddings" in data
|
350
|
+
if "embeddings" in data:
|
351
351
|
key_name = key
|
352
352
|
for key in data["embeddings"]:
|
353
353
|
if key.endswith("_control_diff"):
|
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
|
-
import decoupler as dc
|
6
5
|
import matplotlib.pyplot as plt
|
7
6
|
import numpy as np
|
8
7
|
from anndata import AnnData
|
8
|
+
from decoupler import get_pseudobulk as dc_get_pseudobulk
|
9
|
+
from decoupler import plot_psbulk_samples as dc_plot_psbulk_samples
|
9
10
|
from sklearn.cluster import DBSCAN, KMeans
|
10
11
|
|
11
12
|
from pertpy._doc import _doc_params, doc_common_plot_args
|
@@ -53,7 +54,6 @@ class CentroidSpace(PerturbationSpace):
|
|
53
54
|
>>> cs = pt.tl.CentroidSpace()
|
54
55
|
>>> cs_adata = cs.compute(mdata["rna"], target_col="gene_target")
|
55
56
|
"""
|
56
|
-
|
57
57
|
X = None
|
58
58
|
if layer_key is not None and embedding_key is not None:
|
59
59
|
raise ValueError("Please, select just either layer or embedding for computation.")
|
@@ -65,7 +65,7 @@ class CentroidSpace(PerturbationSpace):
|
|
65
65
|
X = np.empty((len(adata.obs[target_col].unique()), adata.obsm[embedding_key].shape[1]))
|
66
66
|
|
67
67
|
if layer_key is not None:
|
68
|
-
if layer_key not in adata.layers
|
68
|
+
if layer_key not in adata.layers:
|
69
69
|
raise ValueError(f"Layer {layer_key!r} does not exist in the .layers attribute.")
|
70
70
|
else:
|
71
71
|
X = np.empty((len(adata.obs[target_col].unique()), adata.layers[layer_key].shape[1]))
|
@@ -79,8 +79,7 @@ class CentroidSpace(PerturbationSpace):
|
|
79
79
|
X = np.empty((len(adata.obs[target_col].unique()), adata.obsm[embedding_key].shape[1]))
|
80
80
|
|
81
81
|
index = []
|
82
|
-
pert_index
|
83
|
-
for group_name, group_data in grouped:
|
82
|
+
for pert_index, (group_name, group_data) in enumerate(grouped):
|
84
83
|
indices = group_data.index
|
85
84
|
if layer_key is not None:
|
86
85
|
points = adata[indices].layers[layer_key]
|
@@ -94,7 +93,6 @@ class CentroidSpace(PerturbationSpace):
|
|
94
93
|
points, key=lambda point: np.linalg.norm(point - centroid)
|
95
94
|
) # Find the point in the array closest to the centroid
|
96
95
|
X[pert_index, :] = closest_point
|
97
|
-
pert_index += 1
|
98
96
|
|
99
97
|
ps_adata = AnnData(X=X)
|
100
98
|
ps_adata.obs_names = index
|
@@ -153,7 +151,7 @@ class PseudobulkSpace(PerturbationSpace):
|
|
153
151
|
if layer_key is not None and embedding_key is not None:
|
154
152
|
raise ValueError("Please, select just either layer or embedding for computation.")
|
155
153
|
|
156
|
-
if layer_key is not None and layer_key not in adata.layers
|
154
|
+
if layer_key is not None and layer_key not in adata.layers:
|
157
155
|
raise ValueError(f"Layer {layer_key!r} does not exist in the .layers attribute.")
|
158
156
|
|
159
157
|
if target_col not in adata.obs:
|
@@ -169,14 +167,14 @@ class PseudobulkSpace(PerturbationSpace):
|
|
169
167
|
adata = adata_emb
|
170
168
|
|
171
169
|
adata.obs[target_col] = adata.obs[target_col].astype("category")
|
172
|
-
ps_adata =
|
170
|
+
ps_adata = dc_get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore
|
173
171
|
|
174
172
|
ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
|
175
173
|
|
176
174
|
return ps_adata
|
177
175
|
|
178
176
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
179
|
-
def plot_psbulk_samples(
|
177
|
+
def plot_psbulk_samples( # pragma: no cover # noqa: D417
|
180
178
|
self,
|
181
179
|
adata: AnnData,
|
182
180
|
groupby: str,
|
@@ -209,7 +207,7 @@ class PseudobulkSpace(PerturbationSpace):
|
|
209
207
|
Preview:
|
210
208
|
.. image:: /_static/docstring_previews/pseudobulk_samples.png
|
211
209
|
"""
|
212
|
-
fig =
|
210
|
+
fig = dc_plot_psbulk_samples(adata, groupby, return_fig=True, **kwargs)
|
213
211
|
|
214
212
|
if return_fig:
|
215
213
|
return fig
|
@@ -244,7 +242,7 @@ class KMeansSpace(ClusteringSpace):
|
|
244
242
|
Returns:
|
245
243
|
If return_object is True, the adata and the clustering object is returned.
|
246
244
|
Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
|
247
|
-
|
245
|
+
that stores the cluster labels.
|
248
246
|
|
249
247
|
Examples:
|
250
248
|
>>> import pertpy as pt
|
@@ -265,7 +263,7 @@ class KMeansSpace(ClusteringSpace):
|
|
265
263
|
self.X = adata.obsm[embedding_key]
|
266
264
|
|
267
265
|
elif layer_key is not None:
|
268
|
-
if layer_key not in adata.layers
|
266
|
+
if layer_key not in adata.layers:
|
269
267
|
raise ValueError(f"Layer {layer_key!r} does not exist in the anndata.")
|
270
268
|
else:
|
271
269
|
self.X = adata.layers[layer_key]
|
@@ -284,7 +282,7 @@ class KMeansSpace(ClusteringSpace):
|
|
284
282
|
|
285
283
|
|
286
284
|
class DBSCANSpace(ClusteringSpace):
|
287
|
-
"""Cluster the given data using DBSCAN"""
|
285
|
+
"""Cluster the given data using DBSCAN."""
|
288
286
|
|
289
287
|
def compute( # type: ignore
|
290
288
|
self,
|
@@ -328,7 +326,7 @@ class DBSCANSpace(ClusteringSpace):
|
|
328
326
|
self.X = adata.obsm[embedding_key]
|
329
327
|
|
330
328
|
elif layer_key is not None:
|
331
|
-
if layer_key not in adata.layers
|
329
|
+
if layer_key not in adata.layers:
|
332
330
|
raise ValueError(f"Layer {layer_key!r} does not exist in the anndata.")
|
333
331
|
else:
|
334
332
|
self.X = adata.layers[layer_key]
|
pertpy/tools/_scgen/_scgen.py
CHANGED
@@ -77,7 +77,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
77
77
|
restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction.
|
78
78
|
|
79
79
|
Returns:
|
80
|
-
`
|
80
|
+
:class:`numpy.ndarray` of predicted cells in primary space.
|
81
81
|
delta: float
|
82
82
|
Difference between stimulated and control cells in latent space
|
83
83
|
|
@@ -198,7 +198,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
198
198
|
corresponding to batch and cell type metadata, respectively.
|
199
199
|
|
200
200
|
Returns:
|
201
|
-
corrected
|
201
|
+
A corrected `~anndata.AnnData` object.
|
202
202
|
AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"].
|
203
203
|
A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute.
|
204
204
|
|
@@ -343,6 +343,8 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
343
343
|
AnnData object used to initialize the model.
|
344
344
|
indices: Indices of cells in adata to use. If `None`, all cells are used.
|
345
345
|
batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
|
346
|
+
give_mean: Whether to return the mean
|
347
|
+
n_samples: The number of samples to use.
|
346
348
|
|
347
349
|
Returns:
|
348
350
|
Low-dimensional representation for each cell
|
@@ -365,17 +367,14 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
365
367
|
latent = []
|
366
368
|
for array_dict in scdl:
|
367
369
|
out = jit_inference_fn(self.module.rngs, array_dict)
|
368
|
-
if give_mean
|
369
|
-
z = out["qz"].mean
|
370
|
-
else:
|
371
|
-
z = out["z"]
|
370
|
+
z = out["qz"].mean if give_mean else out["z"]
|
372
371
|
latent.append(z)
|
373
372
|
concat_axis = 0 if ((n_samples == 1) or give_mean) else 1
|
374
373
|
latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore
|
375
374
|
|
376
375
|
return self.module.as_numpy_array(latent)
|
377
376
|
|
378
|
-
def plot_reg_mean_plot(
|
377
|
+
def plot_reg_mean_plot( # pragma: no cover # noqa: D417
|
379
378
|
self,
|
380
379
|
adata,
|
381
380
|
condition_key: str,
|
@@ -495,14 +494,14 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
495
494
|
ax.text(
|
496
495
|
max(x) - max(x) * x_coeff,
|
497
496
|
max(y) - y_coeff * max(y),
|
498
|
-
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value
|
497
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value**2:.2f}",
|
499
498
|
fontsize=kwargs.get("textsize", fontsize),
|
500
499
|
)
|
501
500
|
if diff_genes is not None:
|
502
501
|
ax.text(
|
503
502
|
max(x) - max(x) * x_coeff,
|
504
503
|
max(y) - (y_coeff + 0.15) * max(y),
|
505
|
-
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff
|
504
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff**2:.2f}",
|
506
505
|
fontsize=kwargs.get("textsize", fontsize),
|
507
506
|
)
|
508
507
|
|
@@ -516,7 +515,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
516
515
|
else:
|
517
516
|
return r_value**2
|
518
517
|
|
519
|
-
def plot_reg_var_plot(
|
518
|
+
def plot_reg_var_plot( # pragma: no cover # noqa: D417
|
520
519
|
self,
|
521
520
|
adata,
|
522
521
|
condition_key: str,
|
@@ -576,7 +575,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
576
575
|
m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff)
|
577
576
|
if verbose:
|
578
577
|
logger.info("Top 100 DEGs var: ", r_value_diff**2)
|
579
|
-
if "y1" in axis_keys
|
578
|
+
if "y1" in axis_keys:
|
580
579
|
real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]]
|
581
580
|
x = np.asarray(np.var(ctrl.X, axis=0)).ravel()
|
582
581
|
y = np.asarray(np.var(stim.X, axis=0)).ravel()
|
@@ -594,7 +593,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
594
593
|
# plt.plot(x, m * x + b, "-", color="green")
|
595
594
|
ax.set_xlabel(labels["x"], fontsize=fontsize)
|
596
595
|
ax.set_ylabel(labels["y"], fontsize=fontsize)
|
597
|
-
if "y1" in axis_keys
|
596
|
+
if "y1" in axis_keys:
|
598
597
|
y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel()
|
599
598
|
_ = plt.scatter(
|
600
599
|
x,
|
@@ -611,7 +610,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
611
610
|
y_bar = y[j]
|
612
611
|
plt.text(x_bar, y_bar, i, fontsize=11, color="black")
|
613
612
|
plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
|
614
|
-
if "y1" in axis_keys
|
613
|
+
if "y1" in axis_keys:
|
615
614
|
y1_bar = y1[j]
|
616
615
|
plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5)
|
617
616
|
if legend:
|
@@ -623,14 +622,14 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
623
622
|
ax.text(
|
624
623
|
max(x) - max(x) * x_coeff,
|
625
624
|
max(y) - y_coeff * max(y),
|
626
|
-
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value
|
625
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value**2:.2f}",
|
627
626
|
fontsize=kwargs.get("textsize", fontsize),
|
628
627
|
)
|
629
628
|
if diff_genes is not None:
|
630
629
|
ax.text(
|
631
630
|
max(x) - max(x) * x_coeff,
|
632
631
|
max(y) - (y_coeff + 0.15) * max(y),
|
633
|
-
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff
|
632
|
+
r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff**2:.2f}",
|
634
633
|
fontsize=kwargs.get("textsize", fontsize),
|
635
634
|
)
|
636
635
|
|
@@ -645,7 +644,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
645
644
|
return r_value**2
|
646
645
|
|
647
646
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
648
|
-
def plot_binary_classifier(
|
647
|
+
def plot_binary_classifier( # pragma: no cover # noqa: D417
|
649
648
|
self,
|
650
649
|
scgen: Scgen,
|
651
650
|
adata: AnnData | None,
|
@@ -665,7 +664,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
665
664
|
Args:
|
666
665
|
scgen: ScGen object that was trained.
|
667
666
|
adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
|
668
|
-
AnnData object used to initialize the model. Must have been
|
667
|
+
AnnData object used to initialize the model. Must have been set up with `batch_key` and `labels_key`,
|
669
668
|
corresponding to batch and cell type metadata, respectively.
|
670
669
|
delta: Difference between stimulated and control cells in latent space
|
671
670
|
ctrl_key: Key for `control` part of the `data` found in `condition_key`.
|
pertpy/tools/_scgen/_scgenvae.py
CHANGED
@@ -24,8 +24,8 @@ class JaxSCGENVAE(JaxBaseModuleClass):
|
|
24
24
|
training: bool = True
|
25
25
|
|
26
26
|
def setup(self):
|
27
|
-
use_batch_norm_encoder = self.use_batch_norm
|
28
|
-
use_layer_norm_encoder = self.use_layer_norm
|
27
|
+
use_batch_norm_encoder = self.use_batch_norm in ("encoder", "both")
|
28
|
+
use_layer_norm_encoder = self.use_layer_norm in ("encoder", "both")
|
29
29
|
|
30
30
|
self.encoder = FlaxEncoder(
|
31
31
|
n_latent=self.n_latent,
|
pertpy/tools/_scgen/_utils.py
CHANGED
@@ -32,7 +32,9 @@ def extractor(
|
|
32
32
|
|
33
33
|
train_data = anndata.read("./data/train.h5ad")
|
34
34
|
test_data = anndata.read("./data/test.h5ad")
|
35
|
-
train_data_extracted_list = extractor(
|
35
|
+
train_data_extracted_list = extractor(
|
36
|
+
train_data, "CD4T", "conditions", "cell_type", "control", "stimulated"
|
37
|
+
)
|
36
38
|
"""
|
37
39
|
cell_with_both_condition = data[data.obs[cell_type_key] == cell_type]
|
38
40
|
condition_1 = data[(data.obs[cell_type_key] == cell_type) & (data.obs[condition_key] == ctrl_key)]
|