pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -31,6 +31,8 @@ class ClusteringSpace(PerturbationSpace):
|
|
31
31
|
true_label_col: ground truth labels.
|
32
32
|
cluster_col: cluster computed labels.
|
33
33
|
metrics: Metrics to compute. Defaults to ['nmi', 'ari', 'asw'].
|
34
|
+
**kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed.
|
35
|
+
For asw, metric, distances, sample_size, and random_state can be passed.
|
34
36
|
|
35
37
|
Examples:
|
36
38
|
Example usage with KMeansSpace:
|
@@ -39,7 +41,9 @@ class ClusteringSpace(PerturbationSpace):
|
|
39
41
|
>>> mdata = pt.dt.papalexi_2021()
|
40
42
|
>>> kmeans = pt.tl.KMeansSpace()
|
41
43
|
>>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26)
|
42
|
-
>>> results = kmeans.evaluate_clustering(
|
44
|
+
>>> results = kmeans.evaluate_clustering(
|
45
|
+
... kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=["nmi"]
|
46
|
+
... )
|
43
47
|
"""
|
44
48
|
if metrics is None:
|
45
49
|
metrics = ["nmi", "ari", "asw"]
|
@@ -0,0 +1,526 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import warnings
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
5
|
+
|
6
|
+
import anndata
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import pytorch_lightning as pl
|
10
|
+
import scipy
|
11
|
+
import torch
|
12
|
+
from anndata import AnnData
|
13
|
+
from pytorch_lightning.callbacks import EarlyStopping
|
14
|
+
from sklearn.linear_model import LogisticRegression
|
15
|
+
from sklearn.model_selection import train_test_split
|
16
|
+
from sklearn.preprocessing import OneHotEncoder
|
17
|
+
from torch import optim
|
18
|
+
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
|
19
|
+
|
20
|
+
from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
|
21
|
+
|
22
|
+
|
23
|
+
class LRClassifierSpace(PerturbationSpace):
|
24
|
+
"""Fits a logistic regression model to the data and takes the feature space as embedding.
|
25
|
+
|
26
|
+
We fit one logistic regression model per perturbation. After training, the coefficients of the logistic regression
|
27
|
+
model are used as the feature space. This results in one embedding per perturbation.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def compute(
|
31
|
+
self,
|
32
|
+
adata: AnnData,
|
33
|
+
target_col: str = "perturbations",
|
34
|
+
layer_key: str = None,
|
35
|
+
embedding_key: str = None,
|
36
|
+
test_split_size: float = 0.2,
|
37
|
+
max_iter: int = 1000,
|
38
|
+
):
|
39
|
+
"""
|
40
|
+
Fits a logistic regression model to the data and takes the coefficients of the logistic regression
|
41
|
+
model as perturbation embedding.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
adata: AnnData object of size cells x genes
|
45
|
+
target_col: .obs column that stores the perturbations. Defaults to "perturbations".
|
46
|
+
layer_key: Layer in adata to use. Defaults to None.
|
47
|
+
embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
|
48
|
+
Can only be specified if layer_key is None. Defaults to None.
|
49
|
+
test_split_size: Fraction of data to put in the test set. Default to 0.2.
|
50
|
+
max_iter: Maximum number of iterations taken for the solvers to converge. Defaults to 1000.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
|
54
|
+
|
55
|
+
Examples:
|
56
|
+
>>> import pertpy as pt
|
57
|
+
>>> adata = pt.dt.norman_2019()
|
58
|
+
>>> rcs = pt.tl.LRClassifierSpace()
|
59
|
+
>>> pert_embeddings = rcs.compute(adata, embedding_key="X_pca", target_col="perturbation_name")
|
60
|
+
"""
|
61
|
+
if layer_key is not None and layer_key not in adata.obs.columns:
|
62
|
+
raise ValueError(f"Layer key {layer_key} not found in adata.")
|
63
|
+
|
64
|
+
if embedding_key is not None and embedding_key not in adata.obsm.keys():
|
65
|
+
raise ValueError(f"Embedding key {embedding_key} not found in adata.obsm.")
|
66
|
+
|
67
|
+
if layer_key is not None and embedding_key is not None:
|
68
|
+
raise ValueError("Cannot specify both layer_key and embedding_key.")
|
69
|
+
|
70
|
+
if target_col not in adata.obs:
|
71
|
+
raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
|
72
|
+
|
73
|
+
if layer_key is not None:
|
74
|
+
regression_data = adata.layers[layer_key]
|
75
|
+
elif embedding_key is not None:
|
76
|
+
regression_data = adata.obsm[embedding_key]
|
77
|
+
else:
|
78
|
+
regression_data = adata.X
|
79
|
+
|
80
|
+
regression_labels = adata.obs[target_col]
|
81
|
+
|
82
|
+
# Save adata observations for embedding annotations in get_embeddings
|
83
|
+
adata_obs = adata.obs.reset_index(drop=True)
|
84
|
+
adata_obs = adata_obs.groupby(target_col).agg(
|
85
|
+
lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
|
86
|
+
)
|
87
|
+
|
88
|
+
# Fit a logistic regression model for each perturbation
|
89
|
+
regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")
|
90
|
+
regression_embeddings = {}
|
91
|
+
regression_scores = {}
|
92
|
+
|
93
|
+
for perturbation in regression_labels.unique():
|
94
|
+
labels = np.where(regression_labels == perturbation, 1, 0)
|
95
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
96
|
+
regression_data, labels, test_size=test_split_size, stratify=labels
|
97
|
+
)
|
98
|
+
|
99
|
+
regression_model.fit(X_train, y_train)
|
100
|
+
regression_embeddings[perturbation] = regression_model.coef_
|
101
|
+
regression_scores[perturbation] = regression_model.score(X_test, y_test)
|
102
|
+
|
103
|
+
# Save the regression embeddings and scores in an AnnData object
|
104
|
+
pert_adata = AnnData(X=np.array(list(regression_embeddings.values())).squeeze())
|
105
|
+
pert_adata.obs["perturbations"] = list(regression_embeddings.keys())
|
106
|
+
pert_adata.obs["classifier_score"] = list(regression_scores.values())
|
107
|
+
|
108
|
+
# Save adata observations for embedding annotations
|
109
|
+
for obs_name in adata_obs.columns:
|
110
|
+
if not adata_obs[obs_name].isnull().values.any():
|
111
|
+
pert_adata.obs[obs_name] = pert_adata.obs["perturbations"].map(
|
112
|
+
{pert: adata_obs.loc[pert][obs_name] for pert in adata_obs.index}
|
113
|
+
)
|
114
|
+
|
115
|
+
return pert_adata
|
116
|
+
|
117
|
+
|
118
|
+
# Ensure backward compatibility with DiscriminatorClassifierSpace
|
119
|
+
def DiscriminatorClassifierSpace():
|
120
|
+
warnings.warn(
|
121
|
+
"The DiscriminatorClassifierSpace class is deprecated and will be removed in the future."
|
122
|
+
"Please use the MLPClassifierSpace or the LRClassifierSpace class instead.",
|
123
|
+
DeprecationWarning,
|
124
|
+
stacklevel=2,
|
125
|
+
)
|
126
|
+
|
127
|
+
return MLPClassifierSpace()
|
128
|
+
|
129
|
+
|
130
|
+
class MLPClassifierSpace(PerturbationSpace):
|
131
|
+
"""Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
|
132
|
+
|
133
|
+
We train the ANN to classify the different perturbations. After training, the penultimate layer is used as the
|
134
|
+
feature space, resulting in one embedding per cell. Consider employing the PseudoBulk or another PerturbationSpace
|
135
|
+
to obtain one embedding per perturbation.
|
136
|
+
|
137
|
+
See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
|
138
|
+
"""
|
139
|
+
|
140
|
+
def compute( # type: ignore
|
141
|
+
self,
|
142
|
+
adata: AnnData,
|
143
|
+
target_col: str = "perturbations",
|
144
|
+
layer_key: str = None,
|
145
|
+
hidden_dim: list[int] = None,
|
146
|
+
dropout: float = 0.0,
|
147
|
+
batch_norm: bool = True,
|
148
|
+
batch_size: int = 256,
|
149
|
+
test_split_size: float = 0.2,
|
150
|
+
validation_split_size: float = 0.25,
|
151
|
+
max_epochs: int = 20,
|
152
|
+
val_epochs_check: int = 2,
|
153
|
+
patience: int = 2,
|
154
|
+
) -> AnnData:
|
155
|
+
"""Creates cell embeddings by training a MLP classifier model to distinguish between perturbations.
|
156
|
+
|
157
|
+
A model is created using the specified parameters (hidden_dim, dropout, batch_norm). Further parameters such as
|
158
|
+
the number of classes to predict (number of perturbations) are obtained from the provided AnnData object directly.
|
159
|
+
Dataloaders that take into account class imbalances are created. Next, the model is trained and tested, using the
|
160
|
+
GPU if available. The embeddings are obtained by passing the data through the model and extracting the values in
|
161
|
+
the last layer of the MLP. You will get one embedding per cell, so be aware that you might need to apply another
|
162
|
+
perturbation space to aggregate the embeddings per perturbation.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
adata: AnnData object of size cells x genes
|
166
|
+
target_col: .obs column that stores the perturbations. Defaults to "perturbations".
|
167
|
+
layer_key: Layer in adata to use. Defaults to None.
|
168
|
+
hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
|
169
|
+
will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
|
170
|
+
Defaults to [512].
|
171
|
+
dropout: Amount of dropout applied, constant for all layers. Defaults to 0.
|
172
|
+
batch_norm: Whether to apply batch normalization. Defaults to True.
|
173
|
+
batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
|
174
|
+
test_split_size: Fraction of data to put in the test set. Default to 0.2.
|
175
|
+
validation_split_size: Fraction of data to put in the validation set of the resultant train set.
|
176
|
+
E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
|
177
|
+
will be used for validation. Defaults to 0.25.
|
178
|
+
max_epochs: Maximum number of epochs for training. Defaults to 20.
|
179
|
+
val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
|
180
|
+
Note that this affects early stopping, as the model will be stopped if the validation performance does not
|
181
|
+
improve for patience epochs. Defaults to 2.
|
182
|
+
patience: Number of validation performance checks without improvement, after which the early stopping flag
|
183
|
+
is activated and training is therefore stopped. Defaults to 2.
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
|
187
|
+
The AnnData will have shape (n_cells, n_features) where n_features is the number of features in the last layer of the MLP.
|
188
|
+
|
189
|
+
Examples:
|
190
|
+
>>> import pertpy as pt
|
191
|
+
>>> adata = pt.dt.norman_2019()
|
192
|
+
>>> dcs = pt.tl.MLPClassifierSpace()
|
193
|
+
>>> cell_embeddings = dcs.compute(adata, target_col="perturbation_name")
|
194
|
+
"""
|
195
|
+
if layer_key is not None and layer_key not in adata.obs.columns:
|
196
|
+
raise ValueError(f"Layer key {layer_key} not found in adata.")
|
197
|
+
|
198
|
+
if target_col not in adata.obs:
|
199
|
+
raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
|
200
|
+
|
201
|
+
if hidden_dim is None:
|
202
|
+
hidden_dim = [512]
|
203
|
+
|
204
|
+
# Labels are strings, one hot encoding for classification
|
205
|
+
n_classes = len(adata.obs[target_col].unique())
|
206
|
+
labels = adata.obs[target_col].values.reshape(-1, 1)
|
207
|
+
encoder = OneHotEncoder()
|
208
|
+
encoded_labels = encoder.fit_transform(labels).toarray()
|
209
|
+
adata.obs["encoded_perturbations"] = [np.float32(label) for label in encoded_labels]
|
210
|
+
|
211
|
+
# Split the data in train, test and validation
|
212
|
+
X = list(range(0, adata.n_obs))
|
213
|
+
y = adata.obs[target_col]
|
214
|
+
|
215
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
|
216
|
+
X_train, X_val, y_train, y_val = train_test_split(
|
217
|
+
X_train, y_train, test_size=validation_split_size, stratify=y_train
|
218
|
+
)
|
219
|
+
|
220
|
+
train_dataset = PLDataset(
|
221
|
+
adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
222
|
+
)
|
223
|
+
val_dataset = PLDataset(
|
224
|
+
adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
225
|
+
)
|
226
|
+
test_dataset = PLDataset(
|
227
|
+
adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
228
|
+
) # we don't need to pass y_test since the label selection is done inside
|
229
|
+
|
230
|
+
# Fix class unbalance (likely to happen in perturbation datasets)
|
231
|
+
# Usually control cells are overrepresented such that predicting control all time would give good results
|
232
|
+
# Cells with rare perturbations are sampled more
|
233
|
+
train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
|
234
|
+
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
|
235
|
+
|
236
|
+
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
|
237
|
+
self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
238
|
+
self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
239
|
+
|
240
|
+
# Define the network
|
241
|
+
sizes = [adata.n_vars] + hidden_dim + [n_classes]
|
242
|
+
self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
|
243
|
+
|
244
|
+
# Define a dataset that gathers all the data and dataloader for getting embeddings
|
245
|
+
total_dataset = PLDataset(
|
246
|
+
adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
247
|
+
)
|
248
|
+
self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=0)
|
249
|
+
|
250
|
+
# Save adata observations for embedding annotations in get_embeddings
|
251
|
+
self.adata_obs = adata.obs.reset_index(drop=True)
|
252
|
+
|
253
|
+
self.trainer = pl.Trainer(
|
254
|
+
min_epochs=1,
|
255
|
+
max_epochs=max_epochs,
|
256
|
+
check_val_every_n_epoch=val_epochs_check,
|
257
|
+
callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
|
258
|
+
devices="auto",
|
259
|
+
accelerator="auto",
|
260
|
+
)
|
261
|
+
|
262
|
+
self.mlp = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)
|
263
|
+
|
264
|
+
self.trainer.fit(model=self.mlp, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader)
|
265
|
+
self.trainer.test(model=self.mlp, dataloaders=self.test_dataloader)
|
266
|
+
|
267
|
+
# Obtain cell embeddings
|
268
|
+
with torch.no_grad():
|
269
|
+
self.mlp.eval()
|
270
|
+
for dataset_count, batch in enumerate(self.entire_dataset):
|
271
|
+
emb, y = self.mlp.get_embeddings(batch)
|
272
|
+
emb = torch.squeeze(emb)
|
273
|
+
batch_adata = AnnData(X=emb.cpu().numpy())
|
274
|
+
batch_adata.obs["perturbations"] = y
|
275
|
+
if dataset_count == 0:
|
276
|
+
pert_adata = batch_adata
|
277
|
+
else:
|
278
|
+
pert_adata = anndata.concat([pert_adata, batch_adata])
|
279
|
+
|
280
|
+
# Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
|
281
|
+
# and we can just add the annotations from the original AnnData object
|
282
|
+
pert_adata.obs = pert_adata.obs.reset_index(drop=True)
|
283
|
+
if "perturbations" in self.adata_obs.columns:
|
284
|
+
self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
|
285
|
+
pert_adata.obs = pd.concat([pert_adata.obs, self.adata_obs], axis=1)
|
286
|
+
|
287
|
+
# Drop the 'encoded_perturbations' colums, since this stores the one-hot encoded labels as numpy arrays,
|
288
|
+
# which would cause errors in the downstream processing of the AnnData object (e.g. when plotting)
|
289
|
+
pert_adata.obs = pert_adata.obs.drop("encoded_perturbations", axis=1)
|
290
|
+
|
291
|
+
return pert_adata
|
292
|
+
|
293
|
+
def load(self, adata, **kwargs):
|
294
|
+
"""This method is deprecated and will be removed in the future. Please use the compute method instead."""
|
295
|
+
raise DeprecationWarning(
|
296
|
+
"The load method is deprecated and will be removed in the future. Please use the compute method instead."
|
297
|
+
)
|
298
|
+
|
299
|
+
def train(self, **kwargs):
|
300
|
+
"""This method is deprecated and will be removed in the future. Please use the compute method instead."""
|
301
|
+
raise DeprecationWarning(
|
302
|
+
"The train method is deprecated and will be removed in the future. Please use the compute method instead."
|
303
|
+
)
|
304
|
+
|
305
|
+
def get_embeddings(self, **kwargs):
|
306
|
+
"""This method is deprecated and will be removed in the future. Please use the compute method instead."""
|
307
|
+
raise DeprecationWarning(
|
308
|
+
"The get_embeddings method is deprecated and will be removed in the future. Please use the compute method instead."
|
309
|
+
)
|
310
|
+
|
311
|
+
|
312
|
+
class MLP(torch.nn.Module):
|
313
|
+
"""
|
314
|
+
A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
|
315
|
+
"""
|
316
|
+
|
317
|
+
def __init__(
|
318
|
+
self,
|
319
|
+
sizes: list[int],
|
320
|
+
dropout: float = 0.0,
|
321
|
+
batch_norm: bool = True,
|
322
|
+
layer_norm: bool = False,
|
323
|
+
last_layer_act: str = "linear",
|
324
|
+
) -> None:
|
325
|
+
"""
|
326
|
+
Args:
|
327
|
+
sizes: size of layers.
|
328
|
+
dropout: Dropout probability. Defaults to 0.0.
|
329
|
+
batch_norm: specifies if batch norm should be applied. Defaults to True.
|
330
|
+
layer_norm: specifies if layer norm should be applied, as commonly used in Transformers. Defaults to False.
|
331
|
+
last_layer_act: activation function of last layer. Defaults to "linear".
|
332
|
+
"""
|
333
|
+
super().__init__()
|
334
|
+
layers = []
|
335
|
+
for s in range(len(sizes) - 1):
|
336
|
+
layers += [
|
337
|
+
torch.nn.Linear(sizes[s], sizes[s + 1]),
|
338
|
+
torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
|
339
|
+
torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
|
340
|
+
torch.nn.ReLU(),
|
341
|
+
torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
|
342
|
+
]
|
343
|
+
|
344
|
+
layers = [layer for layer in layers if layer is not None][:-1]
|
345
|
+
self.activation = last_layer_act
|
346
|
+
if self.activation == "linear":
|
347
|
+
pass
|
348
|
+
elif self.activation == "ReLU":
|
349
|
+
self.relu = torch.nn.ReLU()
|
350
|
+
else:
|
351
|
+
raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
|
352
|
+
|
353
|
+
self.network = torch.nn.Sequential(*layers)
|
354
|
+
|
355
|
+
self.network.apply(init_weights)
|
356
|
+
|
357
|
+
self.sizes = sizes
|
358
|
+
self.batch_norm = batch_norm
|
359
|
+
self.layer_norm = layer_norm
|
360
|
+
self.last_layer_act = last_layer_act
|
361
|
+
|
362
|
+
def forward(self, x) -> torch.Tensor:
|
363
|
+
if self.activation == "ReLU":
|
364
|
+
return self.relu(self.network(x))
|
365
|
+
return self.network(x)
|
366
|
+
|
367
|
+
def embedding(self, x) -> torch.Tensor:
|
368
|
+
for layer in self.network[:-1]:
|
369
|
+
x = layer(x)
|
370
|
+
return x
|
371
|
+
|
372
|
+
|
373
|
+
def init_weights(m):
|
374
|
+
if isinstance(m, torch.nn.Linear):
|
375
|
+
torch.nn.init.kaiming_uniform_(m.weight)
|
376
|
+
m.bias.data.fill_(0.01)
|
377
|
+
|
378
|
+
|
379
|
+
class PLDataset(Dataset):
|
380
|
+
"""
|
381
|
+
Dataset for perturbation classification.
|
382
|
+
Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
|
383
|
+
"""
|
384
|
+
|
385
|
+
def __init__(
|
386
|
+
self,
|
387
|
+
adata: np.array,
|
388
|
+
target_col: str = "perturbations",
|
389
|
+
label_col: str = "perturbations",
|
390
|
+
layer_key: str = None,
|
391
|
+
):
|
392
|
+
"""
|
393
|
+
Args:
|
394
|
+
adata: AnnData object with observations and labels.
|
395
|
+
target_col: key with the perturbation labels numerically encoded. Defaults to 'perturbations'.
|
396
|
+
label_col: key with the perturbation labels. Defaults to 'perturbations'.
|
397
|
+
layer_key: key of the layer to be used as data, otherwise .X
|
398
|
+
"""
|
399
|
+
|
400
|
+
if layer_key:
|
401
|
+
self.data = adata.layers[layer_key]
|
402
|
+
else:
|
403
|
+
self.data = adata.X
|
404
|
+
|
405
|
+
self.labels = adata.obs[target_col]
|
406
|
+
self.pert_labels = adata.obs[label_col]
|
407
|
+
|
408
|
+
def __len__(self):
|
409
|
+
return self.data.shape[0]
|
410
|
+
|
411
|
+
def __getitem__(self, idx):
|
412
|
+
"""Returns a sample and corresponding perturbations applied (labels)"""
|
413
|
+
sample = self.data[idx].A.squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
|
414
|
+
num_label = self.labels.iloc[idx]
|
415
|
+
str_label = self.pert_labels.iloc[idx]
|
416
|
+
|
417
|
+
return sample, num_label, str_label
|
418
|
+
|
419
|
+
|
420
|
+
class PerturbationClassifier(pl.LightningModule):
|
421
|
+
def __init__(
|
422
|
+
self,
|
423
|
+
model: torch.nn.Module,
|
424
|
+
batch_size: int,
|
425
|
+
layers: list = [512], # noqa
|
426
|
+
dropout: float = 0.0,
|
427
|
+
batch_norm: bool = True,
|
428
|
+
layer_norm: bool = False,
|
429
|
+
last_layer_act: str = "linear",
|
430
|
+
lr=1e-4,
|
431
|
+
seed=42,
|
432
|
+
):
|
433
|
+
"""
|
434
|
+
Args:
|
435
|
+
model: model to be trained
|
436
|
+
batch_size: batch size
|
437
|
+
layers: list of layers of the MLP
|
438
|
+
dropout: dropout probability
|
439
|
+
batch_norm: whether to apply batch norm
|
440
|
+
layer_norm: whether to apply layer norm
|
441
|
+
last_layer_act: activation function of last layer
|
442
|
+
lr: learning rate
|
443
|
+
seed: random seed
|
444
|
+
"""
|
445
|
+
super().__init__()
|
446
|
+
self.batch_size = batch_size
|
447
|
+
self.save_hyperparameters()
|
448
|
+
if model:
|
449
|
+
self.net = model
|
450
|
+
else:
|
451
|
+
self._create_model()
|
452
|
+
|
453
|
+
def _create_model(self):
|
454
|
+
self.net = MLP(
|
455
|
+
sizes=self.hparams.layers,
|
456
|
+
dropout=self.hparams.dropout,
|
457
|
+
batch_norm=self.hparams.batch_norm,
|
458
|
+
layer_norm=self.hparams.layer_norm,
|
459
|
+
last_layer_act=self.hparams.last_layer_act,
|
460
|
+
)
|
461
|
+
|
462
|
+
def forward(self, x):
|
463
|
+
x = self.net(x)
|
464
|
+
return x
|
465
|
+
|
466
|
+
def configure_optimizers(self):
|
467
|
+
optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
|
468
|
+
|
469
|
+
return optimizer
|
470
|
+
|
471
|
+
def training_step(self, batch, batch_idx):
|
472
|
+
x, y, _ = batch
|
473
|
+
x = x.to(torch.float32)
|
474
|
+
|
475
|
+
y_hat = self.forward(x)
|
476
|
+
|
477
|
+
y = torch.argmax(y, dim=1)
|
478
|
+
y_hat = y_hat.squeeze()
|
479
|
+
|
480
|
+
loss = torch.nn.functional.cross_entropy(y_hat, y)
|
481
|
+
self.log("train_loss", loss, prog_bar=True, batch_size=self.batch_size)
|
482
|
+
|
483
|
+
return loss
|
484
|
+
|
485
|
+
def validation_step(self, batch, batch_idx):
|
486
|
+
x, y, _ = batch
|
487
|
+
x = x.to(torch.float32)
|
488
|
+
|
489
|
+
y_hat = self.forward(x)
|
490
|
+
|
491
|
+
y = torch.argmax(y, dim=1)
|
492
|
+
y_hat = y_hat.squeeze()
|
493
|
+
|
494
|
+
loss = torch.nn.functional.cross_entropy(y_hat, y)
|
495
|
+
self.log("val_loss", loss, prog_bar=True, batch_size=self.batch_size)
|
496
|
+
|
497
|
+
return loss
|
498
|
+
|
499
|
+
def test_step(self, batch, batch_idx):
|
500
|
+
x, y, _ = batch
|
501
|
+
x = x.to(torch.float32)
|
502
|
+
|
503
|
+
y_hat = self.forward(x)
|
504
|
+
|
505
|
+
y = torch.argmax(y, dim=1)
|
506
|
+
y_hat = y_hat.squeeze()
|
507
|
+
|
508
|
+
loss = torch.nn.functional.cross_entropy(y_hat, y)
|
509
|
+
self.log("test_loss", loss, prog_bar=True, batch_size=self.batch_size)
|
510
|
+
|
511
|
+
return loss
|
512
|
+
|
513
|
+
def embedding(self, x):
|
514
|
+
"""
|
515
|
+
Inputs:
|
516
|
+
x: Input features of shape [Batch, SeqLen, 1]
|
517
|
+
"""
|
518
|
+
x = self.net.embedding(x)
|
519
|
+
return x
|
520
|
+
|
521
|
+
def get_embeddings(self, batch):
|
522
|
+
x, _, y = batch
|
523
|
+
x = x.to(torch.float32)
|
524
|
+
|
525
|
+
embedding = self.embedding(x)
|
526
|
+
return embedding, y
|