pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -31,6 +31,8 @@ class ClusteringSpace(PerturbationSpace):
31
31
  true_label_col: ground truth labels.
32
32
  cluster_col: cluster computed labels.
33
33
  metrics: Metrics to compute. Defaults to ['nmi', 'ari', 'asw'].
34
+ **kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed.
35
+ For asw, metric, distances, sample_size, and random_state can be passed.
34
36
 
35
37
  Examples:
36
38
  Example usage with KMeansSpace:
@@ -39,7 +41,9 @@ class ClusteringSpace(PerturbationSpace):
39
41
  >>> mdata = pt.dt.papalexi_2021()
40
42
  >>> kmeans = pt.tl.KMeansSpace()
41
43
  >>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26)
42
- >>> results = kmeans.evaluate_clustering(kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=['nmi'])
44
+ >>> results = kmeans.evaluate_clustering(
45
+ ... kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=["nmi"]
46
+ ... )
43
47
  """
44
48
  if metrics is None:
45
49
  metrics = ["nmi", "ari", "asw"]
@@ -0,0 +1,526 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import TYPE_CHECKING, Literal
5
+
6
+ import anndata
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pytorch_lightning as pl
10
+ import scipy
11
+ import torch
12
+ from anndata import AnnData
13
+ from pytorch_lightning.callbacks import EarlyStopping
14
+ from sklearn.linear_model import LogisticRegression
15
+ from sklearn.model_selection import train_test_split
16
+ from sklearn.preprocessing import OneHotEncoder
17
+ from torch import optim
18
+ from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
19
+
20
+ from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
21
+
22
+
23
+ class LRClassifierSpace(PerturbationSpace):
24
+ """Fits a logistic regression model to the data and takes the feature space as embedding.
25
+
26
+ We fit one logistic regression model per perturbation. After training, the coefficients of the logistic regression
27
+ model are used as the feature space. This results in one embedding per perturbation.
28
+ """
29
+
30
+ def compute(
31
+ self,
32
+ adata: AnnData,
33
+ target_col: str = "perturbations",
34
+ layer_key: str = None,
35
+ embedding_key: str = None,
36
+ test_split_size: float = 0.2,
37
+ max_iter: int = 1000,
38
+ ):
39
+ """
40
+ Fits a logistic regression model to the data and takes the coefficients of the logistic regression
41
+ model as perturbation embedding.
42
+
43
+ Args:
44
+ adata: AnnData object of size cells x genes
45
+ target_col: .obs column that stores the perturbations. Defaults to "perturbations".
46
+ layer_key: Layer in adata to use. Defaults to None.
47
+ embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
48
+ Can only be specified if layer_key is None. Defaults to None.
49
+ test_split_size: Fraction of data to put in the test set. Default to 0.2.
50
+ max_iter: Maximum number of iterations taken for the solvers to converge. Defaults to 1000.
51
+
52
+ Returns:
53
+ AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
54
+
55
+ Examples:
56
+ >>> import pertpy as pt
57
+ >>> adata = pt.dt.norman_2019()
58
+ >>> rcs = pt.tl.LRClassifierSpace()
59
+ >>> pert_embeddings = rcs.compute(adata, embedding_key="X_pca", target_col="perturbation_name")
60
+ """
61
+ if layer_key is not None and layer_key not in adata.obs.columns:
62
+ raise ValueError(f"Layer key {layer_key} not found in adata.")
63
+
64
+ if embedding_key is not None and embedding_key not in adata.obsm.keys():
65
+ raise ValueError(f"Embedding key {embedding_key} not found in adata.obsm.")
66
+
67
+ if layer_key is not None and embedding_key is not None:
68
+ raise ValueError("Cannot specify both layer_key and embedding_key.")
69
+
70
+ if target_col not in adata.obs:
71
+ raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
72
+
73
+ if layer_key is not None:
74
+ regression_data = adata.layers[layer_key]
75
+ elif embedding_key is not None:
76
+ regression_data = adata.obsm[embedding_key]
77
+ else:
78
+ regression_data = adata.X
79
+
80
+ regression_labels = adata.obs[target_col]
81
+
82
+ # Save adata observations for embedding annotations in get_embeddings
83
+ adata_obs = adata.obs.reset_index(drop=True)
84
+ adata_obs = adata_obs.groupby(target_col).agg(
85
+ lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
86
+ )
87
+
88
+ # Fit a logistic regression model for each perturbation
89
+ regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")
90
+ regression_embeddings = {}
91
+ regression_scores = {}
92
+
93
+ for perturbation in regression_labels.unique():
94
+ labels = np.where(regression_labels == perturbation, 1, 0)
95
+ X_train, X_test, y_train, y_test = train_test_split(
96
+ regression_data, labels, test_size=test_split_size, stratify=labels
97
+ )
98
+
99
+ regression_model.fit(X_train, y_train)
100
+ regression_embeddings[perturbation] = regression_model.coef_
101
+ regression_scores[perturbation] = regression_model.score(X_test, y_test)
102
+
103
+ # Save the regression embeddings and scores in an AnnData object
104
+ pert_adata = AnnData(X=np.array(list(regression_embeddings.values())).squeeze())
105
+ pert_adata.obs["perturbations"] = list(regression_embeddings.keys())
106
+ pert_adata.obs["classifier_score"] = list(regression_scores.values())
107
+
108
+ # Save adata observations for embedding annotations
109
+ for obs_name in adata_obs.columns:
110
+ if not adata_obs[obs_name].isnull().values.any():
111
+ pert_adata.obs[obs_name] = pert_adata.obs["perturbations"].map(
112
+ {pert: adata_obs.loc[pert][obs_name] for pert in adata_obs.index}
113
+ )
114
+
115
+ return pert_adata
116
+
117
+
118
+ # Ensure backward compatibility with DiscriminatorClassifierSpace
119
+ def DiscriminatorClassifierSpace():
120
+ warnings.warn(
121
+ "The DiscriminatorClassifierSpace class is deprecated and will be removed in the future."
122
+ "Please use the MLPClassifierSpace or the LRClassifierSpace class instead.",
123
+ DeprecationWarning,
124
+ stacklevel=2,
125
+ )
126
+
127
+ return MLPClassifierSpace()
128
+
129
+
130
+ class MLPClassifierSpace(PerturbationSpace):
131
+ """Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
132
+
133
+ We train the ANN to classify the different perturbations. After training, the penultimate layer is used as the
134
+ feature space, resulting in one embedding per cell. Consider employing the PseudoBulk or another PerturbationSpace
135
+ to obtain one embedding per perturbation.
136
+
137
+ See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
138
+ """
139
+
140
+ def compute( # type: ignore
141
+ self,
142
+ adata: AnnData,
143
+ target_col: str = "perturbations",
144
+ layer_key: str = None,
145
+ hidden_dim: list[int] = None,
146
+ dropout: float = 0.0,
147
+ batch_norm: bool = True,
148
+ batch_size: int = 256,
149
+ test_split_size: float = 0.2,
150
+ validation_split_size: float = 0.25,
151
+ max_epochs: int = 20,
152
+ val_epochs_check: int = 2,
153
+ patience: int = 2,
154
+ ) -> AnnData:
155
+ """Creates cell embeddings by training a MLP classifier model to distinguish between perturbations.
156
+
157
+ A model is created using the specified parameters (hidden_dim, dropout, batch_norm). Further parameters such as
158
+ the number of classes to predict (number of perturbations) are obtained from the provided AnnData object directly.
159
+ Dataloaders that take into account class imbalances are created. Next, the model is trained and tested, using the
160
+ GPU if available. The embeddings are obtained by passing the data through the model and extracting the values in
161
+ the last layer of the MLP. You will get one embedding per cell, so be aware that you might need to apply another
162
+ perturbation space to aggregate the embeddings per perturbation.
163
+
164
+ Args:
165
+ adata: AnnData object of size cells x genes
166
+ target_col: .obs column that stores the perturbations. Defaults to "perturbations".
167
+ layer_key: Layer in adata to use. Defaults to None.
168
+ hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
169
+ will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
170
+ Defaults to [512].
171
+ dropout: Amount of dropout applied, constant for all layers. Defaults to 0.
172
+ batch_norm: Whether to apply batch normalization. Defaults to True.
173
+ batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
174
+ test_split_size: Fraction of data to put in the test set. Default to 0.2.
175
+ validation_split_size: Fraction of data to put in the validation set of the resultant train set.
176
+ E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
177
+ will be used for validation. Defaults to 0.25.
178
+ max_epochs: Maximum number of epochs for training. Defaults to 20.
179
+ val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
180
+ Note that this affects early stopping, as the model will be stopped if the validation performance does not
181
+ improve for patience epochs. Defaults to 2.
182
+ patience: Number of validation performance checks without improvement, after which the early stopping flag
183
+ is activated and training is therefore stopped. Defaults to 2.
184
+
185
+ Returns:
186
+ AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
187
+ The AnnData will have shape (n_cells, n_features) where n_features is the number of features in the last layer of the MLP.
188
+
189
+ Examples:
190
+ >>> import pertpy as pt
191
+ >>> adata = pt.dt.norman_2019()
192
+ >>> dcs = pt.tl.MLPClassifierSpace()
193
+ >>> cell_embeddings = dcs.compute(adata, target_col="perturbation_name")
194
+ """
195
+ if layer_key is not None and layer_key not in adata.obs.columns:
196
+ raise ValueError(f"Layer key {layer_key} not found in adata.")
197
+
198
+ if target_col not in adata.obs:
199
+ raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
200
+
201
+ if hidden_dim is None:
202
+ hidden_dim = [512]
203
+
204
+ # Labels are strings, one hot encoding for classification
205
+ n_classes = len(adata.obs[target_col].unique())
206
+ labels = adata.obs[target_col].values.reshape(-1, 1)
207
+ encoder = OneHotEncoder()
208
+ encoded_labels = encoder.fit_transform(labels).toarray()
209
+ adata.obs["encoded_perturbations"] = [np.float32(label) for label in encoded_labels]
210
+
211
+ # Split the data in train, test and validation
212
+ X = list(range(0, adata.n_obs))
213
+ y = adata.obs[target_col]
214
+
215
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
216
+ X_train, X_val, y_train, y_val = train_test_split(
217
+ X_train, y_train, test_size=validation_split_size, stratify=y_train
218
+ )
219
+
220
+ train_dataset = PLDataset(
221
+ adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
222
+ )
223
+ val_dataset = PLDataset(
224
+ adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
225
+ )
226
+ test_dataset = PLDataset(
227
+ adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
228
+ ) # we don't need to pass y_test since the label selection is done inside
229
+
230
+ # Fix class unbalance (likely to happen in perturbation datasets)
231
+ # Usually control cells are overrepresented such that predicting control all time would give good results
232
+ # Cells with rare perturbations are sampled more
233
+ train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
234
+ train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
235
+
236
+ self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
237
+ self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
238
+ self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
239
+
240
+ # Define the network
241
+ sizes = [adata.n_vars] + hidden_dim + [n_classes]
242
+ self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
243
+
244
+ # Define a dataset that gathers all the data and dataloader for getting embeddings
245
+ total_dataset = PLDataset(
246
+ adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
247
+ )
248
+ self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=0)
249
+
250
+ # Save adata observations for embedding annotations in get_embeddings
251
+ self.adata_obs = adata.obs.reset_index(drop=True)
252
+
253
+ self.trainer = pl.Trainer(
254
+ min_epochs=1,
255
+ max_epochs=max_epochs,
256
+ check_val_every_n_epoch=val_epochs_check,
257
+ callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
258
+ devices="auto",
259
+ accelerator="auto",
260
+ )
261
+
262
+ self.mlp = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)
263
+
264
+ self.trainer.fit(model=self.mlp, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader)
265
+ self.trainer.test(model=self.mlp, dataloaders=self.test_dataloader)
266
+
267
+ # Obtain cell embeddings
268
+ with torch.no_grad():
269
+ self.mlp.eval()
270
+ for dataset_count, batch in enumerate(self.entire_dataset):
271
+ emb, y = self.mlp.get_embeddings(batch)
272
+ emb = torch.squeeze(emb)
273
+ batch_adata = AnnData(X=emb.cpu().numpy())
274
+ batch_adata.obs["perturbations"] = y
275
+ if dataset_count == 0:
276
+ pert_adata = batch_adata
277
+ else:
278
+ pert_adata = anndata.concat([pert_adata, batch_adata])
279
+
280
+ # Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
281
+ # and we can just add the annotations from the original AnnData object
282
+ pert_adata.obs = pert_adata.obs.reset_index(drop=True)
283
+ if "perturbations" in self.adata_obs.columns:
284
+ self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
285
+ pert_adata.obs = pd.concat([pert_adata.obs, self.adata_obs], axis=1)
286
+
287
+ # Drop the 'encoded_perturbations' colums, since this stores the one-hot encoded labels as numpy arrays,
288
+ # which would cause errors in the downstream processing of the AnnData object (e.g. when plotting)
289
+ pert_adata.obs = pert_adata.obs.drop("encoded_perturbations", axis=1)
290
+
291
+ return pert_adata
292
+
293
+ def load(self, adata, **kwargs):
294
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
295
+ raise DeprecationWarning(
296
+ "The load method is deprecated and will be removed in the future. Please use the compute method instead."
297
+ )
298
+
299
+ def train(self, **kwargs):
300
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
301
+ raise DeprecationWarning(
302
+ "The train method is deprecated and will be removed in the future. Please use the compute method instead."
303
+ )
304
+
305
+ def get_embeddings(self, **kwargs):
306
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
307
+ raise DeprecationWarning(
308
+ "The get_embeddings method is deprecated and will be removed in the future. Please use the compute method instead."
309
+ )
310
+
311
+
312
+ class MLP(torch.nn.Module):
313
+ """
314
+ A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ sizes: list[int],
320
+ dropout: float = 0.0,
321
+ batch_norm: bool = True,
322
+ layer_norm: bool = False,
323
+ last_layer_act: str = "linear",
324
+ ) -> None:
325
+ """
326
+ Args:
327
+ sizes: size of layers.
328
+ dropout: Dropout probability. Defaults to 0.0.
329
+ batch_norm: specifies if batch norm should be applied. Defaults to True.
330
+ layer_norm: specifies if layer norm should be applied, as commonly used in Transformers. Defaults to False.
331
+ last_layer_act: activation function of last layer. Defaults to "linear".
332
+ """
333
+ super().__init__()
334
+ layers = []
335
+ for s in range(len(sizes) - 1):
336
+ layers += [
337
+ torch.nn.Linear(sizes[s], sizes[s + 1]),
338
+ torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
339
+ torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
340
+ torch.nn.ReLU(),
341
+ torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
342
+ ]
343
+
344
+ layers = [layer for layer in layers if layer is not None][:-1]
345
+ self.activation = last_layer_act
346
+ if self.activation == "linear":
347
+ pass
348
+ elif self.activation == "ReLU":
349
+ self.relu = torch.nn.ReLU()
350
+ else:
351
+ raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
352
+
353
+ self.network = torch.nn.Sequential(*layers)
354
+
355
+ self.network.apply(init_weights)
356
+
357
+ self.sizes = sizes
358
+ self.batch_norm = batch_norm
359
+ self.layer_norm = layer_norm
360
+ self.last_layer_act = last_layer_act
361
+
362
+ def forward(self, x) -> torch.Tensor:
363
+ if self.activation == "ReLU":
364
+ return self.relu(self.network(x))
365
+ return self.network(x)
366
+
367
+ def embedding(self, x) -> torch.Tensor:
368
+ for layer in self.network[:-1]:
369
+ x = layer(x)
370
+ return x
371
+
372
+
373
+ def init_weights(m):
374
+ if isinstance(m, torch.nn.Linear):
375
+ torch.nn.init.kaiming_uniform_(m.weight)
376
+ m.bias.data.fill_(0.01)
377
+
378
+
379
+ class PLDataset(Dataset):
380
+ """
381
+ Dataset for perturbation classification.
382
+ Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
383
+ """
384
+
385
+ def __init__(
386
+ self,
387
+ adata: np.array,
388
+ target_col: str = "perturbations",
389
+ label_col: str = "perturbations",
390
+ layer_key: str = None,
391
+ ):
392
+ """
393
+ Args:
394
+ adata: AnnData object with observations and labels.
395
+ target_col: key with the perturbation labels numerically encoded. Defaults to 'perturbations'.
396
+ label_col: key with the perturbation labels. Defaults to 'perturbations'.
397
+ layer_key: key of the layer to be used as data, otherwise .X
398
+ """
399
+
400
+ if layer_key:
401
+ self.data = adata.layers[layer_key]
402
+ else:
403
+ self.data = adata.X
404
+
405
+ self.labels = adata.obs[target_col]
406
+ self.pert_labels = adata.obs[label_col]
407
+
408
+ def __len__(self):
409
+ return self.data.shape[0]
410
+
411
+ def __getitem__(self, idx):
412
+ """Returns a sample and corresponding perturbations applied (labels)"""
413
+ sample = self.data[idx].A.squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
414
+ num_label = self.labels.iloc[idx]
415
+ str_label = self.pert_labels.iloc[idx]
416
+
417
+ return sample, num_label, str_label
418
+
419
+
420
+ class PerturbationClassifier(pl.LightningModule):
421
+ def __init__(
422
+ self,
423
+ model: torch.nn.Module,
424
+ batch_size: int,
425
+ layers: list = [512], # noqa
426
+ dropout: float = 0.0,
427
+ batch_norm: bool = True,
428
+ layer_norm: bool = False,
429
+ last_layer_act: str = "linear",
430
+ lr=1e-4,
431
+ seed=42,
432
+ ):
433
+ """
434
+ Args:
435
+ model: model to be trained
436
+ batch_size: batch size
437
+ layers: list of layers of the MLP
438
+ dropout: dropout probability
439
+ batch_norm: whether to apply batch norm
440
+ layer_norm: whether to apply layer norm
441
+ last_layer_act: activation function of last layer
442
+ lr: learning rate
443
+ seed: random seed
444
+ """
445
+ super().__init__()
446
+ self.batch_size = batch_size
447
+ self.save_hyperparameters()
448
+ if model:
449
+ self.net = model
450
+ else:
451
+ self._create_model()
452
+
453
+ def _create_model(self):
454
+ self.net = MLP(
455
+ sizes=self.hparams.layers,
456
+ dropout=self.hparams.dropout,
457
+ batch_norm=self.hparams.batch_norm,
458
+ layer_norm=self.hparams.layer_norm,
459
+ last_layer_act=self.hparams.last_layer_act,
460
+ )
461
+
462
+ def forward(self, x):
463
+ x = self.net(x)
464
+ return x
465
+
466
+ def configure_optimizers(self):
467
+ optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
468
+
469
+ return optimizer
470
+
471
+ def training_step(self, batch, batch_idx):
472
+ x, y, _ = batch
473
+ x = x.to(torch.float32)
474
+
475
+ y_hat = self.forward(x)
476
+
477
+ y = torch.argmax(y, dim=1)
478
+ y_hat = y_hat.squeeze()
479
+
480
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
481
+ self.log("train_loss", loss, prog_bar=True, batch_size=self.batch_size)
482
+
483
+ return loss
484
+
485
+ def validation_step(self, batch, batch_idx):
486
+ x, y, _ = batch
487
+ x = x.to(torch.float32)
488
+
489
+ y_hat = self.forward(x)
490
+
491
+ y = torch.argmax(y, dim=1)
492
+ y_hat = y_hat.squeeze()
493
+
494
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
495
+ self.log("val_loss", loss, prog_bar=True, batch_size=self.batch_size)
496
+
497
+ return loss
498
+
499
+ def test_step(self, batch, batch_idx):
500
+ x, y, _ = batch
501
+ x = x.to(torch.float32)
502
+
503
+ y_hat = self.forward(x)
504
+
505
+ y = torch.argmax(y, dim=1)
506
+ y_hat = y_hat.squeeze()
507
+
508
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
509
+ self.log("test_loss", loss, prog_bar=True, batch_size=self.batch_size)
510
+
511
+ return loss
512
+
513
+ def embedding(self, x):
514
+ """
515
+ Inputs:
516
+ x: Input features of shape [Batch, SeqLen, 1]
517
+ """
518
+ x = self.net.embedding(x)
519
+ return x
520
+
521
+ def get_embeddings(self, batch):
522
+ x, _, y = batch
523
+ x = x.to(torch.float32)
524
+
525
+ embedding = self.embedding(x)
526
+ return embedding, y