pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,6 +7,8 @@ from sklearn.metrics import pairwise_distances
7
7
  from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
8
8
 
9
9
  if TYPE_CHECKING:
10
+ from collections.abc import Iterable
11
+
10
12
  from anndata import AnnData
11
13
 
12
14
 
@@ -14,6 +16,7 @@ class ClusteringSpace(PerturbationSpace):
14
16
  """Applies various clustering techniques to an embedding."""
15
17
 
16
18
  def __init__(self):
19
+ super().__init__()
17
20
  self.X = None
18
21
 
19
22
  def evaluate_clustering(
@@ -21,7 +24,7 @@ class ClusteringSpace(PerturbationSpace):
21
24
  adata: AnnData,
22
25
  true_label_col: str,
23
26
  cluster_col: str,
24
- metrics: list[str] = None,
27
+ metrics: Iterable[str] = None,
25
28
  **kwargs,
26
29
  ):
27
30
  """Evaluation of previously computed clustering against ground truth labels.
@@ -30,7 +33,9 @@ class ClusteringSpace(PerturbationSpace):
30
33
  adata: AnnData object that contains the clustered data and the cluster labels.
31
34
  true_label_col: ground truth labels.
32
35
  cluster_col: cluster computed labels.
33
- metrics: Metrics to compute. Defaults to ['nmi', 'ari', 'asw'].
36
+ metrics: Metrics to compute. If `None` it defaults to ["nmi", "ari", "asw"].
37
+ **kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed.
38
+ For asw, metric, distances, sample_size, and random_state can be passed.
34
39
 
35
40
  Examples:
36
41
  Example usage with KMeansSpace:
@@ -39,7 +44,9 @@ class ClusteringSpace(PerturbationSpace):
39
44
  >>> mdata = pt.dt.papalexi_2021()
40
45
  >>> kmeans = pt.tl.KMeansSpace()
41
46
  >>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26)
42
- >>> results = kmeans.evaluate_clustering(kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=['nmi'])
47
+ >>> results = kmeans.evaluate_clustering(
48
+ ... kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=["nmi"]
49
+ ... )
43
50
  """
44
51
  if metrics is None:
45
52
  metrics = ["nmi", "ari", "asw"]
@@ -0,0 +1,112 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import numpy as np
4
+ import pynndescent
5
+ from scipy.sparse import issparse
6
+ from scipy.sparse import vstack as sp_vstack
7
+ from sklearn.base import ClassifierMixin
8
+ from sklearn.linear_model import LogisticRegression
9
+
10
+ if TYPE_CHECKING:
11
+ from numpy.typing import NDArray
12
+
13
+
14
+ class PerturbationComparison:
15
+ """Comparison between real and simulated perturbations."""
16
+
17
+ def compare_classification(
18
+ self,
19
+ real: np.ndarray,
20
+ simulated: np.ndarray,
21
+ control: np.ndarray,
22
+ clf: ClassifierMixin | None = None,
23
+ ) -> float:
24
+ """Compare classification accuracy between real and simulated perturbations.
25
+
26
+ Trains a classifier on the real perturbation data + the control data and reports a normalized
27
+ classification accuracy on the simulated perturbation.
28
+
29
+ Args:
30
+ real: Real perturbed data.
31
+ simulated: Simulated perturbed data.
32
+ control: Control data
33
+ clf: sklearn classifier to use, `sklearn.linear_model.LogisticRegression` if not provided.
34
+ """
35
+ assert real.shape[1] == simulated.shape[1] == control.shape[1]
36
+ if clf is None:
37
+ clf = LogisticRegression()
38
+ n_x = real.shape[0]
39
+ data = sp_vstack((real, control)) if issparse(real) else np.vstack((real, control))
40
+ labels = np.concatenate([np.full(real.shape[0], "comp"), np.full(control.shape[0], "ctrl")])
41
+
42
+ clf.fit(data, labels)
43
+ norm_score = clf.score(simulated, np.full(simulated.shape[0], "comp")) / clf.score(real, labels[:n_x])
44
+ norm_score = min(1.0, norm_score)
45
+
46
+ return norm_score
47
+
48
+ def compare_knn(
49
+ self,
50
+ real: np.ndarray,
51
+ simulated: np.ndarray,
52
+ control: np.ndarray | None = None,
53
+ use_simulated_for_knn: bool = False,
54
+ n_neighbors: int = 20,
55
+ random_state: int = 0,
56
+ n_jobs: int = 1,
57
+ ) -> dict[str, float]:
58
+ """Calculate proportions of real perturbed and control data points for simulated data.
59
+
60
+ Computes proportions of real perturbed, control and simulated (if `use_simulated_for_knn=True`)
61
+ data points for simulated data. If control (`C`) is not provided, builds the knn graph from
62
+ real perturbed + simulated perturbed.
63
+
64
+ Args:
65
+ real: Real perturbed data.
66
+ simulated: Simulated perturbed data.
67
+ control: Control data
68
+ use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when
69
+ control (`control`) is provided.
70
+ n_neighbors: Number of neighbors to use in k-neighbor graph.
71
+ random_state: Random state used for k-neighbor graph construction.
72
+ n_jobs: Number of cores to use. Defaults to -1 (all).
73
+
74
+ """
75
+ assert real.shape[1] == simulated.shape[1]
76
+ if control is not None:
77
+ assert real.shape[1] == control.shape[1]
78
+
79
+ n_y = simulated.shape[0]
80
+
81
+ if control is None:
82
+ index_data = sp_vstack((simulated, real)) if issparse(real) else np.vstack((simulated, real))
83
+ else:
84
+ datas = (simulated, real, control) if use_simulated_for_knn else (real, control)
85
+ index_data = sp_vstack(datas) if issparse(real) else np.vstack(datas)
86
+
87
+ y_in_index = use_simulated_for_knn or control is None
88
+ c_in_index = control is not None
89
+ label_groups = ["comp"]
90
+ labels: NDArray[np.str_] = np.full(index_data.shape[0], "comp")
91
+ if y_in_index:
92
+ labels[:n_y] = "siml"
93
+ label_groups.append("siml")
94
+ if c_in_index:
95
+ labels[-control.shape[0] :] = "ctrl"
96
+ label_groups.append("ctrl")
97
+
98
+ index = pynndescent.NNDescent(
99
+ index_data,
100
+ n_neighbors=max(50, n_neighbors),
101
+ random_state=random_state,
102
+ n_jobs=n_jobs,
103
+ )
104
+ indices = index.query(simulated, k=n_neighbors)[0]
105
+
106
+ uq, uq_counts = np.unique(labels[indices], return_counts=True)
107
+ uq_counts_norm = uq_counts / uq_counts.sum()
108
+ counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
109
+ for group, count_norm in zip(uq, uq_counts_norm, strict=False):
110
+ counts[group] = count_norm
111
+
112
+ return counts
@@ -0,0 +1,524 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+
5
+ import anndata
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pytorch_lightning as pl
9
+ import scipy
10
+ import torch
11
+ from anndata import AnnData
12
+ from pytorch_lightning.callbacks import EarlyStopping
13
+ from sklearn.linear_model import LogisticRegression
14
+ from sklearn.model_selection import train_test_split
15
+ from sklearn.preprocessing import OneHotEncoder
16
+ from torch import optim
17
+ from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
18
+
19
+ from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
20
+
21
+
22
+ class LRClassifierSpace(PerturbationSpace):
23
+ """Fits a logistic regression model to the data and takes the feature space as embedding.
24
+
25
+ We fit one logistic regression model per perturbation. After training, the coefficients of the logistic regression
26
+ model are used as the feature space. This results in one embedding per perturbation.
27
+ """
28
+
29
+ def compute(
30
+ self,
31
+ adata: AnnData,
32
+ target_col: str = "perturbations",
33
+ layer_key: str = None,
34
+ embedding_key: str = None,
35
+ test_split_size: float = 0.2,
36
+ max_iter: int = 1000,
37
+ ):
38
+ """
39
+ Fits a logistic regression model to the data and takes the coefficients of the logistic regression
40
+ model as perturbation embedding.
41
+
42
+ Args:
43
+ adata: AnnData object of size cells x genes
44
+ target_col: .obs column that stores the perturbations.
45
+ layer_key: Layer in adata to use.
46
+ embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
47
+ Can only be specified if layer_key is None.
48
+ test_split_size: Fraction of data to put in the test set.
49
+ max_iter: Maximum number of iterations taken for the solvers to converge.
50
+
51
+ Returns:
52
+ AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
53
+
54
+ Examples:
55
+ >>> import pertpy as pt
56
+ >>> adata = pt.dt.norman_2019()
57
+ >>> rcs = pt.tl.LRClassifierSpace()
58
+ >>> pert_embeddings = rcs.compute(adata, embedding_key="X_pca", target_col="perturbation_name")
59
+ """
60
+ if layer_key is not None and layer_key not in adata.obs.columns:
61
+ raise ValueError(f"Layer key {layer_key} not found in adata.")
62
+
63
+ if embedding_key is not None and embedding_key not in adata.obsm.keys():
64
+ raise ValueError(f"Embedding key {embedding_key} not found in adata.obsm.")
65
+
66
+ if layer_key is not None and embedding_key is not None:
67
+ raise ValueError("Cannot specify both layer_key and embedding_key.")
68
+
69
+ if target_col not in adata.obs:
70
+ raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
71
+
72
+ if layer_key is not None:
73
+ regression_data = adata.layers[layer_key]
74
+ elif embedding_key is not None:
75
+ regression_data = adata.obsm[embedding_key]
76
+ else:
77
+ regression_data = adata.X
78
+
79
+ regression_labels = adata.obs[target_col]
80
+
81
+ # Save adata observations for embedding annotations in get_embeddings
82
+ adata_obs = adata.obs.reset_index(drop=True)
83
+ adata_obs = adata_obs.groupby(target_col).agg(
84
+ lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
85
+ )
86
+
87
+ # Fit a logistic regression model for each perturbation
88
+ regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")
89
+ regression_embeddings = {}
90
+ regression_scores = {}
91
+
92
+ for perturbation in regression_labels.unique():
93
+ labels = np.where(regression_labels == perturbation, 1, 0)
94
+ X_train, X_test, y_train, y_test = train_test_split(
95
+ regression_data, labels, test_size=test_split_size, stratify=labels
96
+ )
97
+
98
+ regression_model.fit(X_train, y_train)
99
+ regression_embeddings[perturbation] = regression_model.coef_
100
+ regression_scores[perturbation] = regression_model.score(X_test, y_test)
101
+
102
+ # Save the regression embeddings and scores in an AnnData object
103
+ pert_adata = AnnData(X=np.array(list(regression_embeddings.values())).squeeze())
104
+ pert_adata.obs["perturbations"] = list(regression_embeddings.keys())
105
+ pert_adata.obs["classifier_score"] = list(regression_scores.values())
106
+
107
+ # Save adata observations for embedding annotations
108
+ for obs_name in adata_obs.columns:
109
+ if not adata_obs[obs_name].isnull().values.any():
110
+ pert_adata.obs[obs_name] = pert_adata.obs["perturbations"].map(
111
+ {pert: adata_obs.loc[pert][obs_name] for pert in adata_obs.index}
112
+ )
113
+
114
+ return pert_adata
115
+
116
+
117
+ # Ensure backward compatibility with DiscriminatorClassifierSpace
118
+ def DiscriminatorClassifierSpace():
119
+ warnings.warn(
120
+ "The DiscriminatorClassifierSpace class is deprecated and will be removed in the future."
121
+ "Please use the MLPClassifierSpace or the LRClassifierSpace class instead.",
122
+ DeprecationWarning,
123
+ stacklevel=2,
124
+ )
125
+
126
+ return MLPClassifierSpace()
127
+
128
+
129
+ class MLPClassifierSpace(PerturbationSpace):
130
+ """Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
131
+
132
+ We train the ANN to classify the different perturbations. After training, the penultimate layer is used as the
133
+ feature space, resulting in one embedding per cell. Consider employing the PseudoBulk or another PerturbationSpace
134
+ to obtain one embedding per perturbation.
135
+
136
+ See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
137
+ """
138
+
139
+ def compute( # type: ignore
140
+ self,
141
+ adata: AnnData,
142
+ target_col: str = "perturbations",
143
+ layer_key: str = None,
144
+ hidden_dim: list[int] = None,
145
+ dropout: float = 0.0,
146
+ batch_norm: bool = True,
147
+ batch_size: int = 256,
148
+ test_split_size: float = 0.2,
149
+ validation_split_size: float = 0.25,
150
+ max_epochs: int = 20,
151
+ val_epochs_check: int = 2,
152
+ patience: int = 2,
153
+ ) -> AnnData:
154
+ """Creates cell embeddings by training a MLP classifier model to distinguish between perturbations.
155
+
156
+ A model is created using the specified parameters (hidden_dim, dropout, batch_norm). Further parameters such as
157
+ the number of classes to predict (number of perturbations) are obtained from the provided AnnData object directly.
158
+ Dataloaders that take into account class imbalances are created. Next, the model is trained and tested, using the
159
+ GPU if available. The embeddings are obtained by passing the data through the model and extracting the values in
160
+ the last layer of the MLP. You will get one embedding per cell, so be aware that you might need to apply another
161
+ perturbation space to aggregate the embeddings per perturbation.
162
+
163
+ Args:
164
+ adata: AnnData object of size cells x genes
165
+ target_col: .obs column that stores the perturbations.
166
+ layer_key: Layer in adata to use.
167
+ hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
168
+ will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
169
+ dropout: Amount of dropout applied, constant for all layers.
170
+ batch_norm: Whether to apply batch normalization.
171
+ batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass.
172
+ test_split_size: Fraction of data to put in the test set. Default to 0.2.
173
+ validation_split_size: Fraction of data to put in the validation set of the resultant train set.
174
+ E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
175
+ will be used for validation.
176
+ max_epochs: Maximum number of epochs for training.
177
+ val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
178
+ Note that this affects early stopping, as the model will be stopped if the validation performance does not
179
+ improve for patience epochs.
180
+ patience: Number of validation performance checks without improvement, after which the early stopping flag
181
+ is activated and training is therefore stopped.
182
+
183
+ Returns:
184
+ AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
185
+ The AnnData will have shape (n_cells, n_features) where n_features is the number of features in the last layer of the MLP.
186
+
187
+ Examples:
188
+ >>> import pertpy as pt
189
+ >>> adata = pt.dt.norman_2019()
190
+ >>> dcs = pt.tl.MLPClassifierSpace()
191
+ >>> cell_embeddings = dcs.compute(adata, target_col="perturbation_name")
192
+ """
193
+ if layer_key is not None and layer_key not in adata.obs.columns:
194
+ raise ValueError(f"Layer key {layer_key} not found in adata.")
195
+
196
+ if target_col not in adata.obs:
197
+ raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
198
+
199
+ if hidden_dim is None:
200
+ hidden_dim = [512]
201
+
202
+ # Labels are strings, one hot encoding for classification
203
+ n_classes = len(adata.obs[target_col].unique())
204
+ labels = adata.obs[target_col].values.reshape(-1, 1)
205
+ encoder = OneHotEncoder()
206
+ encoded_labels = encoder.fit_transform(labels).toarray()
207
+ adata.obs["encoded_perturbations"] = [np.float32(label) for label in encoded_labels]
208
+
209
+ # Split the data in train, test and validation
210
+ X = list(range(0, adata.n_obs))
211
+ y = adata.obs[target_col]
212
+
213
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split_size, stratify=y)
214
+ X_train, X_val, y_train, y_val = train_test_split(
215
+ X_train, y_train, test_size=validation_split_size, stratify=y_train
216
+ )
217
+
218
+ train_dataset = PLDataset(
219
+ adata=adata[X_train], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
220
+ )
221
+ val_dataset = PLDataset(
222
+ adata=adata[X_val], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
223
+ )
224
+ test_dataset = PLDataset(
225
+ adata=adata[X_test], target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
226
+ ) # we don't need to pass y_test since the label selection is done inside
227
+
228
+ # Fix class unbalance (likely to happen in perturbation datasets)
229
+ # Usually control cells are overrepresented such that predicting control all time would give good results
230
+ # Cells with rare perturbations are sampled more
231
+ train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
232
+ train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
233
+
234
+ self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
235
+ self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
236
+ self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
237
+
238
+ # Define the network
239
+ sizes = [adata.n_vars] + hidden_dim + [n_classes]
240
+ self.net = MLP(sizes=sizes, dropout=dropout, batch_norm=batch_norm)
241
+
242
+ # Define a dataset that gathers all the data and dataloader for getting embeddings
243
+ total_dataset = PLDataset(
244
+ adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
245
+ )
246
+ self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=0)
247
+
248
+ # Save adata observations for embedding annotations in get_embeddings
249
+ self.adata_obs = adata.obs.reset_index(drop=True)
250
+
251
+ self.trainer = pl.Trainer(
252
+ min_epochs=1,
253
+ max_epochs=max_epochs,
254
+ check_val_every_n_epoch=val_epochs_check,
255
+ callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=patience)],
256
+ devices="auto",
257
+ accelerator="auto",
258
+ )
259
+
260
+ self.mlp = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)
261
+
262
+ self.trainer.fit(model=self.mlp, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader)
263
+ self.trainer.test(model=self.mlp, dataloaders=self.test_dataloader)
264
+
265
+ # Obtain cell embeddings
266
+ with torch.no_grad():
267
+ self.mlp.eval()
268
+ for dataset_count, batch in enumerate(self.entire_dataset):
269
+ emb, y = self.mlp.get_embeddings(batch)
270
+ emb = torch.squeeze(emb)
271
+ batch_adata = AnnData(X=emb.cpu().numpy())
272
+ batch_adata.obs["perturbations"] = y
273
+ if dataset_count == 0:
274
+ pert_adata = batch_adata
275
+ else:
276
+ pert_adata = anndata.concat([pert_adata, batch_adata])
277
+
278
+ # Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
279
+ # and we can just add the annotations from the original AnnData object
280
+ pert_adata.obs = pert_adata.obs.reset_index(drop=True)
281
+ if "perturbations" in self.adata_obs.columns:
282
+ self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
283
+ pert_adata.obs = pd.concat([pert_adata.obs, self.adata_obs], axis=1)
284
+
285
+ # Drop the 'encoded_perturbations' colums, since this stores the one-hot encoded labels as numpy arrays,
286
+ # which would cause errors in the downstream processing of the AnnData object (e.g. when plotting)
287
+ pert_adata.obs = pert_adata.obs.drop("encoded_perturbations", axis=1)
288
+
289
+ return pert_adata
290
+
291
+ def load(self, adata, **kwargs):
292
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
293
+ raise DeprecationWarning(
294
+ "The load method is deprecated and will be removed in the future. Please use the compute method instead."
295
+ )
296
+
297
+ def train(self, **kwargs):
298
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
299
+ raise DeprecationWarning(
300
+ "The train method is deprecated and will be removed in the future. Please use the compute method instead."
301
+ )
302
+
303
+ def get_embeddings(self, **kwargs):
304
+ """This method is deprecated and will be removed in the future. Please use the compute method instead."""
305
+ raise DeprecationWarning(
306
+ "The get_embeddings method is deprecated and will be removed in the future. Please use the compute method instead."
307
+ )
308
+
309
+
310
+ class MLP(torch.nn.Module):
311
+ """
312
+ A multilayer perceptron with ReLU activations, optional Dropout and optional BatchNorm.
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ sizes: list[int],
318
+ dropout: float = 0.0,
319
+ batch_norm: bool = True,
320
+ layer_norm: bool = False,
321
+ last_layer_act: str = "linear",
322
+ ) -> None:
323
+ """
324
+ Args:
325
+ sizes: size of layers.
326
+ dropout: Dropout probability.
327
+ batch_norm: specifies if batch norm should be applied.
328
+ layer_norm: specifies if layer norm should be applied, as commonly used in Transformers.
329
+ last_layer_act: activation function of last layer.
330
+ """
331
+ super().__init__()
332
+ layers = []
333
+ for s in range(len(sizes) - 1):
334
+ layers += [
335
+ torch.nn.Linear(sizes[s], sizes[s + 1]),
336
+ torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 2 else None,
337
+ torch.nn.LayerNorm(sizes[s + 1]) if layer_norm and s < len(sizes) - 2 and not batch_norm else None,
338
+ torch.nn.ReLU(),
339
+ torch.nn.Dropout(dropout) if s < len(sizes) - 2 else None,
340
+ ]
341
+
342
+ layers = [layer for layer in layers if layer is not None][:-1]
343
+ self.activation = last_layer_act
344
+ if self.activation == "linear":
345
+ pass
346
+ elif self.activation == "ReLU":
347
+ self.relu = torch.nn.ReLU()
348
+ else:
349
+ raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
350
+
351
+ self.network = torch.nn.Sequential(*layers)
352
+
353
+ self.network.apply(init_weights)
354
+
355
+ self.sizes = sizes
356
+ self.batch_norm = batch_norm
357
+ self.layer_norm = layer_norm
358
+ self.last_layer_act = last_layer_act
359
+
360
+ def forward(self, x) -> torch.Tensor:
361
+ if self.activation == "ReLU":
362
+ return self.relu(self.network(x))
363
+ return self.network(x)
364
+
365
+ def embedding(self, x) -> torch.Tensor:
366
+ for layer in self.network[:-1]:
367
+ x = layer(x)
368
+ return x
369
+
370
+
371
+ def init_weights(m):
372
+ if isinstance(m, torch.nn.Linear):
373
+ torch.nn.init.kaiming_uniform_(m.weight)
374
+ m.bias.data.fill_(0.01)
375
+
376
+
377
+ class PLDataset(Dataset):
378
+ """
379
+ Dataset for perturbation classification.
380
+ Needed for training a model that classifies the perturbed cells and takes as perturbation embedding the second to last layer.
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ adata: np.array,
386
+ target_col: str = "perturbations",
387
+ label_col: str = "perturbations",
388
+ layer_key: str = None,
389
+ ):
390
+ """
391
+ Args:
392
+ adata: AnnData object with observations and labels.
393
+ target_col: key with the perturbation labels numerically encoded.
394
+ label_col: key with the perturbation labels.
395
+ layer_key: key of the layer to be used as data, otherwise .X
396
+ """
397
+
398
+ if layer_key:
399
+ self.data = adata.layers[layer_key]
400
+ else:
401
+ self.data = adata.X
402
+
403
+ self.labels = adata.obs[target_col]
404
+ self.pert_labels = adata.obs[label_col]
405
+
406
+ def __len__(self):
407
+ return self.data.shape[0]
408
+
409
+ def __getitem__(self, idx):
410
+ """Returns a sample and corresponding perturbations applied (labels)"""
411
+ sample = self.data[idx].A.squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
412
+ num_label = self.labels.iloc[idx]
413
+ str_label = self.pert_labels.iloc[idx]
414
+
415
+ return sample, num_label, str_label
416
+
417
+
418
+ class PerturbationClassifier(pl.LightningModule):
419
+ def __init__(
420
+ self,
421
+ model: torch.nn.Module,
422
+ batch_size: int,
423
+ layers: list = [512], # noqa
424
+ dropout: float = 0.0,
425
+ batch_norm: bool = True,
426
+ layer_norm: bool = False,
427
+ last_layer_act: str = "linear",
428
+ lr=1e-4,
429
+ seed=42,
430
+ ):
431
+ """
432
+ Args:
433
+ model: model to be trained
434
+ batch_size: batch size
435
+ layers: list of layers of the MLP
436
+ dropout: dropout probability
437
+ batch_norm: whether to apply batch norm
438
+ layer_norm: whether to apply layer norm
439
+ last_layer_act: activation function of last layer
440
+ lr: learning rate
441
+ seed: random seed
442
+ """
443
+ super().__init__()
444
+ self.batch_size = batch_size
445
+ self.save_hyperparameters()
446
+ if model:
447
+ self.net = model
448
+ else:
449
+ self._create_model()
450
+
451
+ def _create_model(self):
452
+ self.net = MLP(
453
+ sizes=self.hparams.layers,
454
+ dropout=self.hparams.dropout,
455
+ batch_norm=self.hparams.batch_norm,
456
+ layer_norm=self.hparams.layer_norm,
457
+ last_layer_act=self.hparams.last_layer_act,
458
+ )
459
+
460
+ def forward(self, x):
461
+ x = self.net(x)
462
+ return x
463
+
464
+ def configure_optimizers(self):
465
+ optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.1)
466
+
467
+ return optimizer
468
+
469
+ def training_step(self, batch, batch_idx):
470
+ x, y, _ = batch
471
+ x = x.to(torch.float32)
472
+
473
+ y_hat = self.forward(x)
474
+
475
+ y = torch.argmax(y, dim=1)
476
+ y_hat = y_hat.squeeze()
477
+
478
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
479
+ self.log("train_loss", loss, prog_bar=True, batch_size=self.batch_size)
480
+
481
+ return loss
482
+
483
+ def validation_step(self, batch, batch_idx):
484
+ x, y, _ = batch
485
+ x = x.to(torch.float32)
486
+
487
+ y_hat = self.forward(x)
488
+
489
+ y = torch.argmax(y, dim=1)
490
+ y_hat = y_hat.squeeze()
491
+
492
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
493
+ self.log("val_loss", loss, prog_bar=True, batch_size=self.batch_size)
494
+
495
+ return loss
496
+
497
+ def test_step(self, batch, batch_idx):
498
+ x, y, _ = batch
499
+ x = x.to(torch.float32)
500
+
501
+ y_hat = self.forward(x)
502
+
503
+ y = torch.argmax(y, dim=1)
504
+ y_hat = y_hat.squeeze()
505
+
506
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
507
+ self.log("test_loss", loss, prog_bar=True, batch_size=self.batch_size)
508
+
509
+ return loss
510
+
511
+ def embedding(self, x):
512
+ """
513
+ Inputs:
514
+ x: Input features of shape [Batch, SeqLen, 1]
515
+ """
516
+ x = self.net.embedding(x)
517
+ return x
518
+
519
+ def get_embeddings(self, batch):
520
+ x, _, y = batch
521
+ x = x.to(torch.float32)
522
+
523
+ embedding = self.embedding(x)
524
+ return embedding, y