pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,381 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import TYPE_CHECKING
4
-
5
- import anndata
6
- import pytorch_lightning as pl
7
- import scipy
8
- import torch
9
- from anndata import AnnData
10
- from pytorch_lightning.callbacks import EarlyStopping
11
- from sklearn.model_selection import train_test_split
12
- from sklearn.preprocessing import LabelEncoder
13
- from torch import optim
14
- from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
15
-
16
- from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
17
-
18
- if TYPE_CHECKING:
19
- import numpy as np
20
-
21
-
22
- class DiscriminatorClassifierSpace(PerturbationSpace):
23
- """Leveraging discriminator classifier. Fit a regressor model to the data and take the feature space.
24
-
25
- See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19)
26
- We use either the coefficients of the model for each perturbation as a feature or train a classifier example
27
- (simple MLP or logistic regression and take the penultimate layer as feature space and apply pseudobulking approach).
28
- """
29
-
30
- def load( # type: ignore
31
- self,
32
- adata: AnnData,
33
- target_col: str = "perturbations",
34
- layer_key: str = None,
35
- hidden_dim: list[int] = None,
36
- dropout: float = 0.0,
37
- batch_norm: bool = True,
38
- batch_size: int = 256,
39
- test_split_size: float = 0.2,
40
- validation_split_size: float = 0.25,
41
- ):
42
- """Creates a model with the specified parameters (hidden_dim, dropout, batch_norm).
43
-
44
- It further creates dataloaders and fixes class imbalance due to control.
45
- Sets the device to a GPU if available.
46
-
47
- Args:
48
- adata: AnnData object of size cells x genes
49
- target_col: .obs column that stores the perturbations. Defaults to "perturbations".
50
- layer_key: Layer to use. Defaults to None.
51
- hidden_dim: list of hidden layers of the neural network. For instance: [512, 256].
52
- dropout: amount of dropout applied, constant for all layers. Defaults to 0.
53
- batch_norm: Whether to apply batch normalization. Defaults to True.
54
- batch_size: The batch size. Defaults to 256.
55
- test_split_size: Default to 0.2.
56
- validation_split_size: Size of the validation split taking into account that is taking with respect to the resultant train split.
57
- Defaults to 0.25.
58
-
59
- Examples:
60
- >>> import pertpy as pt
61
- >>> adata = pt.dt.papalexi_2021()['rna']
62
- >>> dcs = pt.tl.DiscriminatorClassifierSpace()
63
- >>> dcs.load(adata, target_col="gene_target")
64
- """
65
- if layer_key is not None and layer_key not in adata.obs.columns:
66
- raise ValueError(f"Layer key {layer_key} not found in adata. {layer_key}")
67
-
68
- if target_col not in adata.obs:
69
- raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
70
-
71
- if hidden_dim is None:
72
- hidden_dim = [512]
73
-
74
- # Labels are strings, one hot encoding for classification
75
- n_classes = len(adata.obs[target_col].unique())
76
- labels = adata.obs[target_col]
77
- label_encoder = LabelEncoder()
78
- encoded_labels = label_encoder.fit_transform(labels)
79
- adata.obs["encoded_perturbations"] = encoded_labels
80
-
81
- # Split the data in train, test and validation
82
- X = list(range(0, adata.n_obs))
83
- y = adata.obs[target_col]
84
-
85
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
86
- X_train, X_val, y_train, y_val = train_test_split(
87
- X_train, y_train, test_size=validation_split_size, stratify=y_train
88
- )
89
-
90
- train_dataset = PLDataset(
91
- adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
92
- )
93
- val_dataset = PLDataset(
94
- adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
95
- )
96
- test_dataset = PLDataset(
97
- adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
98
- ) # we don't need to pass y_test since the label selection is done inside
99
-
100
- # Fix class unbalance (likely to happen in perturbation datasets)
101
- # Usually control cells are overrepresented such that predicting control all time would give good results
102
- # Cells with rare perturbations are sampled more
103
- class_weights = 1.0 / torch.bincount(torch.tensor(train_dataset.labels.values))
104
- train_weights = class_weights[train_dataset.labels]
105
- train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
106
-
107
- self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
108
- self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
109
- self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
110
-
111
- # Define the network
112
- sizes = [adata.n_vars] + hidden_dim + [n_classes]
113
- self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
114
-
115
- # Define a dataset that gathers all the data and dataloader for getting embeddings
116
- total_dataset = PLDataset(
117
- adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
118
- )
119
- self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=4)
120
-
121
- return self
122
-
123
- def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int = 2):
124
- """Trains and test the defined model in the load step.
125
-
126
- Args:
127
- max_epochs: max epochs for training. Default to 40
128
- val_epochs_check: check in validation dataset each val_epochs_check epochs
129
- patience: patience before the early stopping flag is activated
130
-
131
- Examples:
132
- >>> import pertpy as pt
133
- >>> adata = pt.dt.papalexi_2021()['rna']
134
- >>> dcs = pt.tl.DiscriminatorClassifierSpace()
135
- >>> dcs.load(adata, target_col="gene_target")
136
- >>> dcs.train(max_epochs=5)
137
- """
138
- self.trainer = pl.Trainer(
139
- min_epochs=1,
140
- max_epochs=max_epochs,
141
- check_val_every_n_epoch=val_epochs_check,
142
- callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
143
- devices="auto",
144
- accelerator="auto",
145
- )
146
-
147
- self.model = PerturbationClassifier(model=self.net)
148
-
149
- self.trainer.fit(
150
- model=self.model, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader
151
- )
152
- self.trainer.test(model=self.model, dataloaders=self.test_dataloader)
153
-
154
- def get_embeddings(self) -> AnnData:
155
- """Access to the embeddings of the last layer.
156
-
157
- Returns:
158
- AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
159
-
160
- Examples:
161
- >>> import pertpy as pt
162
- >>> adata = pt.dt.papalexi_2021()['rna']
163
- >>> dcs = pt.tl.DiscriminatorClassifierSpace()
164
- >>> dcs.load(adata, target_col="gene_target")
165
- >>> dcs.train()
166
- >>> embeddings = dcs.get_embeddings()
167
- """
168
- with torch.no_grad():
169
- self.model.eval()
170
- for dataset_count, batch in enumerate(self.entire_dataset):
171
- emb, y = self.model.get_embeddings(batch)
172
- batch_adata = AnnData(X=emb.cpu().numpy())
173
- batch_adata.obs["perturbations"] = y
174
- if dataset_count == 0:
175
- pert_adata = batch_adata
176
- else:
177
- pert_adata = anndata.concat([pert_adata, batch_adata])
178
-
179
- return pert_adata
180
-
181
-
182
- class MLP(torch.nn.Module):
183
- """
184
- A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
185
- """
186
-
187
- def __init__(
188
- self,
189
- sizes: list[int],
190
- dropout: float = 0.0,
191
- batch_norm: bool = True,
192
- layer_norm: bool = False,
193
- last_layer_act: str = "linear",
194
- ) -> None:
195
- """
196
- Args:
197
- sizes: size of layers
198
- dropout: Dropout probability. Defaults to 0.0.
199
- batch_norm: batch norm. Defaults to True.
200
- layer_norm: layern norm, common in Transformers. Defaults to False.
201
- last_layer_act: activation function of last layer. Defaults to "linear".
202
- """
203
- super().__init__()
204
- layers = []
205
- for s in range(len(sizes) - 1):
206
- layers += [
207
- torch.nn.Linear(sizes[s], sizes[s + 1]),
208
- torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
209
- torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
210
- torch.nn.ReLU(),
211
- torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
212
- ]
213
-
214
- layers = [layer for layer in layers if layer is not None][:-1]
215
- self.activation = last_layer_act
216
- if self.activation == "linear":
217
- pass
218
- elif self.activation == "ReLU":
219
- self.relu = torch.nn.ReLU()
220
- else:
221
- raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
222
-
223
- self.network = torch.nn.Sequential(*layers)
224
-
225
- self.network.apply(init_weights)
226
-
227
- self.sizes = sizes
228
- self.batch_norm = batch_norm
229
- self.layer_norm = layer_norm
230
- self.last_layer_act = last_layer_act
231
-
232
- def forward(self, x) -> torch.Tensor:
233
- if self.activation == "ReLU":
234
- return self.relu(self.network(x))
235
- return self.network(x)
236
-
237
- def embedding(self, x) -> torch.Tensor:
238
- for layer in self.network[:-1]:
239
- x = layer(x)
240
- return x
241
-
242
-
243
- def init_weights(m):
244
- if isinstance(m, torch.nn.Linear):
245
- torch.nn.init.kaiming_uniform_(m.weight)
246
- m.bias.data.fill_(0.01)
247
-
248
-
249
- class PLDataset(Dataset):
250
- """
251
- Dataset for perturbation classification.
252
- Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
253
- """
254
-
255
- def __init__(
256
- self,
257
- adata: np.array,
258
- target_col: str = "perturbations",
259
- label_col: str = "perturbations",
260
- layer_key: str = None,
261
- ):
262
- """
263
- Args:
264
- adata: AnnData object with observations and labels.
265
- target_col: key with the perturbation labels numerically encoded. Defaults to 'perturbations'.
266
- label_col: key with the perturbation labels. Defaults to 'perturbations'.
267
- layer_key: key of the layer to be used as data, otherwise .X
268
- """
269
-
270
- if layer_key:
271
- self.data = adata.layers[layer_key]
272
- else:
273
- self.data = adata.X
274
-
275
- self.labels = adata.obs[target_col]
276
- self.pert_labels = adata.obs[label_col]
277
-
278
- def __len__(self):
279
- return len(self.data)
280
-
281
- def __getitem__(self, idx):
282
- """Returns a sample and corresponding perturbations applied (labels)"""
283
-
284
- sample = self.data[idx].A if scipy.sparse.issparse(self.data) else self.data[idx]
285
- num_label = self.labels[idx]
286
- str_label = self.pert_labels[idx]
287
-
288
- return sample, num_label, str_label
289
-
290
-
291
- class PerturbationClassifier(pl.LightningModule):
292
- def __init__(
293
- self,
294
- model: torch.nn.Module,
295
- layers: list = [512], # noqa
296
- dropout: float = 0.0,
297
- batch_norm: bool = True,
298
- layer_norm: bool = False,
299
- last_layer_act: str = "linear",
300
- lr=1e-4,
301
- seed=42,
302
- ):
303
- """
304
- Inputs:
305
- layers - list: layers of the MLP
306
- """
307
- super().__init__()
308
- self.save_hyperparameters()
309
- if model:
310
- self.net = model
311
- else:
312
- self._create_model()
313
-
314
- def _create_model(self):
315
- self.net = MLP(
316
- sizes=self.hparams.layers,
317
- dropout=self.hparams.dropout,
318
- batch_norm=self.hparams.batch_norm,
319
- layer_norm=self.hparams.layer_norm,
320
- last_layer_act=self.hparams.last_layer_act,
321
- )
322
-
323
- def forward(self, x):
324
- x = self.net(x)
325
- return x
326
-
327
- def configure_optimizers(self):
328
- optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
329
-
330
- return optimizer
331
-
332
- def training_step(self, batch, batch_idx):
333
- x, y, _ = batch
334
- x = x.to(torch.float32)
335
- y = y.to(torch.long)
336
-
337
- y_hat = self.forward(x)
338
-
339
- loss = torch.nn.functional.cross_entropy(y_hat, y)
340
- self.log("train_loss", loss, prog_bar=True)
341
-
342
- return loss
343
-
344
- def validation_step(self, batch, batch_idx):
345
- x, y, _ = batch
346
- x = x.to(torch.float32)
347
- y = y.to(torch.long)
348
-
349
- y_hat = self.forward(x)
350
-
351
- loss = torch.nn.functional.cross_entropy(y_hat, y)
352
- self.log("val_loss", loss, prog_bar=True)
353
-
354
- return loss
355
-
356
- def test_step(self, batch, batch_idx):
357
- x, y, _ = batch
358
- x = x.to(torch.float32)
359
- y = y.to(torch.long)
360
-
361
- y_hat = self.forward(x)
362
-
363
- loss = torch.nn.functional.cross_entropy(y_hat, y)
364
- self.log("test_loss", loss, prog_bar=True)
365
-
366
- return loss
367
-
368
- def embedding(self, x):
369
- """
370
- Inputs:
371
- x - Input features of shape [Batch, SeqLen, 1]
372
- """
373
- x = self.net.embedding(x)
374
- return x
375
-
376
- def get_embeddings(self, batch):
377
- x, _, y = batch
378
- x = x.to(torch.float32)
379
-
380
- embedding = self.embedding(x)
381
- return embedding, y