pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -31,6 +31,8 @@ class ClusteringSpace(PerturbationSpace):
31
31
  true_label_col: ground truth labels.
32
32
  cluster_col: cluster computed labels.
33
33
  metrics: Metrics to compute. Defaults to ['nmi', 'ari', 'asw'].
34
+ **kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed.
35
+ For asw, metric, distances, sample_size, and random_state can be passed.
34
36
 
35
37
  Examples:
36
38
  Example usage with KMeansSpace:
@@ -39,7 +41,9 @@ class ClusteringSpace(PerturbationSpace):
39
41
  >>> mdata = pt.dt.papalexi_2021()
40
42
  >>> kmeans = pt.tl.KMeansSpace()
41
43
  >>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26)
42
- >>> results = kmeans.evaluate_clustering(kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=['nmi'])
44
+ >>> results = kmeans.evaluate_clustering(
45
+ ... kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=["nmi"]
46
+ ... )
43
47
  """
44
48
  if metrics is None:
45
49
  metrics = ["nmi", "ari", "asw"]
@@ -0,0 +1,526 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import TYPE_CHECKING, Literal
5
+
6
+ import anndata
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pytorch_lightning as pl
10
+ import scipy
11
+ import torch
12
+ from anndata import AnnData
13
+ from pytorch_lightning.callbacks import EarlyStopping
14
+ from sklearn.linear_model import LogisticRegression
15
+ from sklearn.model_selection import train_test_split
16
+ from sklearn.preprocessing import OneHotEncoder
17
+ from torch import optim
18
+ from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
19
+
20
+ from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
21
+
22
+
23
+ class LRClassifierSpace(PerturbationSpace):
24
+ """Fits a logistic regression model to the data and takes the feature space as embedding.
25
+
26
+ We fit one logistic regression model per perturbation. After training, the coefficients of the logistic regression
27
+ model are used as the feature space. This results in one embedding per perturbation.
28
+ """
29
+
30
+ def compute(
31
+ self,
32
+ adata: AnnData,
33
+ target_col: str = "perturbations",
34
+ layer_key: str = None,
35
+ embedding_key: str = None,
36
+ test_split_size: float = 0.2,
37
+ max_iter: int = 1000,
38
+ ):
39
+ """
40
+ Fits a logistic regression model to the data and takes the coefficients of the logistic regression
41
+ model as perturbation embedding.
42
+
43
+ Args:
44
+ adata: AnnData object of size cells x genes
45
+ target_col: .obs column that stores the perturbations. Defaults to "perturbations".
46
+ layer_key: Layer in adata to use. Defaults to None.
47
+ embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
48
+ Can only be specified if layer_key is None. Defaults to None.
49
+ test_split_size: Fraction of data to put in the test set. Default to 0.2.
50
+ max_iter: Maximum number of iterations taken for the solvers to converge. Defaults to 1000.
51
+
52
+ Returns:
53
+ AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
54
+
55
+ Examples:
56
+ >>> import pertpy as pt
57
+ >>> adata = pt.dt.norman_2019()
58
+ >>> rcs = pt.tl.LRClassifierSpace()
59
+ >>> pert_embeddings = rcs.compute(adata, embedding_key="X_pca", target_col="perturbation_name")
60
+ """
61
+ if layer_key is not None and layer_key not in adata.obs.columns:
62
+ raise ValueError(f"Layer key {layer_key} not found in adata.")
63
+
64
+ if embedding_key is not None and embedding_key not in adata.obsm.keys():
65
+ raise ValueError(f"Embedding key {embedding_key} not found in adata.obsm.")
66
+
67
+ if layer_key is not None and embedding_key is not None:
68
+ raise ValueError("Cannot specify both layer_key and embedding_key.")
69
+
70
+ if target_col not in adata.obs:
71
+ raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
72
+
73
+ if layer_key is not None:
74
+ regression_data = adata.layers[layer_key]
75
+ elif embedding_key is not None:
76
+ regression_data = adata.obsm[embedding_key]
77
+ else:
78
+ regression_data = adata.X
79
+
80
+ regression_labels = adata.obs[target_col]
81
+
82
+ # Save adata observations for embedding annotations in get_embeddings
83
+ adata_obs = adata.obs.reset_index(drop=True)
84
+ adata_obs = adata_obs.groupby(target_col).agg(
85
+ lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
86
+ )
87
+
88
+ # Fit a logistic regression model for each perturbation
89
+ regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")
90
+ regression_embeddings = {}
91
+ regression_scores = {}
92
+
93
+ for perturbation in regression_labels.unique():
94
+ labels = np.where(regression_labels == perturbation, 1, 0)
95
+ X_train, X_test, y_train, y_test = train_test_split(
96
+ regression_data, labels, test_size=test_split_size, stratify=labels
97
+ )
98
+
99
+ regression_model.fit(X_train, y_train)
100
+ regression_embeddings[perturbation] = regression_model.coef_
101
+ regression_scores[perturbation] = regression_model.score(X_test, y_test)
102
+
103
+ # Save the regression embeddings and scores in an AnnData object
104
+ pert_adata = AnnData(X=np.array(list(regression_embeddings.values())).squeeze())
105
+ pert_adata.obs["perturbations"] = list(regression_embeddings.keys())
106
+ pert_adata.obs["classifier_score"] = list(regression_scores.values())
107
+
108
+ # Save adata observations for embedding annotations
109
+ for obs_name in adata_obs.columns:
110
+ if not adata_obs[obs_name].isnull().values.any():
111
+ pert_adata.obs[obs_name] = pert_adata.obs["perturbations"].map(
112
+ {pert: adata_obs.loc[pert][obs_name] for pert in adata_obs.index}
113
+ )
114
+
115
+ return pert_adata
116
+
117
+
118
+ # Ensure backward compatibility with DiscriminatorClassifierSpace
119
+ def DiscriminatorClassifierSpace():
120
+ warnings.warn(
121
+ "The DiscriminatorClassifierSpace class is deprecated and will be removed in the future."
122
+ "Please use the MLPClassifierSpace or the LRClassifierSpace class instead.",
123
+ DeprecationWarning,
124
+ stacklevel=2,
125
+ )
126
+
127
+ return MLPClassifierSpace()
128
+
129
+
130
+ class MLPClassifierSpace(PerturbationSpace):
131
+ """Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
132
+
133
+ We train the ANN to classify the different perturbations. After training, the penultimate layer is used as the
134
+ feature space, resulting in one embedding per cell. Consider employing the PseudoBulk or another PerturbationSpace
135
+ to obtain one embedding per perturbation.
136
+
137
+ See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
138
+ """
139
+
140
+ def compute( # type: ignore
141
+ self,
142
+ adata: AnnData,
143
+ target_col: str = "perturbations",
144
+ layer_key: str = None,
145
+ hidden_dim: list[int] = None,
146
+ dropout: float = 0.0,
147
+ batch_norm: bool = True,
148
+ batch_size: int = 256,
149
+ test_split_size: float = 0.2,
150
+ validation_split_size: float = 0.25,
151
+ max_epochs: int = 20,
152
+ val_epochs_check: int = 2,
153
+ patience: int = 2,
154
+ ) -> AnnData:
155
+ """Creates cell embeddings by training a MLP classifier model to distinguish between perturbations.
156
+
157
+ A model is created using the specified parameters (hidden_dim, dropout, batch_norm). Further parameters such as
158
+ the number of classes to predict (number of perturbations) are obtained from the provided AnnData object directly.
159
+ Dataloaders that take into account class imbalances are created. Next, the model is trained and tested, using the
160
+ GPU if available. The embeddings are obtained by passing the data through the model and extracting the values in
161
+ the last layer of the MLP. You will get one embedding per cell, so be aware that you might need to apply another
162
+ perturbation space to aggregate the embeddings per perturbation.
163
+
164
+ Args:
165
+ adata: AnnData object of size cells x genes
166
+ target_col: .obs column that stores the perturbations. Defaults to "perturbations".
167
+ layer_key: Layer in adata to use. Defaults to None.
168
+ hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
169
+ will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
170
+ Defaults to [512].
171
+ dropout: Amount of dropout applied, constant for all layers. Defaults to 0.
172
+ batch_norm: Whether to apply batch normalization. Defaults to True.
173
+ batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
174
+ test_split_size: Fraction of data to put in the test set. Default to 0.2.
175
+ validation_split_size: Fraction of data to put in the validation set of the resultant train set.
176
+ E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
177
+ will be used for validation. Defaults to 0.25.
178
+ max_epochs: Maximum number of epochs for training. Defaults to 20.
179
+ val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
180
+ Note that this affects early stopping, as the model will be stopped if the validation performance does not
181
+ improve for patience epochs. Defaults to 2.
182
+ patience: Number of validation performance checks without improvement, after which the early stopping flag
183
+ is activated and training is therefore stopped. Defaults to 2.
184
+
185
+ Returns:
186
+ AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
187
+ The AnnData will have shape (n_cells, n_features) where n_features is the number of features in the last layer of the MLP.
188
+
189
+ Examples:
190
+ >>> import pertpy as pt
191
+ >>> adata = pt.dt.norman_2019()
192
+ >>> dcs = pt.tl.MLPClassifierSpace()
193
+ >>> cell_embeddings = dcs.compute(adata, target_col="perturbation_name")
194
+ """
195
+ if layer_key is not None and layer_key not in adata.obs.columns:
196
+ raise ValueError(f"Layer key {layer_key} not found in adata.")
197
+
198
+ if target_col not in adata.obs:
199
+ raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
200
+
201
+ if hidden_dim is None:
202
+ hidden_dim = [512]
203
+
204
+ # Labels are strings, one hot encoding for classification
205
+ n_classes = len(adata.obs[target_col].unique())
206
+ labels = adata.obs[target_col].values.reshape(-1, 1)
207
+ encoder = OneHotEncoder()
208
+ encoded_labels = encoder.fit_transform(labels).toarray()
209
+ adata.obs["encoded_perturbations"] = [np.float32(label) for label in encoded_labels]
210
+
211
+ # Split the data in train, test and validation
212
+ X = list(range(0, adata.n_obs))
213
+ y = adata.obs[target_col]
214
+
215
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
216
+ X_train, X_val, y_train, y_val = train_test_split(
217
+ X_train, y_train, test_size=validation_split_size, stratify=y_train
218
+ )
219
+
220
+ train_dataset = PLDataset(
221
+ adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
222
+ )
223
+ val_dataset = PLDataset(
224
+ adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
225
+ )
226
+ test_dataset = PLDataset(
227
+ adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
228
+ ) # we don't need to pass y_test since the label selection is done inside
229
+
230
+ # Fix class unbalance (likely to happen in perturbation datasets)
231
+ # Usually control cells are overrepresented such that predicting control all time would give good results
232
+ # Cells with rare perturbations are sampled more
233
+ train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
234
+ train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
235
+
236
+ self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
237
+ self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
238
+ self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
239
+
240
+ # Define the network
241
+ sizes = [adata.n_vars] + hidden_dim + [n_classes]
242
+ self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
243
+
244
+ # Define a dataset that gathers all the data and dataloader for getting embeddings
245
+ total_dataset = PLDataset(
246
+ adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
247
+ )
248
+ self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=0)
249
+
250
+ # Save adata observations for embedding annotations in get_embeddings
251
+ self.adata_obs = adata.obs.reset_index(drop=True)
252
+
253
+ self.trainer = pl.Trainer(
254
+ min_epochs=1,
255
+ max_epochs=max_epochs,
256
+ check_val_every_n_epoch=val_epochs_check,
257
+ callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
258
+ devices="auto",
259
+ accelerator="auto",
260
+ )
261
+
262
+ self.mlp = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)
263
+
264
+ self.trainer.fit(model=self.mlp, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader)
265
+ self.trainer.test(model=self.mlp, dataloaders=self.test_dataloader)
266
+
267
+ # Obtain cell embeddings
268
+ with torch.no_grad():
269
+ self.mlp.eval()
270
+ for dataset_count, batch in enumerate(self.entire_dataset):
271
+ emb, y = self.mlp.get_embeddings(batch)
272
+ emb = torch.squeeze(emb)
273
+ batch_adata = AnnData(X=emb.cpu().numpy())
274
+ batch_adata.obs["perturbations"] = y
275
+ if dataset_count == 0:
276
+ pert_adata = batch_adata
277
+ else:
278
+ pert_adata = anndata.concat([pert_adata, batch_adata])
279
+
280
+ # Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
281
+ # and we can just add the annotations from the original AnnData object
282
+ pert_adata.obs = pert_adata.obs.reset_index(drop=True)
283
+ if "perturbations" in self.adata_obs.columns:
284
+ self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
285
+ pert_adata.obs = pd.concat([pert_adata.obs, self.adata_obs], axis=1)
286
+
287
+ # Drop the 'encoded_perturbations' colums, since this stores the one-hot encoded labels as numpy arrays,
288
+ # which would cause errors in the downstream processing of the AnnData object (e.g. when plotting)
289
+ pert_adata.obs = pert_adata.obs.drop("encoded_perturbations", axis=1)
290
+
291
+ return pert_adata
292
+
293
+ def load(self, adata, **kwargs):
294
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
295
+ raise DeprecationWarning(
296
+ "The load method is deprecated and will be removed in the future. Please use the compute method instead."
297
+ )
298
+
299
+ def train(self, **kwargs):
300
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
301
+ raise DeprecationWarning(
302
+ "The train method is deprecated and will be removed in the future. Please use the compute method instead."
303
+ )
304
+
305
+ def get_embeddings(self, **kwargs):
306
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
307
+ raise DeprecationWarning(
308
+ "The get_embeddings method is deprecated and will be removed in the future. Please use the compute method instead."
309
+ )
310
+
311
+
312
+ class MLP(torch.nn.Module):
313
+ """
314
+ A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ sizes: list[int],
320
+ dropout: float = 0.0,
321
+ batch_norm: bool = True,
322
+ layer_norm: bool = False,
323
+ last_layer_act: str = "linear",
324
+ ) -> None:
325
+ """
326
+ Args:
327
+ sizes: size of layers.
328
+ dropout: Dropout probability. Defaults to 0.0.
329
+ batch_norm: specifies if batch norm should be applied. Defaults to True.
330
+ layer_norm: specifies if layer norm should be applied, as commonly used in Transformers. Defaults to False.
331
+ last_layer_act: activation function of last layer. Defaults to "linear".
332
+ """
333
+ super().__init__()
334
+ layers = []
335
+ for s in range(len(sizes) - 1):
336
+ layers += [
337
+ torch.nn.Linear(sizes[s], sizes[s + 1]),
338
+ torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
339
+ torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
340
+ torch.nn.ReLU(),
341
+ torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
342
+ ]
343
+
344
+ layers = [layer for layer in layers if layer is not None][:-1]
345
+ self.activation = last_layer_act
346
+ if self.activation == "linear":
347
+ pass
348
+ elif self.activation == "ReLU":
349
+ self.relu = torch.nn.ReLU()
350
+ else:
351
+ raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
352
+
353
+ self.network = torch.nn.Sequential(*layers)
354
+
355
+ self.network.apply(init_weights)
356
+
357
+ self.sizes = sizes
358
+ self.batch_norm = batch_norm
359
+ self.layer_norm = layer_norm
360
+ self.last_layer_act = last_layer_act
361
+
362
+ def forward(self, x) -> torch.Tensor:
363
+ if self.activation == "ReLU":
364
+ return self.relu(self.network(x))
365
+ return self.network(x)
366
+
367
+ def embedding(self, x) -> torch.Tensor:
368
+ for layer in self.network[:-1]:
369
+ x = layer(x)
370
+ return x
371
+
372
+
373
+ def init_weights(m):
374
+ if isinstance(m, torch.nn.Linear):
375
+ torch.nn.init.kaiming_uniform_(m.weight)
376
+ m.bias.data.fill_(0.01)
377
+
378
+
379
+ class PLDataset(Dataset):
380
+ """
381
+ Dataset for perturbation classification.
382
+ Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
383
+ """
384
+
385
+ def __init__(
386
+ self,
387
+ adata: np.array,
388
+ target_col: str = "perturbations",
389
+ label_col: str = "perturbations",
390
+ layer_key: str = None,
391
+ ):
392
+ """
393
+ Args:
394
+ adata: AnnData object with observations and labels.
395
+ target_col: key with the perturbation labels numerically encoded. Defaults to 'perturbations'.
396
+ label_col: key with the perturbation labels. Defaults to 'perturbations'.
397
+ layer_key: key of the layer to be used as data, otherwise .X
398
+ """
399
+
400
+ if layer_key:
401
+ self.data = adata.layers[layer_key]
402
+ else:
403
+ self.data = adata.X
404
+
405
+ self.labels = adata.obs[target_col]
406
+ self.pert_labels = adata.obs[label_col]
407
+
408
+ def __len__(self):
409
+ return self.data.shape[0]
410
+
411
+ def __getitem__(self, idx):
412
+ """Returns a sample and corresponding perturbations applied (labels)"""
413
+ sample = self.data[idx].A.squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
414
+ num_label = self.labels.iloc[idx]
415
+ str_label = self.pert_labels.iloc[idx]
416
+
417
+ return sample, num_label, str_label
418
+
419
+
420
+ class PerturbationClassifier(pl.LightningModule):
421
+ def __init__(
422
+ self,
423
+ model: torch.nn.Module,
424
+ batch_size: int,
425
+ layers: list = [512], # noqa
426
+ dropout: float = 0.0,
427
+ batch_norm: bool = True,
428
+ layer_norm: bool = False,
429
+ last_layer_act: str = "linear",
430
+ lr=1e-4,
431
+ seed=42,
432
+ ):
433
+ """
434
+ Args:
435
+ model: model to be trained
436
+ batch_size: batch size
437
+ layers: list of layers of the MLP
438
+ dropout: dropout probability
439
+ batch_norm: whether to apply batch norm
440
+ layer_norm: whether to apply layer norm
441
+ last_layer_act: activation function of last layer
442
+ lr: learning rate
443
+ seed: random seed
444
+ """
445
+ super().__init__()
446
+ self.batch_size = batch_size
447
+ self.save_hyperparameters()
448
+ if model:
449
+ self.net = model
450
+ else:
451
+ self._create_model()
452
+
453
+ def _create_model(self):
454
+ self.net = MLP(
455
+ sizes=self.hparams.layers,
456
+ dropout=self.hparams.dropout,
457
+ batch_norm=self.hparams.batch_norm,
458
+ layer_norm=self.hparams.layer_norm,
459
+ last_layer_act=self.hparams.last_layer_act,
460
+ )
461
+
462
+ def forward(self, x):
463
+ x = self.net(x)
464
+ return x
465
+
466
+ def configure_optimizers(self):
467
+ optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
468
+
469
+ return optimizer
470
+
471
+ def training_step(self, batch, batch_idx):
472
+ x, y, _ = batch
473
+ x = x.to(torch.float32)
474
+
475
+ y_hat = self.forward(x)
476
+
477
+ y = torch.argmax(y, dim=1)
478
+ y_hat = y_hat.squeeze()
479
+
480
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
481
+ self.log("train_loss", loss, prog_bar=True, batch_size=self.batch_size)
482
+
483
+ return loss
484
+
485
+ def validation_step(self, batch, batch_idx):
486
+ x, y, _ = batch
487
+ x = x.to(torch.float32)
488
+
489
+ y_hat = self.forward(x)
490
+
491
+ y = torch.argmax(y, dim=1)
492
+ y_hat = y_hat.squeeze()
493
+
494
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
495
+ self.log("val_loss", loss, prog_bar=True, batch_size=self.batch_size)
496
+
497
+ return loss
498
+
499
+ def test_step(self, batch, batch_idx):
500
+ x, y, _ = batch
501
+ x = x.to(torch.float32)
502
+
503
+ y_hat = self.forward(x)
504
+
505
+ y = torch.argmax(y, dim=1)
506
+ y_hat = y_hat.squeeze()
507
+
508
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
509
+ self.log("test_loss", loss, prog_bar=True, batch_size=self.batch_size)
510
+
511
+ return loss
512
+
513
+ def embedding(self, x):
514
+ """
515
+ Inputs:
516
+ x: Input features of shape [Batch, SeqLen, 1]
517
+ """
518
+ x = self.net.embedding(x)
519
+ return x
520
+
521
+ def get_embeddings(self, batch):
522
+ x, _, y = batch
523
+ x = x.to(torch.float32)
524
+
525
+ embedding = self.embedding(x)
526
+ return embedding, y