pertpy 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/data/_dataloader.py +2 -2
- pertpy/data/_datasets.py +62 -62
- pertpy/metadata/_cell_line.py +9 -3
- pertpy/metadata/_drug.py +4 -2
- pertpy/preprocessing/_guide_rna.py +17 -10
- pertpy/preprocessing/_guide_rna_mixture.py +9 -3
- pertpy/tools/__init__.py +12 -2
- pertpy/tools/_augur.py +37 -14
- pertpy/tools/_coda/_sccoda.py +68 -101
- pertpy/tools/_coda/_tasccoda.py +103 -85
- pertpy/tools/_mixscape.py +48 -39
- pertpy/tools/_perturbation_space/_comparison.py +3 -3
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +261 -353
- pertpy/tools/_perturbation_space/_perturbation_space.py +22 -14
- pertpy/tools/_perturbation_space/_simple.py +12 -6
- pertpy/tools/_scgen/_scgenvae.py +2 -1
- pertpy/tools/core.py +18 -0
- {pertpy-1.0.1.dist-info → pertpy-1.0.3.dist-info}/METADATA +14 -2
- {pertpy-1.0.1.dist-info → pertpy-1.0.3.dist-info}/RECORD +22 -21
- {pertpy-1.0.1.dist-info → pertpy-1.0.3.dist-info}/WHEEL +0 -0
- {pertpy-1.0.1.dist-info → pertpy-1.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,18 +1,21 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import flax.linen as nn
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
4
8
|
import numpy as np
|
9
|
+
import optax
|
10
|
+
import pandas as pd
|
5
11
|
import scipy
|
6
|
-
import torch
|
7
12
|
from anndata import AnnData
|
8
13
|
from fast_array_utils.conv import to_dense
|
9
|
-
from
|
10
|
-
from
|
14
|
+
from flax.training import train_state
|
15
|
+
from jax import random
|
11
16
|
from sklearn.linear_model import LogisticRegression
|
12
17
|
from sklearn.model_selection import train_test_split
|
13
18
|
from sklearn.preprocessing import OneHotEncoder
|
14
|
-
from torch import optim
|
15
|
-
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
|
16
19
|
|
17
20
|
from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
|
18
21
|
|
@@ -74,13 +77,11 @@ class LRClassifierSpace(PerturbationSpace):
|
|
74
77
|
|
75
78
|
regression_labels = adata.obs[target_col]
|
76
79
|
|
77
|
-
# Save adata observations for embedding annotations in get_embeddings
|
78
80
|
adata_obs = adata.obs.reset_index(drop=True)
|
79
81
|
adata_obs = adata_obs.groupby(target_col).agg(
|
80
82
|
lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
|
81
83
|
)
|
82
84
|
|
83
|
-
# Fit a logistic regression model for each perturbation
|
84
85
|
regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")
|
85
86
|
regression_embeddings = {}
|
86
87
|
regression_scores = {}
|
@@ -95,12 +96,10 @@ class LRClassifierSpace(PerturbationSpace):
|
|
95
96
|
regression_embeddings[perturbation] = regression_model.coef_
|
96
97
|
regression_scores[perturbation] = regression_model.score(X_test, y_test)
|
97
98
|
|
98
|
-
# Save the regression embeddings and scores in an AnnData object
|
99
99
|
pert_adata = AnnData(X=np.array(list(regression_embeddings.values())).squeeze())
|
100
100
|
pert_adata.obs["perturbations"] = list(regression_embeddings.keys())
|
101
101
|
pert_adata.obs["classifier_score"] = list(regression_scores.values())
|
102
102
|
|
103
|
-
# Save adata observations for embedding annotations
|
104
103
|
for obs_name in adata_obs.columns:
|
105
104
|
if not adata_obs[obs_name].isnull().values.any():
|
106
105
|
pert_adata.obs[obs_name] = pert_adata.obs["perturbations"].map(
|
@@ -110,6 +109,174 @@ class LRClassifierSpace(PerturbationSpace):
|
|
110
109
|
return pert_adata
|
111
110
|
|
112
111
|
|
112
|
+
class MLP(nn.Module):
|
113
|
+
"""A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm."""
|
114
|
+
|
115
|
+
sizes: list[int]
|
116
|
+
dropout: float = 0.0
|
117
|
+
batch_norm: bool = True
|
118
|
+
layer_norm: bool = False
|
119
|
+
last_layer_act: str = "linear"
|
120
|
+
|
121
|
+
@nn.compact
|
122
|
+
def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
|
123
|
+
for i in range(len(self.sizes) - 1):
|
124
|
+
x = nn.Dense(self.sizes[i + 1])(x)
|
125
|
+
|
126
|
+
if i < len(self.sizes) - 2:
|
127
|
+
if self.batch_norm:
|
128
|
+
x = nn.BatchNorm(use_running_average=not training)(x)
|
129
|
+
elif self.layer_norm:
|
130
|
+
x = nn.LayerNorm()(x)
|
131
|
+
|
132
|
+
x = nn.relu(x)
|
133
|
+
|
134
|
+
if self.dropout > 0 and training:
|
135
|
+
x = nn.Dropout(rate=self.dropout, deterministic=not training)(x)
|
136
|
+
|
137
|
+
if self.last_layer_act == "ReLU":
|
138
|
+
x = nn.relu(x)
|
139
|
+
|
140
|
+
return x
|
141
|
+
|
142
|
+
@nn.compact
|
143
|
+
def embedding(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
|
144
|
+
for i in range(len(self.sizes) - 2):
|
145
|
+
x = nn.Dense(self.sizes[i + 1])(x)
|
146
|
+
|
147
|
+
if self.batch_norm:
|
148
|
+
x = nn.BatchNorm(use_running_average=True)(x)
|
149
|
+
elif self.layer_norm:
|
150
|
+
x = nn.LayerNorm()(x)
|
151
|
+
|
152
|
+
x = nn.relu(x)
|
153
|
+
|
154
|
+
if self.dropout > 0 and training:
|
155
|
+
x = nn.Dropout(rate=self.dropout, deterministic=True)(x)
|
156
|
+
|
157
|
+
return x
|
158
|
+
|
159
|
+
|
160
|
+
class TrainState(train_state.TrainState):
|
161
|
+
batch_stats: Any
|
162
|
+
|
163
|
+
|
164
|
+
def create_train_state(rng: jnp.ndarray, model: nn.Module, input_shape: tuple[int, ...], lr: float) -> TrainState:
|
165
|
+
dummy_input = jnp.ones((1,) + input_shape)
|
166
|
+
rng, init_rng, dropout_rng = random.split(rng, 3)
|
167
|
+
variables = model.init({"params": init_rng, "dropout": dropout_rng}, dummy_input, training=True)
|
168
|
+
params = variables["params"]
|
169
|
+
batch_stats = variables.get("batch_stats", {})
|
170
|
+
|
171
|
+
tx = optax.adamw(learning_rate=lr, weight_decay=0.1)
|
172
|
+
|
173
|
+
return TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)
|
174
|
+
|
175
|
+
|
176
|
+
@jax.jit
|
177
|
+
def train_step(state: TrainState, batch: tuple[jnp.ndarray, jnp.ndarray], rng: jnp.ndarray) -> tuple[TrainState, float]:
|
178
|
+
def loss_fn(params):
|
179
|
+
x, y = batch
|
180
|
+
variables = {"params": params, "batch_stats": state.batch_stats}
|
181
|
+
logits, new_batch_stats = state.apply_fn(
|
182
|
+
variables, x, training=True, mutable=["batch_stats"], rngs={"dropout": rng}
|
183
|
+
)
|
184
|
+
|
185
|
+
y_indices = jnp.argmax(y, axis=1)
|
186
|
+
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y_indices).mean()
|
187
|
+
return loss, new_batch_stats
|
188
|
+
|
189
|
+
(loss, new_batch_stats), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
|
190
|
+
state = state.apply_gradients(grads=grads)
|
191
|
+
state = state.replace(batch_stats=new_batch_stats["batch_stats"])
|
192
|
+
|
193
|
+
return state, loss
|
194
|
+
|
195
|
+
|
196
|
+
@jax.jit
|
197
|
+
def val_step(state: TrainState, batch: tuple[jnp.ndarray, jnp.ndarray]) -> float:
|
198
|
+
x, y = batch
|
199
|
+
variables = {"params": state.params, "batch_stats": state.batch_stats}
|
200
|
+
logits = state.apply_fn(variables, x, training=False)
|
201
|
+
|
202
|
+
y_indices = jnp.argmax(y, axis=1)
|
203
|
+
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y_indices).mean()
|
204
|
+
return loss
|
205
|
+
|
206
|
+
|
207
|
+
@jax.jit
|
208
|
+
def get_embeddings(state: TrainState, x: jnp.ndarray) -> jnp.ndarray:
|
209
|
+
variables = {"params": state.params, "batch_stats": state.batch_stats}
|
210
|
+
return state.apply_fn(variables, x, training=False, method="embedding")
|
211
|
+
|
212
|
+
|
213
|
+
class JAXDataset:
|
214
|
+
"""Dataset for perturbation classification.
|
215
|
+
|
216
|
+
Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
|
217
|
+
"""
|
218
|
+
|
219
|
+
def __init__(
|
220
|
+
self,
|
221
|
+
adata: AnnData,
|
222
|
+
target_col: str = "perturbations",
|
223
|
+
label_col: str = "perturbations",
|
224
|
+
layer_key: str = None,
|
225
|
+
):
|
226
|
+
"""JAX Dataset for perturbation classification.
|
227
|
+
|
228
|
+
Args:
|
229
|
+
adata: AnnData object with observations and labels.
|
230
|
+
target_col: key with the perturbation labels numerically encoded.
|
231
|
+
label_col: key with the perturbation labels.
|
232
|
+
layer_key: key of the layer to be used as data, otherwise .X.
|
233
|
+
"""
|
234
|
+
if layer_key:
|
235
|
+
self.data = adata.layers[layer_key]
|
236
|
+
else:
|
237
|
+
self.data = adata.X
|
238
|
+
|
239
|
+
if target_col in adata.obs.columns:
|
240
|
+
self.labels = adata.obs[target_col].values
|
241
|
+
elif target_col in adata.obsm:
|
242
|
+
self.labels = adata.obsm[target_col]
|
243
|
+
else:
|
244
|
+
raise ValueError(f"Target column {target_col} not found in obs or obsm")
|
245
|
+
|
246
|
+
self.pert_labels = adata.obs[label_col].values
|
247
|
+
|
248
|
+
if scipy.sparse.issparse(self.data):
|
249
|
+
self.data = to_dense(self.data)
|
250
|
+
|
251
|
+
self.data = jnp.array(self.data, dtype=jnp.float32)
|
252
|
+
self.labels = jnp.array(self.labels, dtype=jnp.float32)
|
253
|
+
|
254
|
+
def __len__(self):
|
255
|
+
return self.data.shape[0]
|
256
|
+
|
257
|
+
def get_batch(self, indices: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray, list]:
|
258
|
+
"""Returns a batch of samples and corresponding perturbations applied (labels)."""
|
259
|
+
batch_data = self.data[indices]
|
260
|
+
batch_labels = self.labels[indices]
|
261
|
+
batch_pert_labels = [self.pert_labels[i] for i in indices]
|
262
|
+
return batch_data, batch_labels, batch_pert_labels
|
263
|
+
|
264
|
+
|
265
|
+
def create_batched_indices(
|
266
|
+
dataset_size: int, rng: jnp.ndarray, batch_size: int, n_batches: int, weights: jnp.ndarray | None = None
|
267
|
+
) -> list:
|
268
|
+
"""Create batched indices for training, optionally with weighted sampling."""
|
269
|
+
batches = []
|
270
|
+
for _ in range(n_batches):
|
271
|
+
rng, batch_rng = random.split(rng)
|
272
|
+
if weights is not None:
|
273
|
+
batch_indices = random.choice(batch_rng, dataset_size, shape=(batch_size,), p=weights)
|
274
|
+
else:
|
275
|
+
batch_indices = random.choice(batch_rng, dataset_size, shape=(batch_size,), replace=False)
|
276
|
+
batches.append(batch_indices)
|
277
|
+
return batches
|
278
|
+
|
279
|
+
|
113
280
|
class MLPClassifierSpace(PerturbationSpace):
|
114
281
|
"""Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
|
115
282
|
|
@@ -120,7 +287,7 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
120
287
|
See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
|
121
288
|
"""
|
122
289
|
|
123
|
-
def compute(
|
290
|
+
def compute(
|
124
291
|
self,
|
125
292
|
adata: AnnData,
|
126
293
|
target_col: str = "perturbations",
|
@@ -128,12 +295,14 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
128
295
|
hidden_dim: list[int] = None,
|
129
296
|
dropout: float = 0.0,
|
130
297
|
batch_norm: bool = True,
|
131
|
-
batch_size: int =
|
298
|
+
batch_size: int = 128,
|
132
299
|
test_split_size: float = 0.2,
|
133
300
|
validation_split_size: float = 0.25,
|
134
301
|
max_epochs: int = 20,
|
135
302
|
val_epochs_check: int = 2,
|
136
303
|
patience: int = 2,
|
304
|
+
lr: float = 1e-4,
|
305
|
+
seed: int = 42,
|
137
306
|
) -> AnnData:
|
138
307
|
"""Creates cell embeddings by training a MLP classifier model to distinguish between perturbations.
|
139
308
|
|
@@ -148,21 +317,21 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
148
317
|
adata: AnnData object of size cells x genes
|
149
318
|
target_col: .obs column that stores the perturbations.
|
150
319
|
layer_key: Layer in adata to use.
|
151
|
-
hidden_dim: List of number of neurons in each hidden layers of the neural network.
|
152
|
-
will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
|
320
|
+
hidden_dim: List of number of neurons in each hidden layers of the neural network.
|
321
|
+
For instance, [512, 256] will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
|
153
322
|
dropout: Amount of dropout applied, constant for all layers.
|
154
323
|
batch_norm: Whether to apply batch normalization.
|
155
324
|
batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass.
|
156
325
|
test_split_size: Fraction of data to put in the test set. Default to 0.2.
|
157
326
|
validation_split_size: Fraction of data to put in the validation set of the resultant train set.
|
158
|
-
E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
|
159
|
-
will be used for validation.
|
327
|
+
E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data will be used for validation.
|
160
328
|
max_epochs: Maximum number of epochs for training.
|
161
329
|
val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
|
162
|
-
Note that this affects early stopping, as the model will be stopped if the validation performance does not
|
163
|
-
improve for patience epochs.
|
330
|
+
Note that this affects early stopping, as the model will be stopped if the validation performance does not improve for patience epochs.
|
164
331
|
patience: Number of validation performance checks without improvement, after which the early stopping flag
|
165
332
|
is activated and training is therefore stopped.
|
333
|
+
lr: Learning rate for training.
|
334
|
+
seed: Random seed for reproducibility.
|
166
335
|
|
167
336
|
Returns:
|
168
337
|
AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
|
@@ -188,9 +357,9 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
188
357
|
labels = adata.obs[target_col].values.reshape(-1, 1)
|
189
358
|
encoder = OneHotEncoder()
|
190
359
|
encoded_labels = encoder.fit_transform(labels).toarray()
|
360
|
+
adata = adata.copy()
|
191
361
|
adata.obsm["encoded_perturbations"] = encoded_labels.astype(np.float32)
|
192
362
|
|
193
|
-
# Split the data in train, test and validation
|
194
363
|
X = list(range(adata.n_obs))
|
195
364
|
y = adata.obs[target_col]
|
196
365
|
|
@@ -199,368 +368,107 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
199
368
|
X_train, y_train, test_size=validation_split_size, stratify=y_train
|
200
369
|
)
|
201
370
|
|
202
|
-
train_dataset =
|
371
|
+
train_dataset = JAXDataset(
|
203
372
|
adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
204
373
|
)
|
205
|
-
val_dataset =
|
374
|
+
val_dataset = JAXDataset(
|
206
375
|
adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
207
376
|
)
|
208
|
-
test_dataset =
|
377
|
+
test_dataset = JAXDataset(
|
209
378
|
adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
210
|
-
) # we don't need to pass y_test since the label selection is done inside
|
211
|
-
|
212
|
-
# Fix class unbalance (likely to happen in perturbation datasets)
|
213
|
-
# Usually control cells are overrepresented such that predicting control all time would give good results
|
214
|
-
# Cells with rare perturbations are sampled more
|
215
|
-
train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
|
216
|
-
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
|
217
|
-
|
218
|
-
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
|
219
|
-
self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
220
|
-
self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
221
|
-
|
222
|
-
# Define the network
|
223
|
-
sizes = [adata.n_vars] + hidden_dim + [n_classes]
|
224
|
-
self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
|
225
|
-
|
226
|
-
# Define a dataset that gathers all the data and dataloader for getting embeddings
|
227
|
-
total_dataset = PLDataset(
|
228
|
-
adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
229
|
-
)
|
230
|
-
self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=0)
|
231
|
-
|
232
|
-
# Save adata observations for embedding annotations in get_embeddings
|
233
|
-
self.adata_obs = adata.obs.reset_index(drop=True)
|
234
|
-
|
235
|
-
self.trainer = Trainer(
|
236
|
-
min_epochs=1,
|
237
|
-
max_epochs=max_epochs,
|
238
|
-
check_val_every_n_epoch=val_epochs_check,
|
239
|
-
callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
|
240
|
-
devices="auto",
|
241
|
-
accelerator="auto",
|
242
|
-
)
|
243
|
-
|
244
|
-
self.mlp = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)
|
245
|
-
|
246
|
-
self.trainer.fit(model=self.mlp, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader)
|
247
|
-
self.trainer.test(model=self.mlp, dataloaders=self.test_dataloader)
|
248
|
-
|
249
|
-
# Obtain cell embeddings
|
250
|
-
with torch.no_grad():
|
251
|
-
self.mlp.eval()
|
252
|
-
for dataset_count, batch in enumerate(self.entire_dataset):
|
253
|
-
emb, y = self.mlp.get_embeddings(batch)
|
254
|
-
emb = torch.squeeze(emb)
|
255
|
-
batch_adata = AnnData(X=emb.cpu().numpy())
|
256
|
-
batch_adata.obs["perturbations"] = y
|
257
|
-
if dataset_count == 0:
|
258
|
-
pert_adata = batch_adata
|
259
|
-
else:
|
260
|
-
pert_adata = batch_adata if dataset_count == 0 else anndata.concat([pert_adata, batch_adata])
|
261
|
-
|
262
|
-
# Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
|
263
|
-
# and we can just add the annotations from the original AnnData object
|
264
|
-
pert_adata.obs = pert_adata.obs.reset_index(drop=True)
|
265
|
-
if "perturbations" in self.adata_obs.columns:
|
266
|
-
self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
|
267
|
-
obs_subset = self.adata_obs.iloc[: len(pert_adata.obs)].copy()
|
268
|
-
for col in obs_subset.columns:
|
269
|
-
if col not in ["perturbations", "encoded_perturbations"]:
|
270
|
-
pert_adata.obs[col] = obs_subset[col].values
|
271
|
-
|
272
|
-
return pert_adata
|
273
|
-
|
274
|
-
def load(self, adata, **kwargs):
|
275
|
-
"""This method is deprecated and will be removed in the future. Please use the compute method instead."""
|
276
|
-
raise DeprecationWarning(
|
277
|
-
"The load method is deprecated and will be removed in the future. Please use the compute method instead."
|
278
|
-
)
|
279
|
-
|
280
|
-
def train(self, **kwargs):
|
281
|
-
"""This method is deprecated and will be removed in the future. Please use the compute method instead."""
|
282
|
-
raise DeprecationWarning(
|
283
|
-
"The train method is deprecated and will be removed in the future. Please use the compute method instead."
|
284
379
|
)
|
285
|
-
|
286
|
-
|
287
|
-
"""This method is deprecated and will be removed in the future. Please use the compute method instead."""
|
288
|
-
raise DeprecationWarning(
|
289
|
-
"The get_embeddings method is deprecated and will be removed in the future. Please use the compute method instead."
|
380
|
+
total_dataset = JAXDataset(
|
381
|
+
adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
|
290
382
|
)
|
291
383
|
|
384
|
+
rng = random.PRNGKey(seed)
|
385
|
+
rng, init_rng, train_rng = random.split(rng, 3)
|
292
386
|
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
def __init__(
|
297
|
-
self,
|
298
|
-
sizes: list[int],
|
299
|
-
dropout: float = 0.0,
|
300
|
-
batch_norm: bool = True,
|
301
|
-
layer_norm: bool = False,
|
302
|
-
last_layer_act: str = "linear",
|
303
|
-
) -> None:
|
304
|
-
"""Multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
|
305
|
-
|
306
|
-
Args:
|
307
|
-
sizes: size of layers.
|
308
|
-
dropout: Dropout probability.
|
309
|
-
batch_norm: specifies if batch norm should be applied.
|
310
|
-
layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
|
311
|
-
last_layer_act: activation function of last layer.
|
312
|
-
"""
|
313
|
-
super().__init__()
|
314
|
-
layers = []
|
315
|
-
for s in range(len(sizes) - 1):
|
316
|
-
layers += [
|
317
|
-
torch.nn.Linear(sizes[s], sizes[s + 1]),
|
318
|
-
torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
|
319
|
-
torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
|
320
|
-
torch.nn.ReLU(),
|
321
|
-
torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
|
322
|
-
]
|
323
|
-
|
324
|
-
layers = [layer for layer in layers if layer is not None][:-1]
|
325
|
-
self.activation = last_layer_act
|
326
|
-
if self.activation == "linear":
|
327
|
-
pass
|
328
|
-
elif self.activation == "ReLU":
|
329
|
-
self.relu = torch.nn.ReLU()
|
330
|
-
else:
|
331
|
-
raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
|
332
|
-
|
333
|
-
self.network = torch.nn.Sequential(*layers)
|
334
|
-
|
335
|
-
self.network.apply(init_weights)
|
336
|
-
|
337
|
-
self.sizes = sizes
|
338
|
-
self.batch_norm = batch_norm
|
339
|
-
self.layer_norm = layer_norm
|
340
|
-
self.last_layer_act = last_layer_act
|
341
|
-
|
342
|
-
def forward(self, x) -> torch.Tensor:
|
343
|
-
if self.activation == "ReLU":
|
344
|
-
return self.relu(self.network(x))
|
345
|
-
return self.network(x)
|
346
|
-
|
347
|
-
def embedding(self, x) -> torch.Tensor:
|
348
|
-
for layer in self.network[:-1]:
|
349
|
-
x = layer(x)
|
350
|
-
return x
|
351
|
-
|
352
|
-
|
353
|
-
def init_weights(m):
|
354
|
-
if isinstance(m, torch.nn.Linear):
|
355
|
-
torch.nn.init.kaiming_uniform_(m.weight)
|
356
|
-
m.bias.data.fill_(0.01)
|
357
|
-
|
358
|
-
|
359
|
-
class PLDataset(Dataset):
|
360
|
-
"""Dataset for perturbation classification.
|
361
|
-
|
362
|
-
Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
|
363
|
-
"""
|
364
|
-
|
365
|
-
def __init__(
|
366
|
-
self,
|
367
|
-
adata: np.array,
|
368
|
-
target_col: str = "perturbations",
|
369
|
-
label_col: str = "perturbations",
|
370
|
-
layer_key: str = None,
|
371
|
-
):
|
372
|
-
"""PyTorch lightning Dataset for perturbation classification.
|
373
|
-
|
374
|
-
Args:
|
375
|
-
adata: AnnData object with observations and labels.
|
376
|
-
target_col: key with the perturbation labels numerically encoded.
|
377
|
-
label_col: key with the perturbation labels.
|
378
|
-
layer_key: key of the layer to be used as data, otherwise .X.
|
379
|
-
"""
|
380
|
-
if layer_key:
|
381
|
-
self.data = adata.layers[layer_key]
|
382
|
-
else:
|
383
|
-
self.data = adata.X
|
384
|
-
|
385
|
-
if target_col in adata.obs.columns:
|
386
|
-
self.labels = adata.obs[target_col]
|
387
|
-
elif target_col in adata.obsm:
|
388
|
-
self.labels = adata.obsm[target_col]
|
389
|
-
else:
|
390
|
-
raise ValueError(f"Target column {target_col} not found in obs or obsm")
|
391
|
-
|
392
|
-
self.pert_labels = adata.obs[label_col]
|
393
|
-
|
394
|
-
def __len__(self):
|
395
|
-
return self.data.shape[0]
|
396
|
-
|
397
|
-
def __getitem__(self, idx):
|
398
|
-
"""Returns a sample and corresponding perturbations applied (labels)."""
|
399
|
-
sample = to_dense(self.data[idx]).squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
|
400
|
-
num_label = self.labels.iloc[idx] if hasattr(self.labels, "iloc") else self.labels[idx]
|
401
|
-
str_label = self.pert_labels.iloc[idx]
|
402
|
-
|
403
|
-
return sample, num_label, str_label
|
387
|
+
sizes = [adata.n_vars] + hidden_dim + [n_classes]
|
388
|
+
model = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
|
404
389
|
|
390
|
+
state = create_train_state(init_rng, model, (adata.n_vars,), lr)
|
405
391
|
|
406
|
-
class
|
407
|
-
|
408
|
-
|
409
|
-
model: torch.nn.Module,
|
410
|
-
batch_size: int,
|
411
|
-
layers: list = [512], # noqa
|
412
|
-
dropout: float = 0.0,
|
413
|
-
batch_norm: bool = True,
|
414
|
-
layer_norm: bool = False,
|
415
|
-
last_layer_act: str = "linear",
|
416
|
-
lr=1e-4,
|
417
|
-
seed=42,
|
418
|
-
):
|
419
|
-
"""Perturbation Classifier.
|
392
|
+
# Create weighted sampling for class imbalance
|
393
|
+
weights = 1.0 / (1.0 + jnp.sum(train_dataset.labels, axis=1))
|
394
|
+
weights = weights / jnp.sum(weights)
|
420
395
|
|
421
|
-
|
422
|
-
|
423
|
-
batch_size
|
424
|
-
layers: list of layers of the MLP
|
425
|
-
dropout: dropout probability
|
426
|
-
batch_norm: whether to apply batch norm
|
427
|
-
layer_norm: whether to apply layer norm
|
428
|
-
last_layer_act: activation function of last layer
|
429
|
-
lr: learning rate
|
430
|
-
seed: random seed.
|
431
|
-
"""
|
432
|
-
super().__init__()
|
433
|
-
self.batch_size = batch_size
|
434
|
-
self.save_hyperparameters()
|
435
|
-
if model:
|
436
|
-
self.net = model
|
437
|
-
else:
|
438
|
-
self._create_model()
|
439
|
-
|
440
|
-
def _create_model(self):
|
441
|
-
self.net = MLP(
|
442
|
-
sizes=self.hparams.layers,
|
443
|
-
dropout=self.hparams.dropout,
|
444
|
-
batch_norm=self.hparams.batch_norm,
|
445
|
-
layer_norm=self.hparams.layer_norm,
|
446
|
-
last_layer_act=self.hparams.last_layer_act,
|
396
|
+
n_batches_per_epoch = len(train_dataset) // batch_size
|
397
|
+
train_batches = create_batched_indices(
|
398
|
+
len(train_dataset), train_rng, batch_size, max_epochs * n_batches_per_epoch, weights
|
447
399
|
)
|
448
400
|
|
449
|
-
|
450
|
-
|
401
|
+
best_val_loss = float("inf")
|
402
|
+
patience_counter = 0
|
451
403
|
|
452
|
-
|
453
|
-
|
404
|
+
for epoch in range(max_epochs):
|
405
|
+
epoch_train_loss = 0
|
454
406
|
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
x = self.net(x)
|
459
|
-
return x
|
407
|
+
epoch_start = epoch * n_batches_per_epoch
|
408
|
+
epoch_end = (epoch + 1) * n_batches_per_epoch
|
409
|
+
epoch_batches = train_batches[epoch_start:epoch_end]
|
460
410
|
|
461
|
-
|
462
|
-
|
411
|
+
for _n_train_batches, batch_indices in enumerate(epoch_batches, 1):
|
412
|
+
rng, step_rng = random.split(rng)
|
413
|
+
batch_data, batch_labels, *_ = train_dataset.get_batch(batch_indices)
|
414
|
+
state, loss = train_step(state, (batch_data, batch_labels), step_rng)
|
415
|
+
epoch_train_loss += loss
|
463
416
|
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
"""Perform a training step.
|
472
|
-
|
473
|
-
Args:
|
474
|
-
batch: Tuple of (input, target, metadata)
|
475
|
-
batch_idx: Index of the current batch
|
476
|
-
|
477
|
-
Returns:
|
478
|
-
Loss value
|
479
|
-
"""
|
480
|
-
x, y, _ = batch
|
481
|
-
x = x.to(torch.float32)
|
482
|
-
|
483
|
-
y_hat = self.forward(x)
|
484
|
-
|
485
|
-
y = torch.argmax(y, dim=1)
|
486
|
-
y_hat = y_hat.squeeze()
|
417
|
+
if (epoch + 1) % val_epochs_check == 0:
|
418
|
+
val_losses = []
|
419
|
+
for i in range(0, len(val_dataset), batch_size):
|
420
|
+
val_indices = jnp.arange(i, min(i + batch_size, len(val_dataset)))
|
421
|
+
val_batch_data, val_batch_labels, _ = val_dataset.get_batch(val_indices)
|
422
|
+
val_loss = val_step(state, (val_batch_data, val_batch_labels))
|
423
|
+
val_losses.append(val_loss)
|
487
424
|
|
488
|
-
|
489
|
-
self.log("train_loss", loss, prog_bar=True, batch_size=self.batch_size)
|
425
|
+
avg_val_loss = jnp.mean(jnp.array(val_losses))
|
490
426
|
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
Args:
|
497
|
-
batch: Tuple of (input, target, metadata)
|
498
|
-
batch_idx: Index of the current batch
|
499
|
-
|
500
|
-
Returns:
|
501
|
-
Loss value
|
502
|
-
"""
|
503
|
-
x, y, _ = batch
|
504
|
-
x = x.to(torch.float32)
|
505
|
-
|
506
|
-
y_hat = self.forward(x)
|
507
|
-
|
508
|
-
y = torch.argmax(y, dim=1)
|
509
|
-
y_hat = y_hat.squeeze()
|
510
|
-
|
511
|
-
loss = torch.nn.functional.cross_entropy(y_hat, y)
|
512
|
-
self.log("val_loss", loss, prog_bar=True, batch_size=self.batch_size)
|
513
|
-
|
514
|
-
return loss
|
515
|
-
|
516
|
-
def test_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
517
|
-
"""Perform a test step.
|
518
|
-
|
519
|
-
Args:
|
520
|
-
batch: Tuple of (input, target, metadata)
|
521
|
-
batch_idx: Index of the current batch
|
522
|
-
|
523
|
-
Returns:
|
524
|
-
Loss value
|
525
|
-
"""
|
526
|
-
x, y, _ = batch
|
527
|
-
x = x.to(torch.float32)
|
427
|
+
if avg_val_loss < best_val_loss:
|
428
|
+
best_val_loss = avg_val_loss
|
429
|
+
patience_counter = 0
|
430
|
+
else:
|
431
|
+
patience_counter += 1
|
528
432
|
|
529
|
-
|
433
|
+
if patience_counter >= patience:
|
434
|
+
break
|
530
435
|
|
531
|
-
|
532
|
-
|
436
|
+
# Test evaluation
|
437
|
+
test_losses = []
|
438
|
+
for i in range(0, len(test_dataset), batch_size):
|
439
|
+
test_indices = jnp.arange(i, min(i + batch_size, len(test_dataset)))
|
440
|
+
test_batch_data, test_batch_labels, _ = test_dataset.get_batch(test_indices)
|
441
|
+
test_loss = val_step(state, (test_batch_data, test_batch_labels))
|
442
|
+
test_losses.append(test_loss)
|
533
443
|
|
534
|
-
|
535
|
-
|
444
|
+
# Extract embeddings
|
445
|
+
embeddings_list = []
|
446
|
+
labels_list = []
|
536
447
|
|
537
|
-
|
448
|
+
for i in range(0, len(total_dataset), batch_size * 2):
|
449
|
+
indices = jnp.arange(i, min(i + batch_size * 2, len(total_dataset)))
|
450
|
+
batch_data, _, batch_pert_labels = total_dataset.get_batch(indices)
|
451
|
+
batch_embeddings = get_embeddings(state, batch_data)
|
538
452
|
|
539
|
-
|
540
|
-
|
453
|
+
embeddings_list.append(batch_embeddings)
|
454
|
+
labels_list.extend(batch_pert_labels)
|
541
455
|
|
542
|
-
|
543
|
-
x: Input tensor of shape [Batch, SeqLen, 1]
|
456
|
+
all_embeddings = jnp.concatenate(embeddings_list, axis=0)
|
544
457
|
|
545
|
-
|
546
|
-
|
547
|
-
"""
|
548
|
-
x = self.net.embedding(x)
|
549
|
-
return x
|
458
|
+
pert_adata = AnnData(X=np.array(all_embeddings))
|
459
|
+
pert_adata.obs["perturbations"] = labels_list
|
550
460
|
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
"""Extract embeddings from a batch.
|
461
|
+
adata_obs = adata.obs.reset_index(drop=True)
|
462
|
+
if "perturbations" in adata_obs.columns:
|
463
|
+
adata_obs = adata_obs.drop("perturbations", axis=1)
|
555
464
|
|
556
|
-
|
557
|
-
|
465
|
+
obs_subset = adata_obs.iloc[: len(pert_adata.obs)].copy()
|
466
|
+
cols_to_add = [col for col in obs_subset.columns if col not in ["perturbations", "encoded_perturbations"]]
|
467
|
+
new_cols_data = {col: obs_subset[col].values for col in cols_to_add}
|
558
468
|
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
x = x.to(torch.float32)
|
469
|
+
if new_cols_data:
|
470
|
+
pert_adata.obs = pd.concat(
|
471
|
+
[pert_adata.obs, pd.DataFrame(new_cols_data, index=pert_adata.obs.index)], axis=1
|
472
|
+
)
|
564
473
|
|
565
|
-
|
566
|
-
return embedding, y
|
474
|
+
return pert_adata
|