pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,6 +7,8 @@ from sklearn.metrics import pairwise_distances
7
7
  from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
8
8
 
9
9
  if TYPE_CHECKING:
10
+ from collections.abc import Iterable
11
+
10
12
  from anndata import AnnData
11
13
 
12
14
 
@@ -14,6 +16,7 @@ class ClusteringSpace(PerturbationSpace):
14
16
  """Applies various clustering techniques to an embedding."""
15
17
 
16
18
  def __init__(self):
19
+ super().__init__()
17
20
  self.X = None
18
21
 
19
22
  def evaluate_clustering(
@@ -21,7 +24,7 @@ class ClusteringSpace(PerturbationSpace):
21
24
  adata: AnnData,
22
25
  true_label_col: str,
23
26
  cluster_col: str,
24
- metrics: list[str] = None,
27
+ metrics: Iterable[str] = None,
25
28
  **kwargs,
26
29
  ):
27
30
  """Evaluation of previously computed clustering against ground truth labels.
@@ -30,7 +33,9 @@ class ClusteringSpace(PerturbationSpace):
30
33
  adata: AnnData object that contains the clustered data and the cluster labels.
31
34
  true_label_col: ground truth labels.
32
35
  cluster_col: cluster computed labels.
33
- metrics: Metrics to compute. Defaults to ['nmi', 'ari', 'asw'].
36
+ metrics: Metrics to compute. If `None` it defaults to ["nmi", "ari", "asw"].
37
+ **kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed.
38
+ For asw, metric, distances, sample_size, and random_state can be passed.
34
39
 
35
40
  Examples:
36
41
  Example usage with KMeansSpace:
@@ -39,7 +44,9 @@ class ClusteringSpace(PerturbationSpace):
39
44
  >>> mdata = pt.dt.papalexi_2021()
40
45
  >>> kmeans = pt.tl.KMeansSpace()
41
46
  >>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26)
42
- >>> results = kmeans.evaluate_clustering(kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=['nmi'])
47
+ >>> results = kmeans.evaluate_clustering(
48
+ ... kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=["nmi"]
49
+ ... )
43
50
  """
44
51
  if metrics is None:
45
52
  metrics = ["nmi", "ari", "asw"]
@@ -0,0 +1,112 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import numpy as np
4
+ import pynndescent
5
+ from scipy.sparse import issparse
6
+ from scipy.sparse import vstack as sp_vstack
7
+ from sklearn.base import ClassifierMixin
8
+ from sklearn.linear_model import LogisticRegression
9
+
10
+ if TYPE_CHECKING:
11
+ from numpy.typing import NDArray
12
+
13
+
14
+ class PerturbationComparison:
15
+ """Comparison between real and simulated perturbations."""
16
+
17
+ def compare_classification(
18
+ self,
19
+ real: np.ndarray,
20
+ simulated: np.ndarray,
21
+ control: np.ndarray,
22
+ clf: ClassifierMixin | None = None,
23
+ ) -> float:
24
+ """Compare classification accuracy between real and simulated perturbations.
25
+
26
+ Trains a classifier on the real perturbation data + the control data and reports a normalized
27
+ classification accuracy on the simulated perturbation.
28
+
29
+ Args:
30
+ real: Real perturbed data.
31
+ simulated: Simulated perturbed data.
32
+ control: Control data
33
+ clf: sklearn classifier to use, `sklearn.linear_model.LogisticRegression` if not provided.
34
+ """
35
+ assert real.shape[1] == simulated.shape[1] == control.shape[1]
36
+ if clf is None:
37
+ clf = LogisticRegression()
38
+ n_x = real.shape[0]
39
+ data = sp_vstack((real, control)) if issparse(real) else np.vstack((real, control))
40
+ labels = np.concatenate([np.full(real.shape[0], "comp"), np.full(control.shape[0], "ctrl")])
41
+
42
+ clf.fit(data, labels)
43
+ norm_score = clf.score(simulated, np.full(simulated.shape[0], "comp")) / clf.score(real, labels[:n_x])
44
+ norm_score = min(1.0, norm_score)
45
+
46
+ return norm_score
47
+
48
+ def compare_knn(
49
+ self,
50
+ real: np.ndarray,
51
+ simulated: np.ndarray,
52
+ control: np.ndarray | None = None,
53
+ use_simulated_for_knn: bool = False,
54
+ n_neighbors: int = 20,
55
+ random_state: int = 0,
56
+ n_jobs: int = 1,
57
+ ) -> dict[str, float]:
58
+ """Calculate proportions of real perturbed and control data points for simulated data.
59
+
60
+ Computes proportions of real perturbed, control and simulated (if `use_simulated_for_knn=True`)
61
+ data points for simulated data. If control (`C`) is not provided, builds the knn graph from
62
+ real perturbed + simulated perturbed.
63
+
64
+ Args:
65
+ real: Real perturbed data.
66
+ simulated: Simulated perturbed data.
67
+ control: Control data
68
+ use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when
69
+ control (`control`) is provided.
70
+ n_neighbors: Number of neighbors to use in k-neighbor graph.
71
+ random_state: Random state used for k-neighbor graph construction.
72
+ n_jobs: Number of cores to use. Defaults to -1 (all).
73
+
74
+ """
75
+ assert real.shape[1] == simulated.shape[1]
76
+ if control is not None:
77
+ assert real.shape[1] == control.shape[1]
78
+
79
+ n_y = simulated.shape[0]
80
+
81
+ if control is None:
82
+ index_data = sp_vstack((simulated, real)) if issparse(real) else np.vstack((simulated, real))
83
+ else:
84
+ datas = (simulated, real, control) if use_simulated_for_knn else (real, control)
85
+ index_data = sp_vstack(datas) if issparse(real) else np.vstack(datas)
86
+
87
+ y_in_index = use_simulated_for_knn or control is None
88
+ c_in_index = control is not None
89
+ label_groups = ["comp"]
90
+ labels: NDArray[np.str_] = np.full(index_data.shape[0], "comp")
91
+ if y_in_index:
92
+ labels[:n_y] = "siml"
93
+ label_groups.append("siml")
94
+ if c_in_index:
95
+ labels[-control.shape[0] :] = "ctrl"
96
+ label_groups.append("ctrl")
97
+
98
+ index = pynndescent.NNDescent(
99
+ index_data,
100
+ n_neighbors=max(50, n_neighbors),
101
+ random_state=random_state,
102
+ n_jobs=n_jobs,
103
+ )
104
+ indices = index.query(simulated, k=n_neighbors)[0]
105
+
106
+ uq, uq_counts = np.unique(labels[indices], return_counts=True)
107
+ uq_counts_norm = uq_counts / uq_counts.sum()
108
+ counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
109
+ for group, count_norm in zip(uq, uq_counts_norm, strict=False):
110
+ counts[group] = count_norm
111
+
112
+ return counts
@@ -0,0 +1,524 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+
5
+ import anndata
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pytorch_lightning as pl
9
+ import scipy
10
+ import torch
11
+ from anndata import AnnData
12
+ from pytorch_lightning.callbacks import EarlyStopping
13
+ from sklearn.linear_model import LogisticRegression
14
+ from sklearn.model_selection import train_test_split
15
+ from sklearn.preprocessing import OneHotEncoder
16
+ from torch import optim
17
+ from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
18
+
19
+ from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
20
+
21
+
22
+ class LRClassifierSpace(PerturbationSpace):
23
+ """Fits a logistic regression model to the data and takes the feature space as embedding.
24
+
25
+ We fit one logistic regression model per perturbation. After training, the coefficients of the logistic regression
26
+ model are used as the feature space. This results in one embedding per perturbation.
27
+ """
28
+
29
+ def compute(
30
+ self,
31
+ adata: AnnData,
32
+ target_col: str = "perturbations",
33
+ layer_key: str = None,
34
+ embedding_key: str = None,
35
+ test_split_size: float = 0.2,
36
+ max_iter: int = 1000,
37
+ ):
38
+ """
39
+ Fits a logistic regression model to the data and takes the coefficients of the logistic regression
40
+ model as perturbation embedding.
41
+
42
+ Args:
43
+ adata: AnnData object of size cells x genes
44
+ target_col: .obs column that stores the perturbations.
45
+ layer_key: Layer in adata to use.
46
+ embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
47
+ Can only be specified if layer_key is None.
48
+ test_split_size: Fraction of data to put in the test set.
49
+ max_iter: Maximum number of iterations taken for the solvers to converge.
50
+
51
+ Returns:
52
+ AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
53
+
54
+ Examples:
55
+ >>> import pertpy as pt
56
+ >>> adata = pt.dt.norman_2019()
57
+ >>> rcs = pt.tl.LRClassifierSpace()
58
+ >>> pert_embeddings = rcs.compute(adata, embedding_key="X_pca", target_col="perturbation_name")
59
+ """
60
+ if layer_key is not None and layer_key not in adata.obs.columns:
61
+ raise ValueError(f"Layer key {layer_key} not found in adata.")
62
+
63
+ if embedding_key is not None and embedding_key not in adata.obsm.keys():
64
+ raise ValueError(f"Embedding key {embedding_key} not found in adata.obsm.")
65
+
66
+ if layer_key is not None and embedding_key is not None:
67
+ raise ValueError("Cannot specify both layer_key and embedding_key.")
68
+
69
+ if target_col not in adata.obs:
70
+ raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
71
+
72
+ if layer_key is not None:
73
+ regression_data = adata.layers[layer_key]
74
+ elif embedding_key is not None:
75
+ regression_data = adata.obsm[embedding_key]
76
+ else:
77
+ regression_data = adata.X
78
+
79
+ regression_labels = adata.obs[target_col]
80
+
81
+ # Save adata observations for embedding annotations in get_embeddings
82
+ adata_obs = adata.obs.reset_index(drop=True)
83
+ adata_obs = adata_obs.groupby(target_col).agg(
84
+ lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
85
+ )
86
+
87
+ # Fit a logistic regression model for each perturbation
88
+ regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")
89
+ regression_embeddings = {}
90
+ regression_scores = {}
91
+
92
+ for perturbation in regression_labels.unique():
93
+ labels = np.where(regression_labels == perturbation, 1, 0)
94
+ X_train, X_test, y_train, y_test = train_test_split(
95
+ regression_data, labels, test_size=test_split_size, stratify=labels
96
+ )
97
+
98
+ regression_model.fit(X_train, y_train)
99
+ regression_embeddings[perturbation] = regression_model.coef_
100
+ regression_scores[perturbation] = regression_model.score(X_test, y_test)
101
+
102
+ # Save the regression embeddings and scores in an AnnData object
103
+ pert_adata = AnnData(X=np.array(list(regression_embeddings.values())).squeeze())
104
+ pert_adata.obs["perturbations"] = list(regression_embeddings.keys())
105
+ pert_adata.obs["classifier_score"] = list(regression_scores.values())
106
+
107
+ # Save adata observations for embedding annotations
108
+ for obs_name in adata_obs.columns:
109
+ if not adata_obs[obs_name].isnull().values.any():
110
+ pert_adata.obs[obs_name] = pert_adata.obs["perturbations"].map(
111
+ {pert: adata_obs.loc[pert][obs_name] for pert in adata_obs.index}
112
+ )
113
+
114
+ return pert_adata
115
+
116
+
117
+ # Ensure backward compatibility with DiscriminatorClassifierSpace
118
+ def DiscriminatorClassifierSpace():
119
+ warnings.warn(
120
+ "The DiscriminatorClassifierSpace class is deprecated and will be removed in the future."
121
+ "Please use the MLPClassifierSpace or the LRClassifierSpace class instead.",
122
+ DeprecationWarning,
123
+ stacklevel=2,
124
+ )
125
+
126
+ return MLPClassifierSpace()
127
+
128
+
129
+ class MLPClassifierSpace(PerturbationSpace):
130
+ """Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
131
+
132
+ We train the ANN to classify the different perturbations. After training, the penultimate layer is used as the
133
+ feature space, resulting in one embedding per cell. Consider employing the PseudoBulk or another PerturbationSpace
134
+ to obtain one embedding per perturbation.
135
+
136
+ See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
137
+ """
138
+
139
+ def compute( # type: ignore
140
+ self,
141
+ adata: AnnData,
142
+ target_col: str = "perturbations",
143
+ layer_key: str = None,
144
+ hidden_dim: list[int] = None,
145
+ dropout: float = 0.0,
146
+ batch_norm: bool = True,
147
+ batch_size: int = 256,
148
+ test_split_size: float = 0.2,
149
+ validation_split_size: float = 0.25,
150
+ max_epochs: int = 20,
151
+ val_epochs_check: int = 2,
152
+ patience: int = 2,
153
+ ) -> AnnData:
154
+ """Creates cell embeddings by training a MLP classifier model to distinguish between perturbations.
155
+
156
+ A model is created using the specified parameters (hidden_dim, dropout, batch_norm). Further parameters such as
157
+ the number of classes to predict (number of perturbations) are obtained from the provided AnnData object directly.
158
+ Dataloaders that take into account class imbalances are created. Next, the model is trained and tested, using the
159
+ GPU if available. The embeddings are obtained by passing the data through the model and extracting the values in
160
+ the last layer of the MLP. You will get one embedding per cell, so be aware that you might need to apply another
161
+ perturbation space to aggregate the embeddings per perturbation.
162
+
163
+ Args:
164
+ adata: AnnData object of size cells x genes
165
+ target_col: .obs column that stores the perturbations.
166
+ layer_key: Layer in adata to use.
167
+ hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
168
+ will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
169
+ dropout: Amount of dropout applied, constant for all layers.
170
+ batch_norm: Whether to apply batch normalization.
171
+ batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass.
172
+ test_split_size: Fraction of data to put in the test set. Default to 0.2.
173
+ validation_split_size: Fraction of data to put in the validation set of the resultant train set.
174
+ E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
175
+ will be used for validation.
176
+ max_epochs: Maximum number of epochs for training.
177
+ val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
178
+ Note that this affects early stopping, as the model will be stopped if the validation performance does not
179
+ improve for patience epochs.
180
+ patience: Number of validation performance checks without improvement, after which the early stopping flag
181
+ is activated and training is therefore stopped.
182
+
183
+ Returns:
184
+ AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
185
+ The AnnData will have shape (n_cells, n_features) where n_features is the number of features in the last layer of the MLP.
186
+
187
+ Examples:
188
+ >>> import pertpy as pt
189
+ >>> adata = pt.dt.norman_2019()
190
+ >>> dcs = pt.tl.MLPClassifierSpace()
191
+ >>> cell_embeddings = dcs.compute(adata, target_col="perturbation_name")
192
+ """
193
+ if layer_key is not None and layer_key not in adata.obs.columns:
194
+ raise ValueError(f"Layer key {layer_key} not found in adata.")
195
+
196
+ if target_col not in adata.obs:
197
+ raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
198
+
199
+ if hidden_dim is None:
200
+ hidden_dim = [512]
201
+
202
+ # Labels are strings, one hot encoding for classification
203
+ n_classes = len(adata.obs[target_col].unique())
204
+ labels = adata.obs[target_col].values.reshape(-1, 1)
205
+ encoder = OneHotEncoder()
206
+ encoded_labels = encoder.fit_transform(labels).toarray()
207
+ adata.obs["encoded_perturbations"] = [np.float32(label) for label in encoded_labels]
208
+
209
+ # Split the data in train, test and validation
210
+ X = list(range(0, adata.n_obs))
211
+ y = adata.obs[target_col]
212
+
213
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
214
+ X_train, X_val, y_train, y_val = train_test_split(
215
+ X_train, y_train, test_size=validation_split_size, stratify=y_train
216
+ )
217
+
218
+ train_dataset = PLDataset(
219
+ adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
220
+ )
221
+ val_dataset = PLDataset(
222
+ adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
223
+ )
224
+ test_dataset = PLDataset(
225
+ adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
226
+ ) # we don't need to pass y_test since the label selection is done inside
227
+
228
+ # Fix class unbalance (likely to happen in perturbation datasets)
229
+ # Usually control cells are overrepresented such that predicting control all time would give good results
230
+ # Cells with rare perturbations are sampled more
231
+ train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
232
+ train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
233
+
234
+ self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
235
+ self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
236
+ self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
237
+
238
+ # Define the network
239
+ sizes = [adata.n_vars] + hidden_dim + [n_classes]
240
+ self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
241
+
242
+ # Define a dataset that gathers all the data and dataloader for getting embeddings
243
+ total_dataset = PLDataset(
244
+ adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
245
+ )
246
+ self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=0)
247
+
248
+ # Save adata observations for embedding annotations in get_embeddings
249
+ self.adata_obs = adata.obs.reset_index(drop=True)
250
+
251
+ self.trainer = pl.Trainer(
252
+ min_epochs=1,
253
+ max_epochs=max_epochs,
254
+ check_val_every_n_epoch=val_epochs_check,
255
+ callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
256
+ devices="auto",
257
+ accelerator="auto",
258
+ )
259
+
260
+ self.mlp = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)
261
+
262
+ self.trainer.fit(model=self.mlp, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader)
263
+ self.trainer.test(model=self.mlp, dataloaders=self.test_dataloader)
264
+
265
+ # Obtain cell embeddings
266
+ with torch.no_grad():
267
+ self.mlp.eval()
268
+ for dataset_count, batch in enumerate(self.entire_dataset):
269
+ emb, y = self.mlp.get_embeddings(batch)
270
+ emb = torch.squeeze(emb)
271
+ batch_adata = AnnData(X=emb.cpu().numpy())
272
+ batch_adata.obs["perturbations"] = y
273
+ if dataset_count == 0:
274
+ pert_adata = batch_adata
275
+ else:
276
+ pert_adata = anndata.concat([pert_adata, batch_adata])
277
+
278
+ # Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
279
+ # and we can just add the annotations from the original AnnData object
280
+ pert_adata.obs = pert_adata.obs.reset_index(drop=True)
281
+ if "perturbations" in self.adata_obs.columns:
282
+ self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
283
+ pert_adata.obs = pd.concat([pert_adata.obs, self.adata_obs], axis=1)
284
+
285
+ # Drop the 'encoded_perturbations' colums, since this stores the one-hot encoded labels as numpy arrays,
286
+ # which would cause errors in the downstream processing of the AnnData object (e.g. when plotting)
287
+ pert_adata.obs = pert_adata.obs.drop("encoded_perturbations", axis=1)
288
+
289
+ return pert_adata
290
+
291
+ def load(self, adata, **kwargs):
292
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
293
+ raise DeprecationWarning(
294
+ "The load method is deprecated and will be removed in the future. Please use the compute method instead."
295
+ )
296
+
297
+ def train(self, **kwargs):
298
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
299
+ raise DeprecationWarning(
300
+ "The train method is deprecated and will be removed in the future. Please use the compute method instead."
301
+ )
302
+
303
+ def get_embeddings(self, **kwargs):
304
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
305
+ raise DeprecationWarning(
306
+ "The get_embeddings method is deprecated and will be removed in the future. Please use the compute method instead."
307
+ )
308
+
309
+
310
+ class MLP(torch.nn.Module):
311
+ """
312
+ A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ sizes: list[int],
318
+ dropout: float = 0.0,
319
+ batch_norm: bool = True,
320
+ layer_norm: bool = False,
321
+ last_layer_act: str = "linear",
322
+ ) -> None:
323
+ """
324
+ Args:
325
+ sizes: size of layers.
326
+ dropout: Dropout probability.
327
+ batch_norm: specifies if batch norm should be applied.
328
+ layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
329
+ last_layer_act: activation function of last layer.
330
+ """
331
+ super().__init__()
332
+ layers = []
333
+ for s in range(len(sizes) - 1):
334
+ layers += [
335
+ torch.nn.Linear(sizes[s], sizes[s + 1]),
336
+ torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
337
+ torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
338
+ torch.nn.ReLU(),
339
+ torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
340
+ ]
341
+
342
+ layers = [layer for layer in layers if layer is not None][:-1]
343
+ self.activation = last_layer_act
344
+ if self.activation == "linear":
345
+ pass
346
+ elif self.activation == "ReLU":
347
+ self.relu = torch.nn.ReLU()
348
+ else:
349
+ raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
350
+
351
+ self.network = torch.nn.Sequential(*layers)
352
+
353
+ self.network.apply(init_weights)
354
+
355
+ self.sizes = sizes
356
+ self.batch_norm = batch_norm
357
+ self.layer_norm = layer_norm
358
+ self.last_layer_act = last_layer_act
359
+
360
+ def forward(self, x) -> torch.Tensor:
361
+ if self.activation == "ReLU":
362
+ return self.relu(self.network(x))
363
+ return self.network(x)
364
+
365
+ def embedding(self, x) -> torch.Tensor:
366
+ for layer in self.network[:-1]:
367
+ x = layer(x)
368
+ return x
369
+
370
+
371
+ def init_weights(m):
372
+ if isinstance(m, torch.nn.Linear):
373
+ torch.nn.init.kaiming_uniform_(m.weight)
374
+ m.bias.data.fill_(0.01)
375
+
376
+
377
+ class PLDataset(Dataset):
378
+ """
379
+ Dataset for perturbation classification.
380
+ Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ adata: np.array,
386
+ target_col: str = "perturbations",
387
+ label_col: str = "perturbations",
388
+ layer_key: str = None,
389
+ ):
390
+ """
391
+ Args:
392
+ adata: AnnData object with observations and labels.
393
+ target_col: key with the perturbation labels numerically encoded.
394
+ label_col: key with the perturbation labels.
395
+ layer_key: key of the layer to be used as data, otherwise .X
396
+ """
397
+
398
+ if layer_key:
399
+ self.data = adata.layers[layer_key]
400
+ else:
401
+ self.data = adata.X
402
+
403
+ self.labels = adata.obs[target_col]
404
+ self.pert_labels = adata.obs[label_col]
405
+
406
+ def __len__(self):
407
+ return self.data.shape[0]
408
+
409
+ def __getitem__(self, idx):
410
+ """Returns a sample and corresponding perturbations applied (labels)"""
411
+ sample = self.data[idx].A.squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
412
+ num_label = self.labels.iloc[idx]
413
+ str_label = self.pert_labels.iloc[idx]
414
+
415
+ return sample, num_label, str_label
416
+
417
+
418
+ class PerturbationClassifier(pl.LightningModule):
419
+ def __init__(
420
+ self,
421
+ model: torch.nn.Module,
422
+ batch_size: int,
423
+ layers: list = [512], # noqa
424
+ dropout: float = 0.0,
425
+ batch_norm: bool = True,
426
+ layer_norm: bool = False,
427
+ last_layer_act: str = "linear",
428
+ lr=1e-4,
429
+ seed=42,
430
+ ):
431
+ """
432
+ Args:
433
+ model: model to be trained
434
+ batch_size: batch size
435
+ layers: list of layers of the MLP
436
+ dropout: dropout probability
437
+ batch_norm: whether to apply batch norm
438
+ layer_norm: whether to apply layer norm
439
+ last_layer_act: activation function of last layer
440
+ lr: learning rate
441
+ seed: random seed
442
+ """
443
+ super().__init__()
444
+ self.batch_size = batch_size
445
+ self.save_hyperparameters()
446
+ if model:
447
+ self.net = model
448
+ else:
449
+ self._create_model()
450
+
451
+ def _create_model(self):
452
+ self.net = MLP(
453
+ sizes=self.hparams.layers,
454
+ dropout=self.hparams.dropout,
455
+ batch_norm=self.hparams.batch_norm,
456
+ layer_norm=self.hparams.layer_norm,
457
+ last_layer_act=self.hparams.last_layer_act,
458
+ )
459
+
460
+ def forward(self, x):
461
+ x = self.net(x)
462
+ return x
463
+
464
+ def configure_optimizers(self):
465
+ optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
466
+
467
+ return optimizer
468
+
469
+ def training_step(self, batch, batch_idx):
470
+ x, y, _ = batch
471
+ x = x.to(torch.float32)
472
+
473
+ y_hat = self.forward(x)
474
+
475
+ y = torch.argmax(y, dim=1)
476
+ y_hat = y_hat.squeeze()
477
+
478
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
479
+ self.log("train_loss", loss, prog_bar=True, batch_size=self.batch_size)
480
+
481
+ return loss
482
+
483
+ def validation_step(self, batch, batch_idx):
484
+ x, y, _ = batch
485
+ x = x.to(torch.float32)
486
+
487
+ y_hat = self.forward(x)
488
+
489
+ y = torch.argmax(y, dim=1)
490
+ y_hat = y_hat.squeeze()
491
+
492
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
493
+ self.log("val_loss", loss, prog_bar=True, batch_size=self.batch_size)
494
+
495
+ return loss
496
+
497
+ def test_step(self, batch, batch_idx):
498
+ x, y, _ = batch
499
+ x = x.to(torch.float32)
500
+
501
+ y_hat = self.forward(x)
502
+
503
+ y = torch.argmax(y, dim=1)
504
+ y_hat = y_hat.squeeze()
505
+
506
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
507
+ self.log("test_loss", loss, prog_bar=True, batch_size=self.batch_size)
508
+
509
+ return loss
510
+
511
+ def embedding(self, x):
512
+ """
513
+ Inputs:
514
+ x: Input features of shape [Batch, SeqLen, 1]
515
+ """
516
+ x = self.net.embedding(x)
517
+ return x
518
+
519
+ def get_embeddings(self, batch):
520
+ x, _, y = batch
521
+ x = x.to(torch.float32)
522
+
523
+ embedding = self.embedding(x)
524
+ return embedding, y