drevalpy 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. drevalpy/__init__.py +0 -0
  2. drevalpy/datasets/__init__.py +6 -0
  3. drevalpy/datasets/ccle.py +27 -0
  4. drevalpy/datasets/dataset.py +464 -0
  5. drevalpy/datasets/gdsc1.py +27 -0
  6. drevalpy/datasets/gdsc2.py +10 -0
  7. drevalpy/evaluation.py +77 -0
  8. drevalpy/experiment.py +747 -0
  9. drevalpy/models/Baselines/__init__.py +0 -0
  10. drevalpy/models/Baselines/elastic_net_model.py +74 -0
  11. drevalpy/models/Baselines/hyperparameters.yaml +57 -0
  12. drevalpy/models/Baselines/naive_pred.py +180 -0
  13. drevalpy/models/Baselines/random_forest.py +79 -0
  14. drevalpy/models/Baselines/singledrug_random_forest.py +28 -0
  15. drevalpy/models/Baselines/svm.py +72 -0
  16. drevalpy/models/DrugRegNet/DrugRegNetModel.py +109 -0
  17. drevalpy/models/DrugRegNet/__init__.py +0 -0
  18. drevalpy/models/MOLI/__init__.py +0 -0
  19. drevalpy/models/MOLI/moli_model.py +65 -0
  20. drevalpy/models/SimpleNeuralNetwork/__init__.py +0 -0
  21. drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml +14 -0
  22. drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +129 -0
  23. drevalpy/models/SimpleNeuralNetwork/utils.py +149 -0
  24. drevalpy/models/__init__.py +31 -0
  25. drevalpy/models/drp_model.py +293 -0
  26. drevalpy/models/utils.py +42 -0
  27. drevalpy/moli_wrapper.py +110 -0
  28. drevalpy/utils.py +289 -0
  29. drevalpy/visualization/__init__.py +0 -0
  30. drevalpy/visualization/corr_comp_scatter.py +203 -0
  31. drevalpy/visualization/create_report.py +387 -0
  32. drevalpy/visualization/heatmap.py +142 -0
  33. drevalpy/visualization/regression_slider_plot.py +116 -0
  34. drevalpy/visualization/scatter_eval_models.py +154 -0
  35. drevalpy/visualization/single_model_regression.py +110 -0
  36. drevalpy/visualization/style_utils/LCO.png +0 -0
  37. drevalpy/visualization/style_utils/LDO.png +0 -0
  38. drevalpy/visualization/style_utils/LPO.png +0 -0
  39. drevalpy/visualization/style_utils/favicon.png +0 -0
  40. drevalpy/visualization/style_utils/index_layout.html +90 -0
  41. drevalpy/visualization/style_utils/nf-core-drugresponseeval_logo_light.png +0 -0
  42. drevalpy/visualization/style_utils/page_layout.html +104 -0
  43. drevalpy/visualization/utils.py +312 -0
  44. drevalpy/visualization/violin.py +293 -0
  45. drevalpy-1.0.0.dist-info/LICENSE +674 -0
  46. drevalpy-1.0.0.dist-info/METADATA +20 -0
  47. drevalpy-1.0.0.dist-info/RECORD +49 -0
  48. drevalpy-1.0.0.dist-info/WHEEL +5 -0
  49. drevalpy-1.0.0.dist-info/top_level.txt +1 -0
drevalpy/__init__.py ADDED
File without changes
@@ -0,0 +1,6 @@
1
+ __all__ = ["GDSC1", "GDSC2", "CCLE", "RESPONSE_DATASET_FACTORY"]
2
+ from .gdsc1 import GDSC1
3
+ from .gdsc2 import GDSC2
4
+ from .ccle import CCLE
5
+
6
+ RESPONSE_DATASET_FACTORY = {"GDSC1": GDSC1, "GDSC2": GDSC2, "CCLE": CCLE}
@@ -0,0 +1,27 @@
1
+ from drevalpy.datasets.dataset import DrugResponseDataset
2
+ import pandas as pd
3
+ import os
4
+
5
+
6
+ class CCLE(DrugResponseDataset):
7
+ """
8
+ CCLE dataset.
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ path_data: str = "data",
14
+ file_name: str = "response_CCLE.csv",
15
+ dataset_name: str = "CCLE",
16
+ ):
17
+ """
18
+ :param path: path to the dataset
19
+ """
20
+ path = os.path.join(path_data, dataset_name, file_name)
21
+ response_data = pd.read_csv(path)
22
+ super().__init__(
23
+ response=response_data["LN_IC50"].values,
24
+ cell_line_ids=response_data["CELL_LINE_NAME"].values,
25
+ drug_ids=response_data["DRUG_NAME"].values,
26
+ dataset_name=dataset_name,
27
+ )
@@ -0,0 +1,464 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+ import numpy as np
4
+ from numpy.typing import ArrayLike
5
+ import pandas as pd
6
+ from ..utils import leave_pair_out_cv, leave_group_out_cv
7
+ import copy
8
+ from sklearn.base import TransformerMixin
9
+
10
+
11
+ class Dataset(ABC):
12
+ """
13
+ Abstract wrapper class for datasets.
14
+ """
15
+
16
+ @abstractmethod
17
+ def load(self):
18
+ """
19
+ Loads the dataset from data.
20
+ """
21
+ pass
22
+
23
+ @abstractmethod
24
+ def save(self):
25
+ """
26
+ Saves the dataset to data.
27
+ """
28
+ pass
29
+
30
+
31
+ class DrugResponseDataset(Dataset):
32
+ """
33
+ Drug response dataset.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ response: Optional[ArrayLike] = None,
39
+ cell_line_ids: Optional[ArrayLike] = None,
40
+ drug_ids: Optional[ArrayLike] = None,
41
+ predictions: Optional[ArrayLike] = None,
42
+ dataset_name: Optional[str] = None,
43
+ *args,
44
+ **kwargs,
45
+ ):
46
+ """
47
+ Initializes the drug response dataset.
48
+ :param response: drug response values per cell line and drug
49
+ :param cell_line_ids: cell line IDs
50
+ :param drug_ids: drug IDs
51
+ :param predictions: optional. Predicted drug response values per cell line and drug
52
+ :param dataset_name: optional. Name of the dataset
53
+
54
+ Variables:
55
+ response: drug response values per cell line and drug
56
+ cell_line_ids: cell line IDs
57
+ drug_ids: drug IDs
58
+ predictions: optional. Predicted drug response values per cell line and drug
59
+ dataset_name: optional. Name of the dataset
60
+ """
61
+ super(DrugResponseDataset, self).__init__()
62
+ if response is not None:
63
+ self.response = np.array(response)
64
+ self.cell_line_ids = np.array(cell_line_ids)
65
+ self.drug_ids = np.array(drug_ids)
66
+ assert len(self.response) == len(
67
+ self.cell_line_ids
68
+ ), "response and cell_line_ids have different lengths"
69
+ assert len(self.response) == len(
70
+ self.drug_ids
71
+ ), "response and drug_ids/cell_line_ids have different lengths"
72
+ self.dataset_name = dataset_name
73
+ else:
74
+ self.response = response
75
+ self.cell_line_ids = cell_line_ids
76
+ self.drug_ids = drug_ids
77
+ self.dataset_name = dataset_name
78
+
79
+ if predictions is not None:
80
+ self.predictions = np.array(predictions)
81
+ assert len(self.predictions) == len(
82
+ self.response
83
+ ), "predictions and response have different lengths"
84
+ else:
85
+ self.predictions = None
86
+
87
+ def __len__(self):
88
+ return len(self.response)
89
+
90
+ def __str__(self):
91
+ if len(self.response) > 3:
92
+ string = f"DrugResponseDataset: CLs {self.cell_line_ids[:3]}...; Drugs {self.drug_ids[:3]}...; Response {self.response[:3]}..."
93
+ else:
94
+ string = f"DrugResponseDataset: CLs {self.cell_line_ids}; Drugs {self.drug_ids}; Response {self.response}"
95
+ if self.predictions is not None:
96
+ if len(self.predictions) > 3:
97
+ string += f"; Predictions {self.predictions[:3]}..."
98
+ else:
99
+ string += f"; Predictions {self.predictions}"
100
+ return string
101
+
102
+ def load(self, path: str):
103
+ """
104
+ Loads the drug response dataset from data.
105
+ """
106
+ data = pd.read_csv(path)
107
+ self.response = data["response"].values
108
+ self.cell_line_ids = data["cell_line_ids"].values
109
+ self.drug_ids = data["drug_ids"].values
110
+ if "predictions" in data.columns:
111
+ self.predictions = data["predictions"].values
112
+
113
+ def save(self, path: str):
114
+ """
115
+ Saves the drug response dataset to data.
116
+ """
117
+ out = pd.DataFrame(
118
+ {
119
+ "cell_line_ids": self.cell_line_ids,
120
+ "drug_ids": self.drug_ids,
121
+ "response": self.response,
122
+ }
123
+ )
124
+ if self.predictions is not None:
125
+ out["predictions"] = self.predictions
126
+ out.to_csv(path, index=False)
127
+
128
+ def add_rows(self, other: "DrugResponseDataset") -> None:
129
+ """
130
+ Adds rows from another dataset.
131
+ :other: other dataset
132
+ """
133
+ self.response = np.concatenate([self.response, other.response])
134
+ self.cell_line_ids = np.concatenate([self.cell_line_ids, other.cell_line_ids])
135
+ self.drug_ids = np.concatenate([self.drug_ids, other.drug_ids])
136
+
137
+ if self.predictions is not None and other.predictions is not None:
138
+ self.predictions = np.concatenate([self.predictions, other.predictions])
139
+
140
+ def remove_nan_responses(self) -> None:
141
+ """
142
+ Removes rows with NaN values in the repsonse
143
+ """
144
+ mask = np.isnan(self.response)
145
+ self.response = self.response[~mask]
146
+ self.cell_line_ids = self.cell_line_ids[~mask]
147
+ self.drug_ids = self.drug_ids[~mask]
148
+ if self.predictions is not None:
149
+ self.predictions = self.predictions[~mask]
150
+
151
+ def shuffle(self, random_state: int = 42) -> None:
152
+ """
153
+ Shuffles the dataset.
154
+ :random_state: random state
155
+ """
156
+ indices = np.arange(len(self.response))
157
+ np.random.seed(random_state)
158
+ np.random.shuffle(indices)
159
+ self.response = self.response[indices]
160
+ self.cell_line_ids = self.cell_line_ids[indices]
161
+ self.drug_ids = self.drug_ids[indices]
162
+ if self.predictions is not None:
163
+ self.predictions = self.predictions[indices]
164
+
165
+ def remove_drugs(self, drugs_to_remove: Union[str, list]) -> None:
166
+ """
167
+ Removes drugs from the dataset.
168
+ :drugs_to_remove: name of drug or list of names of multiple drugs to remove
169
+ """
170
+ if isinstance(drugs_to_remove, str):
171
+ drugs_to_remove = [drugs_to_remove]
172
+
173
+ mask = [drug not in drugs_to_remove for drug in self.drug_ids]
174
+ self.drug_ids = self.drug_ids[mask]
175
+ self.cell_line_ids = self.cell_line_ids[mask]
176
+ self.response = self.response[mask]
177
+
178
+ def remove_cell_lines(self, cell_lines_to_remove: Union[str, list]) -> None:
179
+ """
180
+ Removes cell lines from the dataset.
181
+ :cell_lines_to_remove: name of cell line or list of names of multiple cell lines to remove
182
+ """
183
+ if isinstance(cell_lines_to_remove, str):
184
+ cell_lines_to_remove = [cell_lines_to_remove]
185
+
186
+ mask = [
187
+ cell_line not in cell_lines_to_remove for cell_line in self.cell_line_ids
188
+ ]
189
+ self.drug_ids = self.drug_ids[mask]
190
+ self.cell_line_ids = self.cell_line_ids[mask]
191
+ self.response = self.response[mask]
192
+
193
+ def remove_rows(self, indices: ArrayLike) -> None:
194
+ """
195
+ Removes rows from the dataset.
196
+ :indices: indices of rows to remove
197
+ """
198
+ self.drug_ids = np.delete(self.drug_ids, indices)
199
+ self.cell_line_ids = np.delete(self.cell_line_ids, indices)
200
+ self.response = np.delete(self.response, indices)
201
+ if self.predictions is not None:
202
+ self.predictions = np.delete(self.predictions, indices)
203
+
204
+ def reduce_to(
205
+ self, cell_line_ids: Optional[ArrayLike], drug_ids: Optional[ArrayLike]
206
+ ) -> None:
207
+ """
208
+ Removes all rows which contain a cell_line not in cell_line_ids or a drug not in drug_ids
209
+ :cell_line_ids: cell line IDs or None to keep all cell lines
210
+ :drug_ids: drug IDs or None to keep all cell lines
211
+ """
212
+ if drug_ids is not None:
213
+ self.remove_drugs(list(set(self.drug_ids) - set(drug_ids)))
214
+
215
+ if cell_line_ids is not None:
216
+ self.remove_cell_lines(list(set(self.cell_line_ids) - set(cell_line_ids)))
217
+
218
+ def split_dataset(
219
+ self,
220
+ n_cv_splits,
221
+ mode,
222
+ split_validation=True,
223
+ split_early_stopping=True,
224
+ validation_ratio=0.1,
225
+ random_state=42,
226
+ ) -> List[dict]:
227
+ """
228
+ Splits the dataset into training, validation and test sets for crossvalidation
229
+ :param mode: split mode (LPO=Leave-random-Pairs-Out, LCO=Leave-Cell-line-Out, LDO=Leave-Drug-Out)
230
+ :return: training, validation and test sets
231
+ """
232
+
233
+ cell_line_ids = self.cell_line_ids
234
+ drug_ids = self.drug_ids
235
+ response = self.response
236
+
237
+ if mode == "LPO":
238
+ cv_splits = leave_pair_out_cv(
239
+ n_cv_splits,
240
+ response,
241
+ cell_line_ids,
242
+ drug_ids,
243
+ split_validation,
244
+ validation_ratio,
245
+ random_state,
246
+ self.dataset_name,
247
+ )
248
+
249
+ elif mode in ["LCO", "LDO"]:
250
+ group = "cell_line" if mode == "LCO" else "drug"
251
+ cv_splits = leave_group_out_cv(
252
+ group=group,
253
+ n_cv_splits=n_cv_splits,
254
+ response=response,
255
+ cell_line_ids=cell_line_ids,
256
+ drug_ids=drug_ids,
257
+ split_validation=split_validation,
258
+ validation_ratio=validation_ratio,
259
+ random_state=random_state,
260
+ dataset_name=self.dataset_name,
261
+ )
262
+ else:
263
+ raise ValueError(
264
+ f"Unknown split mode '{mode}'. Choose from 'LPO', 'LCO', 'LDO'."
265
+ )
266
+
267
+ if split_validation and split_early_stopping:
268
+ for split in cv_splits:
269
+ validation_es, early_stopping = split_early_stopping_data(
270
+ split["validation"], test_mode=mode
271
+ )
272
+ split["validation_es"] = validation_es
273
+ split["early_stopping"] = early_stopping
274
+ self.cv_splits = cv_splits
275
+ return cv_splits
276
+
277
+ def copy(self):
278
+ """
279
+ Returns a copy of the drug response dataset.
280
+ """
281
+ return DrugResponseDataset(
282
+ response=copy.deepcopy(self.response),
283
+ cell_line_ids=copy.deepcopy(self.cell_line_ids),
284
+ drug_ids=copy.deepcopy(self.drug_ids),
285
+ predictions=copy.deepcopy(self.predictions),
286
+ dataset_name=self.dataset_name,
287
+ )
288
+
289
+ def __hash__(self):
290
+ return hash(
291
+ (
292
+ self.dataset_name,
293
+ tuple(self.cell_line_ids),
294
+ tuple(self.drug_ids),
295
+ tuple(self.response),
296
+ tuple(self.predictions) if self.predictions is not None else None,
297
+ )
298
+ )
299
+
300
+ def mask(self, mask: List[bool]) -> None:
301
+ """
302
+ Masks the dataset.
303
+ :mask: boolean mask
304
+ """
305
+ self.response = self.response[mask]
306
+ self.cell_line_ids = self.cell_line_ids[mask]
307
+ self.drug_ids = self.drug_ids[mask]
308
+ if self.predictions is not None:
309
+ self.predictions = self.predictions[mask]
310
+
311
+ def transform(self, response_transformation: TransformerMixin) -> None:
312
+ """Apply transformation to the response data and prediction data of the dataset."""
313
+ self.response = response_transformation.transform(
314
+ self.response.reshape(-1, 1)
315
+ ).squeeze()
316
+ if self.predictions is not None:
317
+ self.predictions = response_transformation.transform(
318
+ self.predictions.reshape(-1, 1)
319
+ ).squeeze()
320
+
321
+ def fit_transform(self, response_transformation: TransformerMixin) -> None:
322
+ """Fit and transform the response data and prediction data of the dataset."""
323
+ response_transformation.fit(self.response.reshape(-1, 1)).squeeze()
324
+ self.transform(response_transformation)
325
+
326
+ def inverse_transform(self, response_transformation: TransformerMixin) -> None:
327
+ """Inverse transform the response data and prediction data of the dataset."""
328
+ self.response = response_transformation.inverse_transform(
329
+ self.response.reshape(-1, 1)
330
+ ).squeeze()
331
+ if self.predictions is not None:
332
+ self.predictions = response_transformation.inverse_transform(
333
+ self.predictions.reshape(-1, 1)
334
+ ).squeeze()
335
+
336
+
337
+ def split_early_stopping_data(
338
+ validation_dataset: DrugResponseDataset, test_mode: str
339
+ ) -> Tuple[DrugResponseDataset, DrugResponseDataset]:
340
+
341
+ validation_dataset.shuffle(random_state=42)
342
+ cv_v = validation_dataset.split_dataset(
343
+ n_cv_splits=4,
344
+ mode=test_mode,
345
+ split_validation=False,
346
+ split_early_stopping=False,
347
+ random_state=42,
348
+ )
349
+ # take the first fold of a 4 cv as the split ie. 3/4 for validation and 1/4 for early stopping
350
+ validation_dataset = cv_v[0]["train"]
351
+ early_stopping_dataset = cv_v[0]["test"]
352
+ return validation_dataset, early_stopping_dataset
353
+
354
+
355
+ class FeatureDataset(Dataset):
356
+ """
357
+ Class for feature datasets.
358
+ """
359
+
360
+ def __init__(self, features: Dict[str, Dict[str, np.ndarray]], *args, **kwargs):
361
+ """
362
+ Initializes the feature dataset.
363
+ :features: dictionary of features, key: drug ID, value: Dict of feature views, key: feature name, value: feature vector
364
+ """
365
+ super(FeatureDataset, self).__init__()
366
+ self.features = features
367
+ self.view_names = self.get_view_names()
368
+ self.identifiers = self.get_ids()
369
+
370
+ def save(self):
371
+ """
372
+ Saves the feature dataset to data.
373
+ """
374
+ raise NotImplementedError("save method not implemented")
375
+
376
+ def load(self):
377
+ """
378
+ Loads the feature dataset from data.
379
+ """
380
+ raise NotImplementedError("load method not implemented")
381
+
382
+ def randomize_features(
383
+ self, views_to_randomize: Union[str, list], randomization_type: str
384
+ ) -> None:
385
+ """
386
+ Randomizes the feature vectors.
387
+ :views_to_randomize: name of feature view or list of names of multiple feature views to randomize. The other views are not randomized.
388
+ :randomization_type: randomization type (permutation, gaussian, zeroing)
389
+ """
390
+ if isinstance(views_to_randomize, str):
391
+ views_to_randomize = [views_to_randomize]
392
+
393
+ if randomization_type == "permutation":
394
+ # Get the entity names
395
+ identifiers = self.get_ids()
396
+
397
+ # Permute the specified views for each entity (= cell line or drug)
398
+ self.features = {
399
+ entity: {
400
+ view: (
401
+ self.features[entity][view]
402
+ if view not in views_to_randomize
403
+ else self.features[other_entity][view]
404
+ )
405
+ for view in self.view_names
406
+ }
407
+ for entity, other_entity in zip(
408
+ identifiers, np.random.permutation(identifiers)
409
+ )
410
+ }
411
+
412
+ elif randomization_type == "gaussian":
413
+ for view in views_to_randomize:
414
+ for identifier in self.get_ids():
415
+ self.features[identifier][view] = np.random.normal(
416
+ self.features[identifier][view].mean(),
417
+ self.features[identifier][view].std(),
418
+ self.features[identifier][view].shape,
419
+ )
420
+ elif randomization_type == "zeroing":
421
+ for view in views_to_randomize:
422
+ for identifier in self.get_ids():
423
+ self.features[identifier][view] = np.zeros(
424
+ self.features[identifier][view].shape
425
+ )
426
+ else:
427
+ raise ValueError(
428
+ f"Unknown randomization mode '{randomization_type}'. Choose from 'permutation', 'gaussian', 'zeroing'."
429
+ )
430
+
431
+ def get_ids(self):
432
+ """
433
+ returns drug ids of the dataset
434
+ """
435
+ return list(self.features.keys())
436
+
437
+ def get_view_names(self):
438
+ """
439
+ returns feature view names
440
+ """
441
+ return list(self.features[list(self.features.keys())[0]].keys())
442
+
443
+ def get_feature_matrix(self, view: str, identifiers: List[str]) -> np.ndarray:
444
+ """
445
+ Returns the feature matrix for the given view.
446
+ :param view: view name
447
+ :param identifiers: list of identifiers (cell lines oder drugs)
448
+ :return: feature matrix
449
+ """
450
+ assert view in self.view_names, f"View '{view}' not in in the FeatureDataset."
451
+ missing_identifiers = {
452
+ id_ for id_ in identifiers if id_ not in self.identifiers
453
+ }
454
+ assert (
455
+ not missing_identifiers
456
+ ), f"{len(missing_identifiers)} of {len(np.unique(identifiers))} ids are not in the FeatureDataset. Missing ids: {missing_identifiers}"
457
+
458
+ return np.stack([self.features[id_][view] for id_ in identifiers], axis=0)
459
+
460
+ def copy(self):
461
+ """
462
+ Returns a copy of the feature dataset.
463
+ """
464
+ return FeatureDataset(features=copy.deepcopy(self.features))
@@ -0,0 +1,27 @@
1
+ from drevalpy.datasets.dataset import DrugResponseDataset
2
+ import pandas as pd
3
+ import os
4
+
5
+
6
+ class GDSC1(DrugResponseDataset):
7
+ """
8
+ GDSC1 dataset.
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ path_data: str = "data",
14
+ file_name: str = "response_GDSC1.csv",
15
+ dataset_name: str = "GDSC1",
16
+ ):
17
+ """
18
+ :param path: path to the dataset
19
+ """
20
+ path = os.path.join(path_data, dataset_name, file_name)
21
+ response_data = pd.read_csv(path)
22
+ super().__init__(
23
+ response=response_data["LN_IC50"].values,
24
+ cell_line_ids=response_data["CELL_LINE_NAME"].values,
25
+ drug_ids=response_data["DRUG_NAME"].values,
26
+ dataset_name=dataset_name,
27
+ )
@@ -0,0 +1,10 @@
1
+ from drevalpy.datasets.gdsc1 import GDSC1
2
+
3
+
4
+ class GDSC2(GDSC1):
5
+ """
6
+ GDSC2 dataset.
7
+ """
8
+
9
+ def __init__(self, path_data: str = "data/", file_name: str = "response_GDSC2.csv"):
10
+ super().__init__(path_data=path_data, file_name=file_name, dataset_name="GDSC2")
drevalpy/evaluation.py ADDED
@@ -0,0 +1,77 @@
1
+ from .datasets.dataset import DrugResponseDataset
2
+ from typing import Union, List
3
+ import sklearn.metrics as metrics
4
+ from .utils import pearson, spearman, kendall, partial_correlation
5
+ import pandas as pd
6
+ import numpy as np
7
+
8
+ AVAILABLE_METRICS = {
9
+ "MSE": metrics.mean_squared_error,
10
+ "RMSE": metrics.root_mean_squared_error,
11
+ "MAE": metrics.mean_absolute_error,
12
+ "R^2": metrics.r2_score,
13
+ "Pearson": pearson,
14
+ "Spearman": spearman,
15
+ "Kendall": kendall,
16
+ "Partial_Correlation": partial_correlation,
17
+ }
18
+ MINIMIZATION_METRICS = ["MSE", "RMSE", "MAE"]
19
+ MAXIMIZATION_METRICS = ["R^2", "Pearson", "Spearman", "Kendall", "Partial_Correlation"]
20
+
21
+
22
+ def get_mode(metric: str):
23
+ if metric in MINIMIZATION_METRICS:
24
+ mode = "min"
25
+ elif metric in MAXIMIZATION_METRICS:
26
+ mode = "max"
27
+ else:
28
+ raise ValueError(
29
+ f"Invalid metric: {metric}. Need to add metric to MINIMIZATION_METRICS or MAXIMIZATION_METRICS?"
30
+ )
31
+ return mode
32
+
33
+
34
+ def evaluate(dataset: DrugResponseDataset, metric: Union[List[str], str]):
35
+ """
36
+ Evaluates the model on the given dataset.
37
+ :param dataset: dataset to evaluate on
38
+ :param metric: evaluation metric(s) (one or a list of "mse", "rmse", "mae", "r2", "pearson", "spearman", "kendall", "partial_correlation")
39
+ :return: evaluation metric
40
+ """
41
+ if isinstance(metric, str):
42
+ metric = [metric]
43
+ predictions = dataset.predictions
44
+ response = dataset.response
45
+ results = {}
46
+ for m in metric:
47
+ assert (
48
+ m in AVAILABLE_METRICS
49
+ ), f"invalid metric {m}. Available: {list(AVAILABLE_METRICS.keys())}"
50
+ if len(response) < 2:
51
+ results[m] = float(np.nan)
52
+ else:
53
+ if m == "Partial_Correlation":
54
+ results[m] = float(
55
+ AVAILABLE_METRICS[m](
56
+ y_pred=predictions,
57
+ y_true=response,
58
+ cell_line_ids=dataset.cell_line_ids,
59
+ drug_ids=dataset.drug_ids,
60
+ )
61
+ )
62
+ else:
63
+ results[m] = float(
64
+ AVAILABLE_METRICS[m](y_pred=predictions, y_true=response)
65
+ )
66
+
67
+ return results
68
+
69
+
70
+ def visualize_results(results: pd.DataFrame, mode: Union[List[str], str]):
71
+ """
72
+ Visualizes the model on the given dataset.
73
+ :param dataset: dataset to evaluate on
74
+ :mode
75
+ :return: evaluation metric
76
+ """
77
+ raise NotImplementedError("visualize not implemented yet")