pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,381 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import TYPE_CHECKING
4
-
5
- import anndata
6
- import pytorch_lightning as pl
7
- import scipy
8
- import torch
9
- from anndata import AnnData
10
- from pytorch_lightning.callbacks import EarlyStopping
11
- from sklearn.model_selection import train_test_split
12
- from sklearn.preprocessing import LabelEncoder
13
- from torch import optim
14
- from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
15
-
16
- from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
17
-
18
- if TYPE_CHECKING:
19
- import numpy as np
20
-
21
-
22
- class DiscriminatorClassifierSpace(PerturbationSpace):
23
- """Leveraging discriminator classifier. Fit a regressor model to the data and take the feature space.
24
-
25
- See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19)
26
- We use either the coefficients of the model for each perturbation as a feature or train a classifier example
27
- (simple MLP or logistic regression and take the penultimate layer as feature space and apply pseudobulking approach).
28
- """
29
-
30
- def load( # type: ignore
31
- self,
32
- adata: AnnData,
33
- target_col: str = "perturbations",
34
- layer_key: str = None,
35
- hidden_dim: list[int] = None,
36
- dropout: float = 0.0,
37
- batch_norm: bool = True,
38
- batch_size: int = 256,
39
- test_split_size: float = 0.2,
40
- validation_split_size: float = 0.25,
41
- ):
42
- """Creates a model with the specified parameters (hidden_dim, dropout, batch_norm).
43
-
44
- It further creates dataloaders and fixes class imbalance due to control.
45
- Sets the device to a GPU if available.
46
-
47
- Args:
48
- adata: AnnData object of size cells x genes
49
- target_col: .obs column that stores the perturbations. Defaults to "perturbations".
50
- layer_key: Layer to use. Defaults to None.
51
- hidden_dim: list of hidden layers of the neural network. For instance: [512, 256].
52
- dropout: amount of dropout applied, constant for all layers. Defaults to 0.
53
- batch_norm: Whether to apply batch normalization. Defaults to True.
54
- batch_size: The batch size. Defaults to 256.
55
- test_split_size: Default to 0.2.
56
- validation_split_size: Size of the validation split taking into account that is taking with respect to the resultant train split.
57
- Defaults to 0.25.
58
-
59
- Examples:
60
- >>> import pertpy as pt
61
- >>> adata = pt.dt.papalexi_2021()['rna']
62
- >>> dcs = pt.tl.DiscriminatorClassifierSpace()
63
- >>> dcs.load(adata, target_col="gene_target")
64
- """
65
- if layer_key is not None and layer_key not in adata.obs.columns:
66
- raise ValueError(f"Layer key {layer_key} not found in adata. {layer_key}")
67
-
68
- if target_col not in adata.obs:
69
- raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
70
-
71
- if hidden_dim is None:
72
- hidden_dim = [512]
73
-
74
- # Labels are strings, one hot encoding for classification
75
- n_classes = len(adata.obs[target_col].unique())
76
- labels = adata.obs[target_col]
77
- label_encoder = LabelEncoder()
78
- encoded_labels = label_encoder.fit_transform(labels)
79
- adata.obs["encoded_perturbations"] = encoded_labels
80
-
81
- # Split the data in train, test and validation
82
- X = list(range(0, adata.n_obs))
83
- y = adata.obs[target_col]
84
-
85
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
86
- X_train, X_val, y_train, y_val = train_test_split(
87
- X_train, y_train, test_size=validation_split_size, stratify=y_train
88
- )
89
-
90
- train_dataset = PLDataset(
91
- adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
92
- )
93
- val_dataset = PLDataset(
94
- adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
95
- )
96
- test_dataset = PLDataset(
97
- adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
98
- ) # we don't need to pass y_test since the label selection is done inside
99
-
100
- # Fix class unbalance (likely to happen in perturbation datasets)
101
- # Usually control cells are overrepresented such that predicting control all time would give good results
102
- # Cells with rare perturbations are sampled more
103
- class_weights = 1.0 / torch.bincount(torch.tensor(train_dataset.labels.values))
104
- train_weights = class_weights[train_dataset.labels]
105
- train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
106
-
107
- self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
108
- self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
109
- self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
110
-
111
- # Define the network
112
- sizes = [adata.n_vars] + hidden_dim + [n_classes]
113
- self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
114
-
115
- # Define a dataset that gathers all the data and dataloader for getting embeddings
116
- total_dataset = PLDataset(
117
- adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
118
- )
119
- self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=4)
120
-
121
- return self
122
-
123
- def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int = 2):
124
- """Trains and test the defined model in the load step.
125
-
126
- Args:
127
- max_epochs: max epochs for training. Default to 40
128
- val_epochs_check: check in validation dataset each val_epochs_check epochs
129
- patience: patience before the early stopping flag is activated
130
-
131
- Examples:
132
- >>> import pertpy as pt
133
- >>> adata = pt.dt.papalexi_2021()['rna']
134
- >>> dcs = pt.tl.DiscriminatorClassifierSpace()
135
- >>> dcs.load(adata, target_col="gene_target")
136
- >>> dcs.train(max_epochs=5)
137
- """
138
- self.trainer = pl.Trainer(
139
- min_epochs=1,
140
- max_epochs=max_epochs,
141
- check_val_every_n_epoch=val_epochs_check,
142
- callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
143
- devices="auto",
144
- accelerator="auto",
145
- )
146
-
147
- self.model = PerturbationClassifier(model=self.net)
148
-
149
- self.trainer.fit(
150
- model=self.model, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader
151
- )
152
- self.trainer.test(model=self.model, dataloaders=self.test_dataloader)
153
-
154
- def get_embeddings(self) -> AnnData:
155
- """Access to the embeddings of the last layer.
156
-
157
- Returns:
158
- AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
159
-
160
- Examples:
161
- >>> import pertpy as pt
162
- >>> adata = pt.dt.papalexi_2021()['rna']
163
- >>> dcs = pt.tl.DiscriminatorClassifierSpace()
164
- >>> dcs.load(adata, target_col="gene_target")
165
- >>> dcs.train()
166
- >>> embeddings = dcs.get_embeddings()
167
- """
168
- with torch.no_grad():
169
- self.model.eval()
170
- for dataset_count, batch in enumerate(self.entire_dataset):
171
- emb, y = self.model.get_embeddings(batch)
172
- batch_adata = AnnData(X=emb.cpu().numpy())
173
- batch_adata.obs["perturbations"] = y
174
- if dataset_count == 0:
175
- pert_adata = batch_adata
176
- else:
177
- pert_adata = anndata.concat([pert_adata, batch_adata])
178
-
179
- return pert_adata
180
-
181
-
182
- class MLP(torch.nn.Module):
183
- """
184
- A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
185
- """
186
-
187
- def __init__(
188
- self,
189
- sizes: list[int],
190
- dropout: float = 0.0,
191
- batch_norm: bool = True,
192
- layer_norm: bool = False,
193
- last_layer_act: str = "linear",
194
- ) -> None:
195
- """
196
- Args:
197
- sizes: size of layers
198
- dropout: Dropout probability. Defaults to 0.0.
199
- batch_norm: batch norm. Defaults to True.
200
- layer_norm: layern norm, common in Transformers. Defaults to False.
201
- last_layer_act: activation function of last layer. Defaults to "linear".
202
- """
203
- super().__init__()
204
- layers = []
205
- for s in range(len(sizes) - 1):
206
- layers += [
207
- torch.nn.Linear(sizes[s], sizes[s + 1]),
208
- torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
209
- torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
210
- torch.nn.ReLU(),
211
- torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
212
- ]
213
-
214
- layers = [layer for layer in layers if layer is not None][:-1]
215
- self.activation = last_layer_act
216
- if self.activation == "linear":
217
- pass
218
- elif self.activation == "ReLU":
219
- self.relu = torch.nn.ReLU()
220
- else:
221
- raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
222
-
223
- self.network = torch.nn.Sequential(*layers)
224
-
225
- self.network.apply(init_weights)
226
-
227
- self.sizes = sizes
228
- self.batch_norm = batch_norm
229
- self.layer_norm = layer_norm
230
- self.last_layer_act = last_layer_act
231
-
232
- def forward(self, x) -> torch.Tensor:
233
- if self.activation == "ReLU":
234
- return self.relu(self.network(x))
235
- return self.network(x)
236
-
237
- def embedding(self, x) -> torch.Tensor:
238
- for layer in self.network[:-1]:
239
- x = layer(x)
240
- return x
241
-
242
-
243
- def init_weights(m):
244
- if isinstance(m, torch.nn.Linear):
245
- torch.nn.init.kaiming_uniform_(m.weight)
246
- m.bias.data.fill_(0.01)
247
-
248
-
249
- class PLDataset(Dataset):
250
- """
251
- Dataset for perturbation classification.
252
- Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
253
- """
254
-
255
- def __init__(
256
- self,
257
- adata: np.array,
258
- target_col: str = "perturbations",
259
- label_col: str = "perturbations",
260
- layer_key: str = None,
261
- ):
262
- """
263
- Args:
264
- adata: AnnData object with observations and labels.
265
- target_col: key with the perturbation labels numerically encoded. Defaults to 'perturbations'.
266
- label_col: key with the perturbation labels. Defaults to 'perturbations'.
267
- layer_key: key of the layer to be used as data, otherwise .X
268
- """
269
-
270
- if layer_key:
271
- self.data = adata.layers[layer_key]
272
- else:
273
- self.data = adata.X
274
-
275
- self.labels = adata.obs[target_col]
276
- self.pert_labels = adata.obs[label_col]
277
-
278
- def __len__(self):
279
- return len(self.data)
280
-
281
- def __getitem__(self, idx):
282
- """Returns a sample and corresponding perturbations applied (labels)"""
283
-
284
- sample = self.data[idx].A if scipy.sparse.issparse(self.data) else self.data[idx]
285
- num_label = self.labels[idx]
286
- str_label = self.pert_labels[idx]
287
-
288
- return sample, num_label, str_label
289
-
290
-
291
- class PerturbationClassifier(pl.LightningModule):
292
- def __init__(
293
- self,
294
- model: torch.nn.Module,
295
- layers: list = [512], # noqa
296
- dropout: float = 0.0,
297
- batch_norm: bool = True,
298
- layer_norm: bool = False,
299
- last_layer_act: str = "linear",
300
- lr=1e-4,
301
- seed=42,
302
- ):
303
- """
304
- Inputs:
305
- layers - list: layers of the MLP
306
- """
307
- super().__init__()
308
- self.save_hyperparameters()
309
- if model:
310
- self.net = model
311
- else:
312
- self._create_model()
313
-
314
- def _create_model(self):
315
- self.net = MLP(
316
- sizes=self.hparams.layers,
317
- dropout=self.hparams.dropout,
318
- batch_norm=self.hparams.batch_norm,
319
- layer_norm=self.hparams.layer_norm,
320
- last_layer_act=self.hparams.last_layer_act,
321
- )
322
-
323
- def forward(self, x):
324
- x = self.net(x)
325
- return x
326
-
327
- def configure_optimizers(self):
328
- optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
329
-
330
- return optimizer
331
-
332
- def training_step(self, batch, batch_idx):
333
- x, y, _ = batch
334
- x = x.to(torch.float32)
335
- y = y.to(torch.long)
336
-
337
- y_hat = self.forward(x)
338
-
339
- loss = torch.nn.functional.cross_entropy(y_hat, y)
340
- self.log("train_loss", loss, prog_bar=True)
341
-
342
- return loss
343
-
344
- def validation_step(self, batch, batch_idx):
345
- x, y, _ = batch
346
- x = x.to(torch.float32)
347
- y = y.to(torch.long)
348
-
349
- y_hat = self.forward(x)
350
-
351
- loss = torch.nn.functional.cross_entropy(y_hat, y)
352
- self.log("val_loss", loss, prog_bar=True)
353
-
354
- return loss
355
-
356
- def test_step(self, batch, batch_idx):
357
- x, y, _ = batch
358
- x = x.to(torch.float32)
359
- y = y.to(torch.long)
360
-
361
- y_hat = self.forward(x)
362
-
363
- loss = torch.nn.functional.cross_entropy(y_hat, y)
364
- self.log("test_loss", loss, prog_bar=True)
365
-
366
- return loss
367
-
368
- def embedding(self, x):
369
- """
370
- Inputs:
371
- x - Input features of shape [Batch, SeqLen, 1]
372
- """
373
- x = self.net.embedding(x)
374
- return x
375
-
376
- def get_embeddings(self, batch):
377
- x, _, y = batch
378
- x = x.to(torch.float32)
379
-
380
- embedding = self.embedding(x)
381
- return embedding, y