vpop-calibration 2.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,420 @@
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from functools import reduce
6
+
7
+ from ..utils import device
8
+
9
+
10
+ class TrainingDataSet:
11
+ def __init__(
12
+ self,
13
+ training_df: pd.DataFrame,
14
+ descriptors: list[str],
15
+ training_proportion: float = 0.7,
16
+ log_lower_limit: float = 1e-10,
17
+ log_inputs: list[str] = [],
18
+ log_outputs: list[str] = [],
19
+ data_already_normalized: bool = False,
20
+ ):
21
+ """Instantiate a TrainingDataSet container
22
+
23
+ The data container is used to process training data, normalize inputs and outputs, and provide utilitary methods to transform inputs and outputs for a PyTorch model. In particular, it transforms output names and protocol arms into individual tasks by cartesian product, and it provides mappings between tasks and protcol/output.
24
+
25
+ Args:
26
+ training_df (pd.DataFrame): the training data. Should contain the columns [`id`, `output_name`, `protocol_arm`, *descriptors, `value`]
27
+ descriptors (list[str]): the names of the columns of `training_df` which correspond to descriptors on which to train the model
28
+ training_proportion (float, optional): Proportion of patients to be used as training vs. validation. Defaults to 0.7.
29
+ log_lower_limit(float): epsilon value that is added to all rescaled value to avoid numerical errors when log-scaling variables
30
+ log_inputs (list[str]): the list of parameter inputs which should be rescaled to log when fed to the GP. Avoid adding time here, or any parameter that takes 0 as a value.
31
+ log_outputs (list[str]): list of model outptus which should be rescaled to log
32
+ data_already_normalized(bool): set to True if the data set is preprocessed and no normalization / scaling is required
33
+ """
34
+
35
+ # Process the supplied data set
36
+ self.full_df_raw = training_df
37
+
38
+ declared_columns = self.full_df_raw.columns.to_list()
39
+ # Input validation
40
+ if not ("id" in declared_columns):
41
+ raise ValueError("Training data should contain an `id` column.")
42
+ if not ("output_name" in declared_columns):
43
+ raise ValueError("Training data should contain an `output_name` column.")
44
+ if not ("value" in declared_columns):
45
+ raise ValueError("Training data should contain a `value` column.")
46
+ if not set(descriptors) <= set(declared_columns):
47
+ raise ValueError(
48
+ f"The provided inputs are not declared in the data set: {descriptors}."
49
+ )
50
+
51
+ self.parameter_names = descriptors
52
+ self.training_proportion = training_proportion
53
+ self.nb_parameters = len(self.parameter_names)
54
+ self.data_already_normalized = data_already_normalized
55
+ if not ("protocol_arm" in declared_columns):
56
+ self.full_df_raw["protocol_arm"] = "identity"
57
+ self.protocol_arms = self.full_df_raw["protocol_arm"].unique().tolist()
58
+ self.nb_protocol_arms = len(self.protocol_arms)
59
+ self.output_names = self.full_df_raw["output_name"].unique().tolist()
60
+ self.nb_outputs = len(self.output_names)
61
+ self.log_lower_limit = log_lower_limit
62
+ self.log_inputs = log_inputs
63
+ self.log_inputs_indices = [
64
+ self.parameter_names.index(p) for p in self.log_inputs
65
+ ]
66
+ self.log_outputs = log_outputs
67
+
68
+ # Ensure input df has a consistent shape (and remove potential extra columns)
69
+ self.full_df_raw = self.full_df_raw[
70
+ ["id"] + self.parameter_names + ["output_name", "protocol_arm", "value"]
71
+ ]
72
+
73
+ # Gather the list of patients in the training data
74
+ self.patients = self.full_df_raw["id"].unique()
75
+ self.nb_patients = self.patients.shape[0]
76
+
77
+ print(
78
+ f"Successfully loaded a training data set with {self.nb_patients} patients. The following outputs are available:\n{self.output_names}\n and the following protocol arms:\n{self.protocol_arms}"
79
+ )
80
+ # Construct the list of tasks, mapping from output name and protocol arm to task number
81
+ self.tasks: list[str] = [
82
+ output + "_" + protocol
83
+ for protocol in self.protocol_arms
84
+ for output in self.output_names
85
+ ]
86
+ self.nb_tasks = len(self.tasks)
87
+ # Map tasks to output names
88
+ self.task_to_output = {
89
+ output_name + "_" + protocol_arm: output_name
90
+ for output_name in self.output_names
91
+ for protocol_arm in self.protocol_arms
92
+ }
93
+ # Map task index to output index
94
+ self.task_idx_to_output_idx = {
95
+ self.tasks.index(k): self.output_names.index(v)
96
+ for k, v in self.task_to_output.items()
97
+ }
98
+ # Map task to protocol arm
99
+ self.task_to_protocol = {
100
+ output_name + "_" + protocol_arm: protocol_arm
101
+ for output_name in self.output_names
102
+ for protocol_arm in self.protocol_arms
103
+ }
104
+ # Map task index to protocol arm
105
+ self.task_idx_to_protocol = {
106
+ self.tasks.index(k): v for k, v in self.task_to_protocol.items()
107
+ }
108
+ # list tasks that should be rescaled to log
109
+ self.log_tasks = [
110
+ task for task in self.tasks if self.task_to_output[task] in self.log_outputs
111
+ ]
112
+ self.log_tasks_indices = [self.tasks.index(task) for task in self.log_tasks]
113
+
114
+ ## Data processing
115
+ # Pivot the data to the correct shape for GP training
116
+ self.full_df_reshaped = self.pivot_input_data(self.full_df_raw)
117
+
118
+ # Normalize the inputs and the outputs (only if required)
119
+ if self.data_already_normalized == True:
120
+ self.normalized_df = self.full_df_reshaped
121
+ else:
122
+ self.full_df_reshaped[self.log_inputs + self.log_tasks] = (
123
+ self.full_df_reshaped[self.log_inputs + self.log_tasks].apply(
124
+ lambda val: np.log(np.maximum(val, self.log_lower_limit))
125
+ )
126
+ )
127
+
128
+ self.normalized_df, mean, std = self.normalize_data(
129
+ self.full_df_reshaped, ["id"]
130
+ )
131
+ self.normalizing_input_mean, self.normalizing_input_std = (
132
+ torch.as_tensor(mean.loc[self.parameter_names].values, device=device),
133
+ torch.as_tensor(std.loc[self.parameter_names].values, device=device),
134
+ )
135
+ self.normalizing_output_mean, self.normalizing_output_std = (
136
+ torch.as_tensor(mean.loc[self.tasks].values, device=device),
137
+ torch.as_tensor(std.loc[self.tasks].values, device=device),
138
+ )
139
+
140
+ self.unnormalize_output_wide = torch.compile(self.unnormalize_output_wide_logic)
141
+
142
+ # Compute the number of patients for training
143
+ self.nb_patients_training = math.floor(
144
+ self.training_proportion * self.nb_patients
145
+ )
146
+ self.nb_patients_validation = self.nb_patients - self.nb_patients_training
147
+
148
+ if self.training_proportion != 1: # non-empty validation data set
149
+ if self.nb_patients_training == self.nb_patients:
150
+ raise ValueError(
151
+ "Training proportion too high for the number of sets of parameters: all would be used for training. Set training_proportion as 1 if that is your intention."
152
+ )
153
+
154
+ # Randomly mixing up patients
155
+ mixed_patients = np.random.permutation(self.patients)
156
+
157
+ self.training_patients = mixed_patients[: self.nb_patients_training]
158
+ self.validation_patients = mixed_patients[self.nb_patients_training :]
159
+
160
+ self.training_df_normalized: pd.DataFrame = self.normalized_df.loc[
161
+ self.normalized_df["id"].isin(self.training_patients)
162
+ ]
163
+ self.validation_df_normalized: pd.DataFrame = self.normalized_df.loc[
164
+ self.normalized_df["id"].isin(self.validation_patients)
165
+ ]
166
+ self.X_validation = torch.as_tensor(
167
+ self.validation_df_normalized[self.parameter_names].values,
168
+ device=device,
169
+ )
170
+ self.Y_validation = torch.as_tensor(
171
+ self.validation_df_normalized[self.tasks].values, device=device
172
+ )
173
+
174
+ else: # no validation data set provided
175
+ self.training_df_normalized = self.normalized_df
176
+ self.validation_df = None
177
+ self.X_validation = None
178
+ self.Y_validation = None
179
+
180
+ self.X_training: torch.Tensor = torch.as_tensor(
181
+ self.training_df_normalized[self.parameter_names].values, device=device
182
+ )
183
+ self.Y_training: torch.Tensor = torch.as_tensor(
184
+ self.training_df_normalized[self.tasks].values, device=device
185
+ )
186
+
187
+ def pivot_input_data(self, data_in: pd.DataFrame) -> pd.DataFrame:
188
+ """Pivot and reorder columns from a data frame to feed to the model
189
+
190
+ This method is used at initialization on the training data frame), and when plotting the model performance against existing data.
191
+
192
+ Args:
193
+ data_in (pd.DataFrame): Input data frame, containing the following columns
194
+ - `id`: patient id
195
+ - one column per descriptor, the same descriptors as self.parameter_names should be present
196
+ - `output_name`: the name of the output
197
+ - `protocol_arm`: the name of the protocol arm
198
+ - `value`: the observed value
199
+
200
+ Returns:
201
+ pd.DataFrame: A validated and pivotted dataframe with one column per task (`outputName_protocolArm`), and one row per observation
202
+ """
203
+
204
+ # util function to rename columns as `output_protocol`
205
+ def join_if_two(tup: list[str]) -> str:
206
+ if tup[0] == "":
207
+ return tup[1]
208
+ elif tup[1] == "":
209
+ return tup[0]
210
+ else:
211
+ return "_".join(tup)
212
+
213
+ # Pivot the data set
214
+ reshaped_df = data_in.pivot(
215
+ index=["id"] + self.parameter_names,
216
+ columns=["output_name", "protocol_arm"],
217
+ values="value",
218
+ ).reset_index()
219
+ reshaped_df.columns = list(map(join_if_two, reshaped_df.columns.to_series()))
220
+
221
+ assert set(reshaped_df.columns) == set(
222
+ ["id"] + self.parameter_names + self.tasks
223
+ ), "Incomplete training data set provided."
224
+
225
+ return reshaped_df
226
+
227
+ def normalize_data(
228
+ self, data_in: pd.DataFrame, ignore: list[str]
229
+ ) -> tuple[pd.DataFrame, pd.Series, pd.Series]:
230
+ """Normalize a data frame with respect to its mean and std, ignoring certain columns."""
231
+ selected_columns = data_in.columns.difference(ignore)
232
+ norm_data = data_in
233
+ mean = data_in[selected_columns].mean()
234
+ std = data_in[selected_columns].std()
235
+ norm_data[selected_columns] = (norm_data[selected_columns] - mean) / std
236
+ return norm_data, mean, std
237
+
238
+ def unnormalize_output_wide_logic(self, data: torch.Tensor) -> torch.Tensor:
239
+ """Unnormalize wide outputs (all tasks included) from the model."""
240
+ unnormalized = data * self.normalizing_output_std + self.normalizing_output_mean
241
+ unnormalized[:, self.log_tasks_indices] = torch.exp(
242
+ unnormalized[:, self.log_tasks_indices]
243
+ )
244
+
245
+ return unnormalized
246
+
247
+ def unnormalize_output_long(
248
+ self, data: torch.Tensor, task_indices: torch.LongTensor
249
+ ) -> torch.Tensor:
250
+ """Unnormalize long outputs (one row per task) from the model."""
251
+ rescaled_data = data
252
+ for task_idx in range(self.nb_tasks):
253
+ log_task = self.tasks[task_idx] in self.log_tasks
254
+ mask = torch.tensor(task_indices == task_idx, device=device).bool()
255
+ rescaled_data[mask] = (
256
+ rescaled_data[mask] * self.normalizing_output_std[task_idx]
257
+ + self.normalizing_output_mean[task_idx]
258
+ )
259
+ if log_task:
260
+ rescaled_data[mask] = torch.exp(rescaled_data[mask])
261
+ return rescaled_data
262
+
263
+ def normalize_inputs_tensor(self, inputs: torch.Tensor) -> torch.Tensor:
264
+ """Normalize new inputs provided to the model as a tensor. The columns of the input tensor should be the same as [self.descriptors]"""
265
+ X = inputs.to(device)
266
+ X[:, self.log_inputs_indices] = torch.log(X[:, self.log_inputs_indices])
267
+ mean = self.normalizing_input_mean
268
+ std = self.normalizing_input_std
269
+ norm_X = (X - mean) / std
270
+
271
+ return norm_X
272
+
273
+ def pivot_outputs_longer(
274
+ self, comparison_df: pd.DataFrame, Y: torch.Tensor, name: str
275
+ ) -> pd.DataFrame:
276
+ """Given wide outputs from a model and a comparison data frame (wide format), add the patient descriptors and reshape to a long format, with a `protocol_arm` and an `output_name` column."""
277
+ # Assuming Y is a wide output from the model, its columns are self.tasks
278
+ base_df = pd.DataFrame(
279
+ data=Y.cpu().detach().float().numpy(),
280
+ columns=self.tasks,
281
+ )
282
+ # The rows are assumed to correspond to the rows of the comparison data frame
283
+ base_df[["id"] + self.parameter_names] = comparison_df[
284
+ ["id"] + self.parameter_names
285
+ ]
286
+ # Pivot the data frame to a long format, separating the task names into protocol arm and output name
287
+ long_df = (
288
+ pd.wide_to_long(
289
+ df=base_df,
290
+ stubnames=self.output_names,
291
+ i=["id"] + self.parameter_names,
292
+ j="protocol_arm",
293
+ sep="_",
294
+ suffix=".*",
295
+ )
296
+ .reset_index()
297
+ .melt(
298
+ id_vars=["id"] + self.parameter_names + ["protocol_arm"],
299
+ value_vars=self.output_names,
300
+ var_name="output_name",
301
+ value_name=name,
302
+ )
303
+ )
304
+ return long_df
305
+
306
+ def get_data_inputs(
307
+ self, data_set: str | pd.DataFrame
308
+ ) -> tuple[torch.Tensor, pd.DataFrame, pd.DataFrame, bool]:
309
+ """Process a new data set of inputs and format them for a surrogate model to use
310
+
311
+ The new data may be incomplete. The function expects a long data table (unpivotted). This function is under-optimized, and should not be used during training.
312
+
313
+ Args:
314
+ data_set (str | pd.DataFrame):
315
+ Either "training" or "validation" OR
316
+ An input data frame on which to predict with the GP. Should contain the following columns
317
+ - `id`
318
+ - one column per descriptor
319
+ - `protocol_name`
320
+ - `value` (Optional)
321
+
322
+ Returns:
323
+ torch.Tensor: the inputs to provide to a surrogate model for predicting the same values as provided in the data set
324
+ pd.DataFrame: the processed data frame, in a wide format
325
+ pd.DataFrame: the original data frame, in a long format
326
+ bool: a flag, True if the value column is dummy in the output data frames
327
+ """
328
+ if isinstance(data_set, str):
329
+ if data_set == "training":
330
+ patients = self.training_patients
331
+ elif data_set == "validation":
332
+ patients = self.validation_patients
333
+ else:
334
+ raise ValueError(
335
+ f"Incorrect data set choice: {data_set}. Use `training` or `validation`"
336
+ )
337
+ new_data = self.full_df_raw.loc[self.full_df_raw["id"].isin(patients)]
338
+ elif isinstance(data_set, pd.DataFrame):
339
+ new_data = data_set
340
+ else:
341
+ raise ValueError(
342
+ "`predict_new_data` expects either a str (`training`|`validation`) or a data frame."
343
+ )
344
+
345
+ # Validate the content of the new data frame
346
+ new_columns = new_data.columns.to_list()
347
+ if not "protocol_arm" in new_columns:
348
+ new_protocols = ["identity"]
349
+ else:
350
+ new_protocols = new_data["protocol_arm"].unique().tolist()
351
+ new_outputs = new_data["output_name"].unique().tolist()
352
+ if not (set(new_protocols) <= set(self.protocol_arms)):
353
+ raise ValueError(
354
+ "Supplied data frame contains a different set of protocol arms."
355
+ )
356
+ if not (set(new_outputs) <= set(self.output_names)):
357
+ raise ValueError(
358
+ "Supplied data frame contains a different set of model outputs."
359
+ )
360
+ if not (set(self.parameter_names) <= set(new_columns)):
361
+ raise ValueError(
362
+ "All model descriptors are not supplied in the new data frame."
363
+ )
364
+
365
+ # Flag the case where no observed value was supplied
366
+ remove_value = False
367
+ if not "value" in new_columns:
368
+ remove_value = True
369
+ # Add a dummy `value` column
370
+ new_data["value"] = 1.0
371
+
372
+ wide_df = self.pivot_input_data(new_data)
373
+ tensor_inputs_wide = torch.as_tensor(
374
+ wide_df[self.parameter_names].values, device=device
375
+ )
376
+
377
+ return tensor_inputs_wide, wide_df, new_data, remove_value
378
+
379
+ def merge_predictions_long(
380
+ self,
381
+ pred: tuple[torch.Tensor, torch.Tensor],
382
+ wide_df: pd.DataFrame,
383
+ long_df: pd.DataFrame,
384
+ remove_value: bool,
385
+ ) -> pd.DataFrame:
386
+ """Merge model predictions with an observation data frame
387
+
388
+ Args:
389
+ pred (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): Predictions from a model: mean, low bound, high bound. Predictions are expected in a wide format (as many columns as tasks)
390
+ wide_df (pd.DataFrame): The comparison data frame, in a wide format
391
+ long_df (pd.DataFrame): The comparison data frame, in a long format
392
+ remove_value (bool): True if the value column should be ignored
393
+
394
+ Returns:
395
+ pd.DataFrame: A merged data frame in a long format, identical to the initial data, with additional columns [`pred_mean`, `pred_var`, `pred_low`, `pred_high`]
396
+ """
397
+ pred_mean, pred_variance = pred
398
+ # Reshape these outputs into a long format
399
+ mean_df = self.pivot_outputs_longer(wide_df, pred_mean, "pred_mean")
400
+ var_df = self.pivot_outputs_longer(wide_df, pred_variance, "pred_var")
401
+ # Merge the model results with the long format data frame
402
+ out_df = reduce(
403
+ lambda left, right: pd.merge(
404
+ left,
405
+ right,
406
+ on=["id"] + self.parameter_names + ["protocol_arm", "output_name"],
407
+ how="left",
408
+ ),
409
+ [long_df, mean_df, var_df],
410
+ )
411
+ out_df["pred_low"] = out_df.apply(
412
+ lambda r: r["pred_mean"] - 2 * np.sqrt(r["pred_var"]), axis=1
413
+ )
414
+ out_df["pred_high"] = out_df.apply(
415
+ lambda r: r["pred_mean"] + 2 * np.sqrt(r["pred_var"]), axis=1
416
+ )
417
+ # Remove the dummy value column if it was added during the data processing
418
+ if remove_value:
419
+ out_df = out_df.drop(columns=["value"])
420
+ return out_df