pertpy 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,18 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- import anndata
3
+ from typing import Any
4
+
5
+ import flax.linen as nn
6
+ import jax
7
+ import jax.numpy as jnp
4
8
  import numpy as np
9
+ import optax
10
+ import pandas as pd
5
11
  import scipy
6
- import torch
7
12
  from anndata import AnnData
8
13
  from fast_array_utils.conv import to_dense
9
- from pytorch_lightning import LightningModule, Trainer
10
- from pytorch_lightning.callbacks import EarlyStopping
14
+ from flax.training import train_state
15
+ from jax import random
11
16
  from sklearn.linear_model import LogisticRegression
12
17
  from sklearn.model_selection import train_test_split
13
18
  from sklearn.preprocessing import OneHotEncoder
14
- from torch import optim
15
- from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
16
19
 
17
20
  from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
18
21
 
@@ -74,13 +77,11 @@ class LRClassifierSpace(PerturbationSpace):
74
77
 
75
78
  regression_labels = adata.obs[target_col]
76
79
 
77
- # Save adata observations for embedding annotations in get_embeddings
78
80
  adata_obs = adata.obs.reset_index(drop=True)
79
81
  adata_obs = adata_obs.groupby(target_col).agg(
80
82
  lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
81
83
  )
82
84
 
83
- # Fit a logistic regression model for each perturbation
84
85
  regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")
85
86
  regression_embeddings = {}
86
87
  regression_scores = {}
@@ -95,12 +96,10 @@ class LRClassifierSpace(PerturbationSpace):
95
96
  regression_embeddings[perturbation] = regression_model.coef_
96
97
  regression_scores[perturbation] = regression_model.score(X_test, y_test)
97
98
 
98
- # Save the regression embeddings and scores in an AnnData object
99
99
  pert_adata = AnnData(X=np.array(list(regression_embeddings.values())).squeeze())
100
100
  pert_adata.obs["perturbations"] = list(regression_embeddings.keys())
101
101
  pert_adata.obs["classifier_score"] = list(regression_scores.values())
102
102
 
103
- # Save adata observations for embedding annotations
104
103
  for obs_name in adata_obs.columns:
105
104
  if not adata_obs[obs_name].isnull().values.any():
106
105
  pert_adata.obs[obs_name] = pert_adata.obs["perturbations"].map(
@@ -110,6 +109,174 @@ class LRClassifierSpace(PerturbationSpace):
110
109
  return pert_adata
111
110
 
112
111
 
112
+ class MLP(nn.Module):
113
+ """A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm."""
114
+
115
+ sizes: list[int]
116
+ dropout: float = 0.0
117
+ batch_norm: bool = True
118
+ layer_norm: bool = False
119
+ last_layer_act: str = "linear"
120
+
121
+ @nn.compact
122
+ def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
123
+ for i in range(len(self.sizes) - 1):
124
+ x = nn.Dense(self.sizes[i + 1])(x)
125
+
126
+ if i < len(self.sizes) - 2:
127
+ if self.batch_norm:
128
+ x = nn.BatchNorm(use_running_average=not training)(x)
129
+ elif self.layer_norm:
130
+ x = nn.LayerNorm()(x)
131
+
132
+ x = nn.relu(x)
133
+
134
+ if self.dropout > 0 and training:
135
+ x = nn.Dropout(rate=self.dropout, deterministic=not training)(x)
136
+
137
+ if self.last_layer_act == "ReLU":
138
+ x = nn.relu(x)
139
+
140
+ return x
141
+
142
+ @nn.compact
143
+ def embedding(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
144
+ for i in range(len(self.sizes) - 2):
145
+ x = nn.Dense(self.sizes[i + 1])(x)
146
+
147
+ if self.batch_norm:
148
+ x = nn.BatchNorm(use_running_average=True)(x)
149
+ elif self.layer_norm:
150
+ x = nn.LayerNorm()(x)
151
+
152
+ x = nn.relu(x)
153
+
154
+ if self.dropout > 0 and training:
155
+ x = nn.Dropout(rate=self.dropout, deterministic=True)(x)
156
+
157
+ return x
158
+
159
+
160
+ class TrainState(train_state.TrainState):
161
+ batch_stats: Any
162
+
163
+
164
+ def create_train_state(rng: jnp.ndarray, model: nn.Module, input_shape: tuple[int, ...], lr: float) -> TrainState:
165
+ dummy_input = jnp.ones((1,) + input_shape)
166
+ rng, init_rng, dropout_rng = random.split(rng, 3)
167
+ variables = model.init({"params": init_rng, "dropout": dropout_rng}, dummy_input, training=True)
168
+ params = variables["params"]
169
+ batch_stats = variables.get("batch_stats", {})
170
+
171
+ tx = optax.adamw(learning_rate=lr, weight_decay=0.1)
172
+
173
+ return TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)
174
+
175
+
176
+ @jax.jit
177
+ def train_step(state: TrainState, batch: tuple[jnp.ndarray, jnp.ndarray], rng: jnp.ndarray) -> tuple[TrainState, float]:
178
+ def loss_fn(params):
179
+ x, y = batch
180
+ variables = {"params": params, "batch_stats": state.batch_stats}
181
+ logits, new_batch_stats = state.apply_fn(
182
+ variables, x, training=True, mutable=["batch_stats"], rngs={"dropout": rng}
183
+ )
184
+
185
+ y_indices = jnp.argmax(y, axis=1)
186
+ loss = optax.softmax_cross_entropy_with_integer_labels(logits, y_indices).mean()
187
+ return loss, new_batch_stats
188
+
189
+ (loss, new_batch_stats), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
190
+ state = state.apply_gradients(grads=grads)
191
+ state = state.replace(batch_stats=new_batch_stats["batch_stats"])
192
+
193
+ return state, loss
194
+
195
+
196
+ @jax.jit
197
+ def val_step(state: TrainState, batch: tuple[jnp.ndarray, jnp.ndarray]) -> float:
198
+ x, y = batch
199
+ variables = {"params": state.params, "batch_stats": state.batch_stats}
200
+ logits = state.apply_fn(variables, x, training=False)
201
+
202
+ y_indices = jnp.argmax(y, axis=1)
203
+ loss = optax.softmax_cross_entropy_with_integer_labels(logits, y_indices).mean()
204
+ return loss
205
+
206
+
207
+ @jax.jit
208
+ def get_embeddings(state: TrainState, x: jnp.ndarray) -> jnp.ndarray:
209
+ variables = {"params": state.params, "batch_stats": state.batch_stats}
210
+ return state.apply_fn(variables, x, training=False, method="embedding")
211
+
212
+
213
+ class JAXDataset:
214
+ """Dataset for perturbation classification.
215
+
216
+ Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ adata: AnnData,
222
+ target_col: str = "perturbations",
223
+ label_col: str = "perturbations",
224
+ layer_key: str = None,
225
+ ):
226
+ """JAX Dataset for perturbation classification.
227
+
228
+ Args:
229
+ adata: AnnData object with observations and labels.
230
+ target_col: key with the perturbation labels numerically encoded.
231
+ label_col: key with the perturbation labels.
232
+ layer_key: key of the layer to be used as data, otherwise .X.
233
+ """
234
+ if layer_key:
235
+ self.data = adata.layers[layer_key]
236
+ else:
237
+ self.data = adata.X
238
+
239
+ if target_col in adata.obs.columns:
240
+ self.labels = adata.obs[target_col].values
241
+ elif target_col in adata.obsm:
242
+ self.labels = adata.obsm[target_col]
243
+ else:
244
+ raise ValueError(f"Target column {target_col} not found in obs or obsm")
245
+
246
+ self.pert_labels = adata.obs[label_col].values
247
+
248
+ if scipy.sparse.issparse(self.data):
249
+ self.data = to_dense(self.data)
250
+
251
+ self.data = jnp.array(self.data, dtype=jnp.float32)
252
+ self.labels = jnp.array(self.labels, dtype=jnp.float32)
253
+
254
+ def __len__(self):
255
+ return self.data.shape[0]
256
+
257
+ def get_batch(self, indices: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray, list]:
258
+ """Returns a batch of samples and corresponding perturbations applied (labels)."""
259
+ batch_data = self.data[indices]
260
+ batch_labels = self.labels[indices]
261
+ batch_pert_labels = [self.pert_labels[i] for i in indices]
262
+ return batch_data, batch_labels, batch_pert_labels
263
+
264
+
265
+ def create_batched_indices(
266
+ dataset_size: int, rng: jnp.ndarray, batch_size: int, n_batches: int, weights: jnp.ndarray | None = None
267
+ ) -> list:
268
+ """Create batched indices for training, optionally with weighted sampling."""
269
+ batches = []
270
+ for _ in range(n_batches):
271
+ rng, batch_rng = random.split(rng)
272
+ if weights is not None:
273
+ batch_indices = random.choice(batch_rng, dataset_size, shape=(batch_size,), p=weights)
274
+ else:
275
+ batch_indices = random.choice(batch_rng, dataset_size, shape=(batch_size,), replace=False)
276
+ batches.append(batch_indices)
277
+ return batches
278
+
279
+
113
280
  class MLPClassifierSpace(PerturbationSpace):
114
281
  """Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
115
282
 
@@ -120,7 +287,7 @@ class MLPClassifierSpace(PerturbationSpace):
120
287
  See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
121
288
  """
122
289
 
123
- def compute( # type: ignore
290
+ def compute(
124
291
  self,
125
292
  adata: AnnData,
126
293
  target_col: str = "perturbations",
@@ -128,12 +295,14 @@ class MLPClassifierSpace(PerturbationSpace):
128
295
  hidden_dim: list[int] = None,
129
296
  dropout: float = 0.0,
130
297
  batch_norm: bool = True,
131
- batch_size: int = 256,
298
+ batch_size: int = 128,
132
299
  test_split_size: float = 0.2,
133
300
  validation_split_size: float = 0.25,
134
301
  max_epochs: int = 20,
135
302
  val_epochs_check: int = 2,
136
303
  patience: int = 2,
304
+ lr: float = 1e-4,
305
+ seed: int = 42,
137
306
  ) -> AnnData:
138
307
  """Creates cell embeddings by training a MLP classifier model to distinguish between perturbations.
139
308
 
@@ -148,21 +317,21 @@ class MLPClassifierSpace(PerturbationSpace):
148
317
  adata: AnnData object of size cells x genes
149
318
  target_col: .obs column that stores the perturbations.
150
319
  layer_key: Layer in adata to use.
151
- hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
152
- will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
320
+ hidden_dim: List of number of neurons in each hidden layers of the neural network.
321
+ For instance, [512, 256] will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
153
322
  dropout: Amount of dropout applied, constant for all layers.
154
323
  batch_norm: Whether to apply batch normalization.
155
324
  batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass.
156
325
  test_split_size: Fraction of data to put in the test set. Default to 0.2.
157
326
  validation_split_size: Fraction of data to put in the validation set of the resultant train set.
158
- E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
159
- will be used for validation.
327
+ E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data will be used for validation.
160
328
  max_epochs: Maximum number of epochs for training.
161
329
  val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
162
- Note that this affects early stopping, as the model will be stopped if the validation performance does not
163
- improve for patience epochs.
330
+ Note that this affects early stopping, as the model will be stopped if the validation performance does not improve for patience epochs.
164
331
  patience: Number of validation performance checks without improvement, after which the early stopping flag
165
332
  is activated and training is therefore stopped.
333
+ lr: Learning rate for training.
334
+ seed: Random seed for reproducibility.
166
335
 
167
336
  Returns:
168
337
  AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
@@ -188,9 +357,9 @@ class MLPClassifierSpace(PerturbationSpace):
188
357
  labels = adata.obs[target_col].values.reshape(-1, 1)
189
358
  encoder = OneHotEncoder()
190
359
  encoded_labels = encoder.fit_transform(labels).toarray()
360
+ adata = adata.copy()
191
361
  adata.obsm["encoded_perturbations"] = encoded_labels.astype(np.float32)
192
362
 
193
- # Split the data in train, test and validation
194
363
  X = list(range(adata.n_obs))
195
364
  y = adata.obs[target_col]
196
365
 
@@ -199,368 +368,107 @@ class MLPClassifierSpace(PerturbationSpace):
199
368
  X_train, y_train, test_size=validation_split_size, stratify=y_train
200
369
  )
201
370
 
202
- train_dataset = PLDataset(
371
+ train_dataset = JAXDataset(
203
372
  adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
204
373
  )
205
- val_dataset = PLDataset(
374
+ val_dataset = JAXDataset(
206
375
  adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
207
376
  )
208
- test_dataset = PLDataset(
377
+ test_dataset = JAXDataset(
209
378
  adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
210
- ) # we don't need to pass y_test since the label selection is done inside
211
-
212
- # Fix class unbalance (likely to happen in perturbation datasets)
213
- # Usually control cells are overrepresented such that predicting control all time would give good results
214
- # Cells with rare perturbations are sampled more
215
- train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
216
- train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
217
-
218
- self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
219
- self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
220
- self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
221
-
222
- # Define the network
223
- sizes = [adata.n_vars] + hidden_dim + [n_classes]
224
- self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
225
-
226
- # Define a dataset that gathers all the data and dataloader for getting embeddings
227
- total_dataset = PLDataset(
228
- adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
229
- )
230
- self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=0)
231
-
232
- # Save adata observations for embedding annotations in get_embeddings
233
- self.adata_obs = adata.obs.reset_index(drop=True)
234
-
235
- self.trainer = Trainer(
236
- min_epochs=1,
237
- max_epochs=max_epochs,
238
- check_val_every_n_epoch=val_epochs_check,
239
- callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
240
- devices="auto",
241
- accelerator="auto",
242
- )
243
-
244
- self.mlp = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)
245
-
246
- self.trainer.fit(model=self.mlp, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader)
247
- self.trainer.test(model=self.mlp, dataloaders=self.test_dataloader)
248
-
249
- # Obtain cell embeddings
250
- with torch.no_grad():
251
- self.mlp.eval()
252
- for dataset_count, batch in enumerate(self.entire_dataset):
253
- emb, y = self.mlp.get_embeddings(batch)
254
- emb = torch.squeeze(emb)
255
- batch_adata = AnnData(X=emb.cpu().numpy())
256
- batch_adata.obs["perturbations"] = y
257
- if dataset_count == 0:
258
- pert_adata = batch_adata
259
- else:
260
- pert_adata = batch_adata if dataset_count == 0 else anndata.concat([pert_adata, batch_adata])
261
-
262
- # Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
263
- # and we can just add the annotations from the original AnnData object
264
- pert_adata.obs = pert_adata.obs.reset_index(drop=True)
265
- if "perturbations" in self.adata_obs.columns:
266
- self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
267
- obs_subset = self.adata_obs.iloc[: len(pert_adata.obs)].copy()
268
- for col in obs_subset.columns:
269
- if col not in ["perturbations", "encoded_perturbations"]:
270
- pert_adata.obs[col] = obs_subset[col].values
271
-
272
- return pert_adata
273
-
274
- def load(self, adata, **kwargs):
275
- """This method is deprecated and will be removed in the future. Please use the compute method instead."""
276
- raise DeprecationWarning(
277
- "The load method is deprecated and will be removed in the future. Please use the compute method instead."
278
- )
279
-
280
- def train(self, **kwargs):
281
- """This method is deprecated and will be removed in the future. Please use the compute method instead."""
282
- raise DeprecationWarning(
283
- "The train method is deprecated and will be removed in the future. Please use the compute method instead."
284
379
  )
285
-
286
- def get_embeddings(self, **kwargs):
287
- """This method is deprecated and will be removed in the future. Please use the compute method instead."""
288
- raise DeprecationWarning(
289
- "The get_embeddings method is deprecated and will be removed in the future. Please use the compute method instead."
380
+ total_dataset = JAXDataset(
381
+ adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
290
382
  )
291
383
 
384
+ rng = random.PRNGKey(seed)
385
+ rng, init_rng, train_rng = random.split(rng, 3)
292
386
 
293
- class MLP(torch.nn.Module):
294
- """A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm."""
295
-
296
- def __init__(
297
- self,
298
- sizes: list[int],
299
- dropout: float = 0.0,
300
- batch_norm: bool = True,
301
- layer_norm: bool = False,
302
- last_layer_act: str = "linear",
303
- ) -> None:
304
- """Multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
305
-
306
- Args:
307
- sizes: size of layers.
308
- dropout: Dropout probability.
309
- batch_norm: specifies if batch norm should be applied.
310
- layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
311
- last_layer_act: activation function of last layer.
312
- """
313
- super().__init__()
314
- layers = []
315
- for s in range(len(sizes) - 1):
316
- layers += [
317
- torch.nn.Linear(sizes[s], sizes[s + 1]),
318
- torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
319
- torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
320
- torch.nn.ReLU(),
321
- torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
322
- ]
323
-
324
- layers = [layer for layer in layers if layer is not None][:-1]
325
- self.activation = last_layer_act
326
- if self.activation == "linear":
327
- pass
328
- elif self.activation == "ReLU":
329
- self.relu = torch.nn.ReLU()
330
- else:
331
- raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
332
-
333
- self.network = torch.nn.Sequential(*layers)
334
-
335
- self.network.apply(init_weights)
336
-
337
- self.sizes = sizes
338
- self.batch_norm = batch_norm
339
- self.layer_norm = layer_norm
340
- self.last_layer_act = last_layer_act
341
-
342
- def forward(self, x) -> torch.Tensor:
343
- if self.activation == "ReLU":
344
- return self.relu(self.network(x))
345
- return self.network(x)
346
-
347
- def embedding(self, x) -> torch.Tensor:
348
- for layer in self.network[:-1]:
349
- x = layer(x)
350
- return x
351
-
352
-
353
- def init_weights(m):
354
- if isinstance(m, torch.nn.Linear):
355
- torch.nn.init.kaiming_uniform_(m.weight)
356
- m.bias.data.fill_(0.01)
357
-
358
-
359
- class PLDataset(Dataset):
360
- """Dataset for perturbation classification.
361
-
362
- Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
363
- """
364
-
365
- def __init__(
366
- self,
367
- adata: np.array,
368
- target_col: str = "perturbations",
369
- label_col: str = "perturbations",
370
- layer_key: str = None,
371
- ):
372
- """PyTorch lightning Dataset for perturbation classification.
373
-
374
- Args:
375
- adata: AnnData object with observations and labels.
376
- target_col: key with the perturbation labels numerically encoded.
377
- label_col: key with the perturbation labels.
378
- layer_key: key of the layer to be used as data, otherwise .X.
379
- """
380
- if layer_key:
381
- self.data = adata.layers[layer_key]
382
- else:
383
- self.data = adata.X
384
-
385
- if target_col in adata.obs.columns:
386
- self.labels = adata.obs[target_col]
387
- elif target_col in adata.obsm:
388
- self.labels = adata.obsm[target_col]
389
- else:
390
- raise ValueError(f"Target column {target_col} not found in obs or obsm")
391
-
392
- self.pert_labels = adata.obs[label_col]
393
-
394
- def __len__(self):
395
- return self.data.shape[0]
396
-
397
- def __getitem__(self, idx):
398
- """Returns a sample and corresponding perturbations applied (labels)."""
399
- sample = to_dense(self.data[idx]).squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
400
- num_label = self.labels.iloc[idx] if hasattr(self.labels, "iloc") else self.labels[idx]
401
- str_label = self.pert_labels.iloc[idx]
402
-
403
- return sample, num_label, str_label
387
+ sizes = [adata.n_vars] + hidden_dim + [n_classes]
388
+ model = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
404
389
 
390
+ state = create_train_state(init_rng, model, (adata.n_vars,), lr)
405
391
 
406
- class PerturbationClassifier(LightningModule):
407
- def __init__(
408
- self,
409
- model: torch.nn.Module,
410
- batch_size: int,
411
- layers: list = [512], # noqa
412
- dropout: float = 0.0,
413
- batch_norm: bool = True,
414
- layer_norm: bool = False,
415
- last_layer_act: str = "linear",
416
- lr=1e-4,
417
- seed=42,
418
- ):
419
- """Perturbation Classifier.
392
+ # Create weighted sampling for class imbalance
393
+ weights = 1.0 / (1.0 + jnp.sum(train_dataset.labels, axis=1))
394
+ weights = weights / jnp.sum(weights)
420
395
 
421
- Args:
422
- model: model to be trained
423
- batch_size: batch size
424
- layers: list of layers of the MLP
425
- dropout: dropout probability
426
- batch_norm: whether to apply batch norm
427
- layer_norm: whether to apply layer norm
428
- last_layer_act: activation function of last layer
429
- lr: learning rate
430
- seed: random seed.
431
- """
432
- super().__init__()
433
- self.batch_size = batch_size
434
- self.save_hyperparameters()
435
- if model:
436
- self.net = model
437
- else:
438
- self._create_model()
439
-
440
- def _create_model(self):
441
- self.net = MLP(
442
- sizes=self.hparams.layers,
443
- dropout=self.hparams.dropout,
444
- batch_norm=self.hparams.batch_norm,
445
- layer_norm=self.hparams.layer_norm,
446
- last_layer_act=self.hparams.last_layer_act,
396
+ n_batches_per_epoch = len(train_dataset) // batch_size
397
+ train_batches = create_batched_indices(
398
+ len(train_dataset), train_rng, batch_size, max_epochs * n_batches_per_epoch, weights
447
399
  )
448
400
 
449
- def forward(self, x: torch.Tensor) -> torch.Tensor:
450
- """Forward pass through the network.
401
+ best_val_loss = float("inf")
402
+ patience_counter = 0
451
403
 
452
- Args:
453
- x: Input tensor
404
+ for epoch in range(max_epochs):
405
+ epoch_train_loss = 0
454
406
 
455
- Returns:
456
- Network output tensor
457
- """
458
- x = self.net(x)
459
- return x
407
+ epoch_start = epoch * n_batches_per_epoch
408
+ epoch_end = (epoch + 1) * n_batches_per_epoch
409
+ epoch_batches = train_batches[epoch_start:epoch_end]
460
410
 
461
- def configure_optimizers(self) -> optim.Adam:
462
- """Configure optimizer for the model.
411
+ for _n_train_batches, batch_indices in enumerate(epoch_batches, 1):
412
+ rng, step_rng = random.split(rng)
413
+ batch_data, batch_labels, *_ = train_dataset.get_batch(batch_indices)
414
+ state, loss = train_step(state, (batch_data, batch_labels), step_rng)
415
+ epoch_train_loss += loss
463
416
 
464
- Returns:
465
- Adam optimizer with weight decay
466
- """
467
- optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
468
- return optimizer
469
-
470
- def training_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
471
- """Perform a training step.
472
-
473
- Args:
474
- batch: Tuple of (input, target, metadata)
475
- batch_idx: Index of the current batch
476
-
477
- Returns:
478
- Loss value
479
- """
480
- x, y, _ = batch
481
- x = x.to(torch.float32)
482
-
483
- y_hat = self.forward(x)
484
-
485
- y = torch.argmax(y, dim=1)
486
- y_hat = y_hat.squeeze()
417
+ if (epoch + 1) % val_epochs_check == 0:
418
+ val_losses = []
419
+ for i in range(0, len(val_dataset), batch_size):
420
+ val_indices = jnp.arange(i, min(i + batch_size, len(val_dataset)))
421
+ val_batch_data, val_batch_labels, _ = val_dataset.get_batch(val_indices)
422
+ val_loss = val_step(state, (val_batch_data, val_batch_labels))
423
+ val_losses.append(val_loss)
487
424
 
488
- loss = torch.nn.functional.cross_entropy(y_hat, y)
489
- self.log("train_loss", loss, prog_bar=True, batch_size=self.batch_size)
425
+ avg_val_loss = jnp.mean(jnp.array(val_losses))
490
426
 
491
- return loss
492
-
493
- def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
494
- """Perform a validation step.
495
-
496
- Args:
497
- batch: Tuple of (input, target, metadata)
498
- batch_idx: Index of the current batch
499
-
500
- Returns:
501
- Loss value
502
- """
503
- x, y, _ = batch
504
- x = x.to(torch.float32)
505
-
506
- y_hat = self.forward(x)
507
-
508
- y = torch.argmax(y, dim=1)
509
- y_hat = y_hat.squeeze()
510
-
511
- loss = torch.nn.functional.cross_entropy(y_hat, y)
512
- self.log("val_loss", loss, prog_bar=True, batch_size=self.batch_size)
513
-
514
- return loss
515
-
516
- def test_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
517
- """Perform a test step.
518
-
519
- Args:
520
- batch: Tuple of (input, target, metadata)
521
- batch_idx: Index of the current batch
522
-
523
- Returns:
524
- Loss value
525
- """
526
- x, y, _ = batch
527
- x = x.to(torch.float32)
427
+ if avg_val_loss < best_val_loss:
428
+ best_val_loss = avg_val_loss
429
+ patience_counter = 0
430
+ else:
431
+ patience_counter += 1
528
432
 
529
- y_hat = self.forward(x)
433
+ if patience_counter >= patience:
434
+ break
530
435
 
531
- y = torch.argmax(y, dim=1)
532
- y_hat = y_hat.squeeze()
436
+ # Test evaluation
437
+ test_losses = []
438
+ for i in range(0, len(test_dataset), batch_size):
439
+ test_indices = jnp.arange(i, min(i + batch_size, len(test_dataset)))
440
+ test_batch_data, test_batch_labels, _ = test_dataset.get_batch(test_indices)
441
+ test_loss = val_step(state, (test_batch_data, test_batch_labels))
442
+ test_losses.append(test_loss)
533
443
 
534
- loss = torch.nn.functional.cross_entropy(y_hat, y)
535
- self.log("test_loss", loss, prog_bar=True, batch_size=self.batch_size)
444
+ # Extract embeddings
445
+ embeddings_list = []
446
+ labels_list = []
536
447
 
537
- return loss
448
+ for i in range(0, len(total_dataset), batch_size * 2):
449
+ indices = jnp.arange(i, min(i + batch_size * 2, len(total_dataset)))
450
+ batch_data, _, batch_pert_labels = total_dataset.get_batch(indices)
451
+ batch_embeddings = get_embeddings(state, batch_data)
538
452
 
539
- def embedding(self, x: torch.Tensor) -> torch.Tensor:
540
- """Extract embeddings from input features.
453
+ embeddings_list.append(batch_embeddings)
454
+ labels_list.extend(batch_pert_labels)
541
455
 
542
- Args:
543
- x: Input tensor of shape [Batch, SeqLen, 1]
456
+ all_embeddings = jnp.concatenate(embeddings_list, axis=0)
544
457
 
545
- Returns:
546
- Embedded representation of the input
547
- """
548
- x = self.net.embedding(x)
549
- return x
458
+ pert_adata = AnnData(X=np.array(all_embeddings))
459
+ pert_adata.obs["perturbations"] = labels_list
550
460
 
551
- def get_embeddings(
552
- self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
553
- ) -> tuple[torch.Tensor, torch.Tensor]:
554
- """Extract embeddings from a batch.
461
+ adata_obs = adata.obs.reset_index(drop=True)
462
+ if "perturbations" in adata_obs.columns:
463
+ adata_obs = adata_obs.drop("perturbations", axis=1)
555
464
 
556
- Args:
557
- batch: Tuple of (input, target, metadata)
465
+ obs_subset = adata_obs.iloc[: len(pert_adata.obs)].copy()
466
+ cols_to_add = [col for col in obs_subset.columns if col not in ["perturbations", "encoded_perturbations"]]
467
+ new_cols_data = {col: obs_subset[col].values for col in cols_to_add}
558
468
 
559
- Returns:
560
- Tuple of (embeddings, metadata)
561
- """
562
- x, _, y = batch
563
- x = x.to(torch.float32)
469
+ if new_cols_data:
470
+ pert_adata.obs = pd.concat(
471
+ [pert_adata.obs, pd.DataFrame(new_cols_data, index=pert_adata.obs.index)], axis=1
472
+ )
564
473
 
565
- embedding = self.embedding(x)
566
- return embedding, y
474
+ return pert_adata