pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,381 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import TYPE_CHECKING
|
4
|
-
|
5
|
-
import anndata
|
6
|
-
import pytorch_lightning as pl
|
7
|
-
import scipy
|
8
|
-
import torch
|
9
|
-
from anndata import AnnData
|
10
|
-
from pytorch_lightning.callbacks import EarlyStopping
|
11
|
-
from sklearn.model_selection import train_test_split
|
12
|
-
from sklearn.preprocessing import LabelEncoder
|
13
|
-
from torch import optim
|
14
|
-
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
|
15
|
-
|
16
|
-
from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
|
17
|
-
|
18
|
-
if TYPE_CHECKING:
|
19
|
-
import numpy as np
|
20
|
-
|
21
|
-
|
22
|
-
class DiscriminatorClassifierSpace(PerturbationSpace):
|
23
|
-
"""Leveraging discriminator classifier. Fit a regressor model to the data and take the feature space.
|
24
|
-
|
25
|
-
See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19)
|
26
|
-
We use either the coefficients of the model for each perturbation as a feature or train a classifier example
|
27
|
-
(simple MLP or logistic regression and take the penultimate layer as feature space and apply pseudobulking approach).
|
28
|
-
"""
|
29
|
-
|
30
|
-
def load( # type: ignore
|
31
|
-
self,
|
32
|
-
adata: AnnData,
|
33
|
-
target_col: str = "perturbations",
|
34
|
-
layer_key: str = None,
|
35
|
-
hidden_dim: list[int] = None,
|
36
|
-
dropout: float = 0.0,
|
37
|
-
batch_norm: bool = True,
|
38
|
-
batch_size: int = 256,
|
39
|
-
test_split_size: float = 0.2,
|
40
|
-
validation_split_size: float = 0.25,
|
41
|
-
):
|
42
|
-
"""Creates a model with the specified parameters (hidden_dim, dropout, batch_norm).
|
43
|
-
|
44
|
-
It further creates dataloaders and fixes class imbalance due to control.
|
45
|
-
Sets the device to a GPU if available.
|
46
|
-
|
47
|
-
Args:
|
48
|
-
adata: AnnData object of size cells x genes
|
49
|
-
target_col: .obs column that stores the perturbations. Defaults to "perturbations".
|
50
|
-
layer_key: Layer to use. Defaults to None.
|
51
|
-
hidden_dim: list of hidden layers of the neural network. For instance: [512, 256].
|
52
|
-
dropout: amount of dropout applied, constant for all layers. Defaults to 0.
|
53
|
-
batch_norm: Whether to apply batch normalization. Defaults to True.
|
54
|
-
batch_size: The batch size. Defaults to 256.
|
55
|
-
test_split_size: Default to 0.2.
|
56
|
-
validation_split_size: Size of the validation split taking into account that is taking with respect to the resultant train split.
|
57
|
-
Defaults to 0.25.
|
58
|
-
|
59
|
-
Examples:
|
60
|
-
>>> import pertpy as pt
|
61
|
-
>>> adata = pt.dt.papalexi_2021()['rna']
|
62
|
-
>>> dcs = pt.tl.DiscriminatorClassifierSpace()
|
63
|
-
>>> dcs.load(adata, target_col="gene_target")
|
64
|
-
"""
|
65
|
-
if layer_key is not None and layer_key not in adata.obs.columns:
|
66
|
-
raise ValueError(f"Layer key {layer_key} not found in adata. {layer_key}")
|
67
|
-
|
68
|
-
if target_col not in adata.obs:
|
69
|
-
raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
|
70
|
-
|
71
|
-
if hidden_dim is None:
|
72
|
-
hidden_dim = [512]
|
73
|
-
|
74
|
-
# Labels are strings, one hot encoding for classification
|
75
|
-
n_classes = len(adata.obs[target_col].unique())
|
76
|
-
labels = adata.obs[target_col]
|
77
|
-
label_encoder = LabelEncoder()
|
78
|
-
encoded_labels = label_encoder.fit_transform(labels)
|
79
|
-
adata.obs["encoded_perturbations"] = encoded_labels
|
80
|
-
|
81
|
-
# Split the data in train, test and validation
|
82
|
-
X = list(range(0, adata.n_obs))
|
83
|
-
y = adata.obs[target_col]
|
84
|
-
|
85
|
-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
|
86
|
-
X_train, X_val, y_train, y_val = train_test_split(
|
87
|
-
X_train, y_train, test_size=validation_split_size, stratify=y_train
|
88
|
-
)
|
89
|
-
|
90
|
-
train_dataset = PLDataset(
|
91
|
-
adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
92
|
-
)
|
93
|
-
val_dataset = PLDataset(
|
94
|
-
adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
95
|
-
)
|
96
|
-
test_dataset = PLDataset(
|
97
|
-
adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
98
|
-
) # we don't need to pass y_test since the label selection is done inside
|
99
|
-
|
100
|
-
# Fix class unbalance (likely to happen in perturbation datasets)
|
101
|
-
# Usually control cells are overrepresented such that predicting control all time would give good results
|
102
|
-
# Cells with rare perturbations are sampled more
|
103
|
-
class_weights = 1.0 / torch.bincount(torch.tensor(train_dataset.labels.values))
|
104
|
-
train_weights = class_weights[train_dataset.labels]
|
105
|
-
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
|
106
|
-
|
107
|
-
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
|
108
|
-
self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
109
|
-
self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
110
|
-
|
111
|
-
# Define the network
|
112
|
-
sizes = [adata.n_vars] + hidden_dim + [n_classes]
|
113
|
-
self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
|
114
|
-
|
115
|
-
# Define a dataset that gathers all the data and dataloader for getting embeddings
|
116
|
-
total_dataset = PLDataset(
|
117
|
-
adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
118
|
-
)
|
119
|
-
self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=4)
|
120
|
-
|
121
|
-
return self
|
122
|
-
|
123
|
-
def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int = 2):
|
124
|
-
"""Trains and test the defined model in the load step.
|
125
|
-
|
126
|
-
Args:
|
127
|
-
max_epochs: max epochs for training. Default to 40
|
128
|
-
val_epochs_check: check in validation dataset each val_epochs_check epochs
|
129
|
-
patience: patience before the early stopping flag is activated
|
130
|
-
|
131
|
-
Examples:
|
132
|
-
>>> import pertpy as pt
|
133
|
-
>>> adata = pt.dt.papalexi_2021()['rna']
|
134
|
-
>>> dcs = pt.tl.DiscriminatorClassifierSpace()
|
135
|
-
>>> dcs.load(adata, target_col="gene_target")
|
136
|
-
>>> dcs.train(max_epochs=5)
|
137
|
-
"""
|
138
|
-
self.trainer = pl.Trainer(
|
139
|
-
min_epochs=1,
|
140
|
-
max_epochs=max_epochs,
|
141
|
-
check_val_every_n_epoch=val_epochs_check,
|
142
|
-
callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
|
143
|
-
devices="auto",
|
144
|
-
accelerator="auto",
|
145
|
-
)
|
146
|
-
|
147
|
-
self.model = PerturbationClassifier(model=self.net)
|
148
|
-
|
149
|
-
self.trainer.fit(
|
150
|
-
model=self.model, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader
|
151
|
-
)
|
152
|
-
self.trainer.test(model=self.model, dataloaders=self.test_dataloader)
|
153
|
-
|
154
|
-
def get_embeddings(self) -> AnnData:
|
155
|
-
"""Access to the embeddings of the last layer.
|
156
|
-
|
157
|
-
Returns:
|
158
|
-
AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
|
159
|
-
|
160
|
-
Examples:
|
161
|
-
>>> import pertpy as pt
|
162
|
-
>>> adata = pt.dt.papalexi_2021()['rna']
|
163
|
-
>>> dcs = pt.tl.DiscriminatorClassifierSpace()
|
164
|
-
>>> dcs.load(adata, target_col="gene_target")
|
165
|
-
>>> dcs.train()
|
166
|
-
>>> embeddings = dcs.get_embeddings()
|
167
|
-
"""
|
168
|
-
with torch.no_grad():
|
169
|
-
self.model.eval()
|
170
|
-
for dataset_count, batch in enumerate(self.entire_dataset):
|
171
|
-
emb, y = self.model.get_embeddings(batch)
|
172
|
-
batch_adata = AnnData(X=emb.cpu().numpy())
|
173
|
-
batch_adata.obs["perturbations"] = y
|
174
|
-
if dataset_count == 0:
|
175
|
-
pert_adata = batch_adata
|
176
|
-
else:
|
177
|
-
pert_adata = anndata.concat([pert_adata, batch_adata])
|
178
|
-
|
179
|
-
return pert_adata
|
180
|
-
|
181
|
-
|
182
|
-
class MLP(torch.nn.Module):
|
183
|
-
"""
|
184
|
-
A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
|
185
|
-
"""
|
186
|
-
|
187
|
-
def __init__(
|
188
|
-
self,
|
189
|
-
sizes: list[int],
|
190
|
-
dropout: float = 0.0,
|
191
|
-
batch_norm: bool = True,
|
192
|
-
layer_norm: bool = False,
|
193
|
-
last_layer_act: str = "linear",
|
194
|
-
) -> None:
|
195
|
-
"""
|
196
|
-
Args:
|
197
|
-
sizes: size of layers
|
198
|
-
dropout: Dropout probability. Defaults to 0.0.
|
199
|
-
batch_norm: batch norm. Defaults to True.
|
200
|
-
layer_norm: layern norm, common in Transformers. Defaults to False.
|
201
|
-
last_layer_act: activation function of last layer. Defaults to "linear".
|
202
|
-
"""
|
203
|
-
super().__init__()
|
204
|
-
layers = []
|
205
|
-
for s in range(len(sizes) - 1):
|
206
|
-
layers += [
|
207
|
-
torch.nn.Linear(sizes[s], sizes[s + 1]),
|
208
|
-
torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
|
209
|
-
torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
|
210
|
-
torch.nn.ReLU(),
|
211
|
-
torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
|
212
|
-
]
|
213
|
-
|
214
|
-
layers = [layer for layer in layers if layer is not None][:-1]
|
215
|
-
self.activation = last_layer_act
|
216
|
-
if self.activation == "linear":
|
217
|
-
pass
|
218
|
-
elif self.activation == "ReLU":
|
219
|
-
self.relu = torch.nn.ReLU()
|
220
|
-
else:
|
221
|
-
raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
|
222
|
-
|
223
|
-
self.network = torch.nn.Sequential(*layers)
|
224
|
-
|
225
|
-
self.network.apply(init_weights)
|
226
|
-
|
227
|
-
self.sizes = sizes
|
228
|
-
self.batch_norm = batch_norm
|
229
|
-
self.layer_norm = layer_norm
|
230
|
-
self.last_layer_act = last_layer_act
|
231
|
-
|
232
|
-
def forward(self, x) -> torch.Tensor:
|
233
|
-
if self.activation == "ReLU":
|
234
|
-
return self.relu(self.network(x))
|
235
|
-
return self.network(x)
|
236
|
-
|
237
|
-
def embedding(self, x) -> torch.Tensor:
|
238
|
-
for layer in self.network[:-1]:
|
239
|
-
x = layer(x)
|
240
|
-
return x
|
241
|
-
|
242
|
-
|
243
|
-
def init_weights(m):
|
244
|
-
if isinstance(m, torch.nn.Linear):
|
245
|
-
torch.nn.init.kaiming_uniform_(m.weight)
|
246
|
-
m.bias.data.fill_(0.01)
|
247
|
-
|
248
|
-
|
249
|
-
class PLDataset(Dataset):
|
250
|
-
"""
|
251
|
-
Dataset for perturbation classification.
|
252
|
-
Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
|
253
|
-
"""
|
254
|
-
|
255
|
-
def __init__(
|
256
|
-
self,
|
257
|
-
adata: np.array,
|
258
|
-
target_col: str = "perturbations",
|
259
|
-
label_col: str = "perturbations",
|
260
|
-
layer_key: str = None,
|
261
|
-
):
|
262
|
-
"""
|
263
|
-
Args:
|
264
|
-
adata: AnnData object with observations and labels.
|
265
|
-
target_col: key with the perturbation labels numerically encoded. Defaults to 'perturbations'.
|
266
|
-
label_col: key with the perturbation labels. Defaults to 'perturbations'.
|
267
|
-
layer_key: key of the layer to be used as data, otherwise .X
|
268
|
-
"""
|
269
|
-
|
270
|
-
if layer_key:
|
271
|
-
self.data = adata.layers[layer_key]
|
272
|
-
else:
|
273
|
-
self.data = adata.X
|
274
|
-
|
275
|
-
self.labels = adata.obs[target_col]
|
276
|
-
self.pert_labels = adata.obs[label_col]
|
277
|
-
|
278
|
-
def __len__(self):
|
279
|
-
return len(self.data)
|
280
|
-
|
281
|
-
def __getitem__(self, idx):
|
282
|
-
"""Returns a sample and corresponding perturbations applied (labels)"""
|
283
|
-
|
284
|
-
sample = self.data[idx].A if scipy.sparse.issparse(self.data) else self.data[idx]
|
285
|
-
num_label = self.labels[idx]
|
286
|
-
str_label = self.pert_labels[idx]
|
287
|
-
|
288
|
-
return sample, num_label, str_label
|
289
|
-
|
290
|
-
|
291
|
-
class PerturbationClassifier(pl.LightningModule):
|
292
|
-
def __init__(
|
293
|
-
self,
|
294
|
-
model: torch.nn.Module,
|
295
|
-
layers: list = [512], # noqa
|
296
|
-
dropout: float = 0.0,
|
297
|
-
batch_norm: bool = True,
|
298
|
-
layer_norm: bool = False,
|
299
|
-
last_layer_act: str = "linear",
|
300
|
-
lr=1e-4,
|
301
|
-
seed=42,
|
302
|
-
):
|
303
|
-
"""
|
304
|
-
Inputs:
|
305
|
-
layers - list: layers of the MLP
|
306
|
-
"""
|
307
|
-
super().__init__()
|
308
|
-
self.save_hyperparameters()
|
309
|
-
if model:
|
310
|
-
self.net = model
|
311
|
-
else:
|
312
|
-
self._create_model()
|
313
|
-
|
314
|
-
def _create_model(self):
|
315
|
-
self.net = MLP(
|
316
|
-
sizes=self.hparams.layers,
|
317
|
-
dropout=self.hparams.dropout,
|
318
|
-
batch_norm=self.hparams.batch_norm,
|
319
|
-
layer_norm=self.hparams.layer_norm,
|
320
|
-
last_layer_act=self.hparams.last_layer_act,
|
321
|
-
)
|
322
|
-
|
323
|
-
def forward(self, x):
|
324
|
-
x = self.net(x)
|
325
|
-
return x
|
326
|
-
|
327
|
-
def configure_optimizers(self):
|
328
|
-
optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
|
329
|
-
|
330
|
-
return optimizer
|
331
|
-
|
332
|
-
def training_step(self, batch, batch_idx):
|
333
|
-
x, y, _ = batch
|
334
|
-
x = x.to(torch.float32)
|
335
|
-
y = y.to(torch.long)
|
336
|
-
|
337
|
-
y_hat = self.forward(x)
|
338
|
-
|
339
|
-
loss = torch.nn.functional.cross_entropy(y_hat, y)
|
340
|
-
self.log("train_loss", loss, prog_bar=True)
|
341
|
-
|
342
|
-
return loss
|
343
|
-
|
344
|
-
def validation_step(self, batch, batch_idx):
|
345
|
-
x, y, _ = batch
|
346
|
-
x = x.to(torch.float32)
|
347
|
-
y = y.to(torch.long)
|
348
|
-
|
349
|
-
y_hat = self.forward(x)
|
350
|
-
|
351
|
-
loss = torch.nn.functional.cross_entropy(y_hat, y)
|
352
|
-
self.log("val_loss", loss, prog_bar=True)
|
353
|
-
|
354
|
-
return loss
|
355
|
-
|
356
|
-
def test_step(self, batch, batch_idx):
|
357
|
-
x, y, _ = batch
|
358
|
-
x = x.to(torch.float32)
|
359
|
-
y = y.to(torch.long)
|
360
|
-
|
361
|
-
y_hat = self.forward(x)
|
362
|
-
|
363
|
-
loss = torch.nn.functional.cross_entropy(y_hat, y)
|
364
|
-
self.log("test_loss", loss, prog_bar=True)
|
365
|
-
|
366
|
-
return loss
|
367
|
-
|
368
|
-
def embedding(self, x):
|
369
|
-
"""
|
370
|
-
Inputs:
|
371
|
-
x - Input features of shape [Batch, SeqLen, 1]
|
372
|
-
"""
|
373
|
-
x = self.net.embedding(x)
|
374
|
-
return x
|
375
|
-
|
376
|
-
def get_embeddings(self, batch):
|
377
|
-
x, _, y = batch
|
378
|
-
x = x.to(torch.float32)
|
379
|
-
|
380
|
-
embedding = self.embedding(x)
|
381
|
-
return embedding, y
|