vpop-calibration 2.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vpop_calibration/__init__.py +22 -0
- vpop_calibration/data_generation.py +186 -0
- vpop_calibration/diagnostics.py +162 -0
- vpop_calibration/model/__init__.py +3 -0
- vpop_calibration/model/data.py +420 -0
- vpop_calibration/model/gp.py +517 -0
- vpop_calibration/model/plot.py +243 -0
- vpop_calibration/nlme.py +840 -0
- vpop_calibration/ode.py +203 -0
- vpop_calibration/saem.py +945 -0
- vpop_calibration/structural_model.py +200 -0
- vpop_calibration/test/__init__.py +11 -0
- vpop_calibration/test/test_data.py +21 -0
- vpop_calibration/test/test_gp_flavors.py +89 -0
- vpop_calibration/test/test_gp_saem.py +175 -0
- vpop_calibration/test/test_ode_saem.py +121 -0
- vpop_calibration/utils.py +9 -0
- vpop_calibration/vpop.py +50 -0
- vpop_calibration-2.2.8.dist-info/METADATA +78 -0
- vpop_calibration-2.2.8.dist-info/RECORD +22 -0
- vpop_calibration-2.2.8.dist-info/WHEEL +4 -0
- vpop_calibration-2.2.8.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from functools import reduce
|
|
6
|
+
|
|
7
|
+
from ..utils import device
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TrainingDataSet:
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
training_df: pd.DataFrame,
|
|
14
|
+
descriptors: list[str],
|
|
15
|
+
training_proportion: float = 0.7,
|
|
16
|
+
log_lower_limit: float = 1e-10,
|
|
17
|
+
log_inputs: list[str] = [],
|
|
18
|
+
log_outputs: list[str] = [],
|
|
19
|
+
data_already_normalized: bool = False,
|
|
20
|
+
):
|
|
21
|
+
"""Instantiate a TrainingDataSet container
|
|
22
|
+
|
|
23
|
+
The data container is used to process training data, normalize inputs and outputs, and provide utilitary methods to transform inputs and outputs for a PyTorch model. In particular, it transforms output names and protocol arms into individual tasks by cartesian product, and it provides mappings between tasks and protcol/output.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
training_df (pd.DataFrame): the training data. Should contain the columns [`id`, `output_name`, `protocol_arm`, *descriptors, `value`]
|
|
27
|
+
descriptors (list[str]): the names of the columns of `training_df` which correspond to descriptors on which to train the model
|
|
28
|
+
training_proportion (float, optional): Proportion of patients to be used as training vs. validation. Defaults to 0.7.
|
|
29
|
+
log_lower_limit(float): epsilon value that is added to all rescaled value to avoid numerical errors when log-scaling variables
|
|
30
|
+
log_inputs (list[str]): the list of parameter inputs which should be rescaled to log when fed to the GP. Avoid adding time here, or any parameter that takes 0 as a value.
|
|
31
|
+
log_outputs (list[str]): list of model outptus which should be rescaled to log
|
|
32
|
+
data_already_normalized(bool): set to True if the data set is preprocessed and no normalization / scaling is required
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Process the supplied data set
|
|
36
|
+
self.full_df_raw = training_df
|
|
37
|
+
|
|
38
|
+
declared_columns = self.full_df_raw.columns.to_list()
|
|
39
|
+
# Input validation
|
|
40
|
+
if not ("id" in declared_columns):
|
|
41
|
+
raise ValueError("Training data should contain an `id` column.")
|
|
42
|
+
if not ("output_name" in declared_columns):
|
|
43
|
+
raise ValueError("Training data should contain an `output_name` column.")
|
|
44
|
+
if not ("value" in declared_columns):
|
|
45
|
+
raise ValueError("Training data should contain a `value` column.")
|
|
46
|
+
if not set(descriptors) <= set(declared_columns):
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"The provided inputs are not declared in the data set: {descriptors}."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
self.parameter_names = descriptors
|
|
52
|
+
self.training_proportion = training_proportion
|
|
53
|
+
self.nb_parameters = len(self.parameter_names)
|
|
54
|
+
self.data_already_normalized = data_already_normalized
|
|
55
|
+
if not ("protocol_arm" in declared_columns):
|
|
56
|
+
self.full_df_raw["protocol_arm"] = "identity"
|
|
57
|
+
self.protocol_arms = self.full_df_raw["protocol_arm"].unique().tolist()
|
|
58
|
+
self.nb_protocol_arms = len(self.protocol_arms)
|
|
59
|
+
self.output_names = self.full_df_raw["output_name"].unique().tolist()
|
|
60
|
+
self.nb_outputs = len(self.output_names)
|
|
61
|
+
self.log_lower_limit = log_lower_limit
|
|
62
|
+
self.log_inputs = log_inputs
|
|
63
|
+
self.log_inputs_indices = [
|
|
64
|
+
self.parameter_names.index(p) for p in self.log_inputs
|
|
65
|
+
]
|
|
66
|
+
self.log_outputs = log_outputs
|
|
67
|
+
|
|
68
|
+
# Ensure input df has a consistent shape (and remove potential extra columns)
|
|
69
|
+
self.full_df_raw = self.full_df_raw[
|
|
70
|
+
["id"] + self.parameter_names + ["output_name", "protocol_arm", "value"]
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
# Gather the list of patients in the training data
|
|
74
|
+
self.patients = self.full_df_raw["id"].unique()
|
|
75
|
+
self.nb_patients = self.patients.shape[0]
|
|
76
|
+
|
|
77
|
+
print(
|
|
78
|
+
f"Successfully loaded a training data set with {self.nb_patients} patients. The following outputs are available:\n{self.output_names}\n and the following protocol arms:\n{self.protocol_arms}"
|
|
79
|
+
)
|
|
80
|
+
# Construct the list of tasks, mapping from output name and protocol arm to task number
|
|
81
|
+
self.tasks: list[str] = [
|
|
82
|
+
output + "_" + protocol
|
|
83
|
+
for protocol in self.protocol_arms
|
|
84
|
+
for output in self.output_names
|
|
85
|
+
]
|
|
86
|
+
self.nb_tasks = len(self.tasks)
|
|
87
|
+
# Map tasks to output names
|
|
88
|
+
self.task_to_output = {
|
|
89
|
+
output_name + "_" + protocol_arm: output_name
|
|
90
|
+
for output_name in self.output_names
|
|
91
|
+
for protocol_arm in self.protocol_arms
|
|
92
|
+
}
|
|
93
|
+
# Map task index to output index
|
|
94
|
+
self.task_idx_to_output_idx = {
|
|
95
|
+
self.tasks.index(k): self.output_names.index(v)
|
|
96
|
+
for k, v in self.task_to_output.items()
|
|
97
|
+
}
|
|
98
|
+
# Map task to protocol arm
|
|
99
|
+
self.task_to_protocol = {
|
|
100
|
+
output_name + "_" + protocol_arm: protocol_arm
|
|
101
|
+
for output_name in self.output_names
|
|
102
|
+
for protocol_arm in self.protocol_arms
|
|
103
|
+
}
|
|
104
|
+
# Map task index to protocol arm
|
|
105
|
+
self.task_idx_to_protocol = {
|
|
106
|
+
self.tasks.index(k): v for k, v in self.task_to_protocol.items()
|
|
107
|
+
}
|
|
108
|
+
# list tasks that should be rescaled to log
|
|
109
|
+
self.log_tasks = [
|
|
110
|
+
task for task in self.tasks if self.task_to_output[task] in self.log_outputs
|
|
111
|
+
]
|
|
112
|
+
self.log_tasks_indices = [self.tasks.index(task) for task in self.log_tasks]
|
|
113
|
+
|
|
114
|
+
## Data processing
|
|
115
|
+
# Pivot the data to the correct shape for GP training
|
|
116
|
+
self.full_df_reshaped = self.pivot_input_data(self.full_df_raw)
|
|
117
|
+
|
|
118
|
+
# Normalize the inputs and the outputs (only if required)
|
|
119
|
+
if self.data_already_normalized == True:
|
|
120
|
+
self.normalized_df = self.full_df_reshaped
|
|
121
|
+
else:
|
|
122
|
+
self.full_df_reshaped[self.log_inputs + self.log_tasks] = (
|
|
123
|
+
self.full_df_reshaped[self.log_inputs + self.log_tasks].apply(
|
|
124
|
+
lambda val: np.log(np.maximum(val, self.log_lower_limit))
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.normalized_df, mean, std = self.normalize_data(
|
|
129
|
+
self.full_df_reshaped, ["id"]
|
|
130
|
+
)
|
|
131
|
+
self.normalizing_input_mean, self.normalizing_input_std = (
|
|
132
|
+
torch.as_tensor(mean.loc[self.parameter_names].values, device=device),
|
|
133
|
+
torch.as_tensor(std.loc[self.parameter_names].values, device=device),
|
|
134
|
+
)
|
|
135
|
+
self.normalizing_output_mean, self.normalizing_output_std = (
|
|
136
|
+
torch.as_tensor(mean.loc[self.tasks].values, device=device),
|
|
137
|
+
torch.as_tensor(std.loc[self.tasks].values, device=device),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
self.unnormalize_output_wide = torch.compile(self.unnormalize_output_wide_logic)
|
|
141
|
+
|
|
142
|
+
# Compute the number of patients for training
|
|
143
|
+
self.nb_patients_training = math.floor(
|
|
144
|
+
self.training_proportion * self.nb_patients
|
|
145
|
+
)
|
|
146
|
+
self.nb_patients_validation = self.nb_patients - self.nb_patients_training
|
|
147
|
+
|
|
148
|
+
if self.training_proportion != 1: # non-empty validation data set
|
|
149
|
+
if self.nb_patients_training == self.nb_patients:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"Training proportion too high for the number of sets of parameters: all would be used for training. Set training_proportion as 1 if that is your intention."
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Randomly mixing up patients
|
|
155
|
+
mixed_patients = np.random.permutation(self.patients)
|
|
156
|
+
|
|
157
|
+
self.training_patients = mixed_patients[: self.nb_patients_training]
|
|
158
|
+
self.validation_patients = mixed_patients[self.nb_patients_training :]
|
|
159
|
+
|
|
160
|
+
self.training_df_normalized: pd.DataFrame = self.normalized_df.loc[
|
|
161
|
+
self.normalized_df["id"].isin(self.training_patients)
|
|
162
|
+
]
|
|
163
|
+
self.validation_df_normalized: pd.DataFrame = self.normalized_df.loc[
|
|
164
|
+
self.normalized_df["id"].isin(self.validation_patients)
|
|
165
|
+
]
|
|
166
|
+
self.X_validation = torch.as_tensor(
|
|
167
|
+
self.validation_df_normalized[self.parameter_names].values,
|
|
168
|
+
device=device,
|
|
169
|
+
)
|
|
170
|
+
self.Y_validation = torch.as_tensor(
|
|
171
|
+
self.validation_df_normalized[self.tasks].values, device=device
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
else: # no validation data set provided
|
|
175
|
+
self.training_df_normalized = self.normalized_df
|
|
176
|
+
self.validation_df = None
|
|
177
|
+
self.X_validation = None
|
|
178
|
+
self.Y_validation = None
|
|
179
|
+
|
|
180
|
+
self.X_training: torch.Tensor = torch.as_tensor(
|
|
181
|
+
self.training_df_normalized[self.parameter_names].values, device=device
|
|
182
|
+
)
|
|
183
|
+
self.Y_training: torch.Tensor = torch.as_tensor(
|
|
184
|
+
self.training_df_normalized[self.tasks].values, device=device
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def pivot_input_data(self, data_in: pd.DataFrame) -> pd.DataFrame:
|
|
188
|
+
"""Pivot and reorder columns from a data frame to feed to the model
|
|
189
|
+
|
|
190
|
+
This method is used at initialization on the training data frame), and when plotting the model performance against existing data.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
data_in (pd.DataFrame): Input data frame, containing the following columns
|
|
194
|
+
- `id`: patient id
|
|
195
|
+
- one column per descriptor, the same descriptors as self.parameter_names should be present
|
|
196
|
+
- `output_name`: the name of the output
|
|
197
|
+
- `protocol_arm`: the name of the protocol arm
|
|
198
|
+
- `value`: the observed value
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
pd.DataFrame: A validated and pivotted dataframe with one column per task (`outputName_protocolArm`), and one row per observation
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
# util function to rename columns as `output_protocol`
|
|
205
|
+
def join_if_two(tup: list[str]) -> str:
|
|
206
|
+
if tup[0] == "":
|
|
207
|
+
return tup[1]
|
|
208
|
+
elif tup[1] == "":
|
|
209
|
+
return tup[0]
|
|
210
|
+
else:
|
|
211
|
+
return "_".join(tup)
|
|
212
|
+
|
|
213
|
+
# Pivot the data set
|
|
214
|
+
reshaped_df = data_in.pivot(
|
|
215
|
+
index=["id"] + self.parameter_names,
|
|
216
|
+
columns=["output_name", "protocol_arm"],
|
|
217
|
+
values="value",
|
|
218
|
+
).reset_index()
|
|
219
|
+
reshaped_df.columns = list(map(join_if_two, reshaped_df.columns.to_series()))
|
|
220
|
+
|
|
221
|
+
assert set(reshaped_df.columns) == set(
|
|
222
|
+
["id"] + self.parameter_names + self.tasks
|
|
223
|
+
), "Incomplete training data set provided."
|
|
224
|
+
|
|
225
|
+
return reshaped_df
|
|
226
|
+
|
|
227
|
+
def normalize_data(
|
|
228
|
+
self, data_in: pd.DataFrame, ignore: list[str]
|
|
229
|
+
) -> tuple[pd.DataFrame, pd.Series, pd.Series]:
|
|
230
|
+
"""Normalize a data frame with respect to its mean and std, ignoring certain columns."""
|
|
231
|
+
selected_columns = data_in.columns.difference(ignore)
|
|
232
|
+
norm_data = data_in
|
|
233
|
+
mean = data_in[selected_columns].mean()
|
|
234
|
+
std = data_in[selected_columns].std()
|
|
235
|
+
norm_data[selected_columns] = (norm_data[selected_columns] - mean) / std
|
|
236
|
+
return norm_data, mean, std
|
|
237
|
+
|
|
238
|
+
def unnormalize_output_wide_logic(self, data: torch.Tensor) -> torch.Tensor:
|
|
239
|
+
"""Unnormalize wide outputs (all tasks included) from the model."""
|
|
240
|
+
unnormalized = data * self.normalizing_output_std + self.normalizing_output_mean
|
|
241
|
+
unnormalized[:, self.log_tasks_indices] = torch.exp(
|
|
242
|
+
unnormalized[:, self.log_tasks_indices]
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
return unnormalized
|
|
246
|
+
|
|
247
|
+
def unnormalize_output_long(
|
|
248
|
+
self, data: torch.Tensor, task_indices: torch.LongTensor
|
|
249
|
+
) -> torch.Tensor:
|
|
250
|
+
"""Unnormalize long outputs (one row per task) from the model."""
|
|
251
|
+
rescaled_data = data
|
|
252
|
+
for task_idx in range(self.nb_tasks):
|
|
253
|
+
log_task = self.tasks[task_idx] in self.log_tasks
|
|
254
|
+
mask = torch.tensor(task_indices == task_idx, device=device).bool()
|
|
255
|
+
rescaled_data[mask] = (
|
|
256
|
+
rescaled_data[mask] * self.normalizing_output_std[task_idx]
|
|
257
|
+
+ self.normalizing_output_mean[task_idx]
|
|
258
|
+
)
|
|
259
|
+
if log_task:
|
|
260
|
+
rescaled_data[mask] = torch.exp(rescaled_data[mask])
|
|
261
|
+
return rescaled_data
|
|
262
|
+
|
|
263
|
+
def normalize_inputs_tensor(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
264
|
+
"""Normalize new inputs provided to the model as a tensor. The columns of the input tensor should be the same as [self.descriptors]"""
|
|
265
|
+
X = inputs.to(device)
|
|
266
|
+
X[:, self.log_inputs_indices] = torch.log(X[:, self.log_inputs_indices])
|
|
267
|
+
mean = self.normalizing_input_mean
|
|
268
|
+
std = self.normalizing_input_std
|
|
269
|
+
norm_X = (X - mean) / std
|
|
270
|
+
|
|
271
|
+
return norm_X
|
|
272
|
+
|
|
273
|
+
def pivot_outputs_longer(
|
|
274
|
+
self, comparison_df: pd.DataFrame, Y: torch.Tensor, name: str
|
|
275
|
+
) -> pd.DataFrame:
|
|
276
|
+
"""Given wide outputs from a model and a comparison data frame (wide format), add the patient descriptors and reshape to a long format, with a `protocol_arm` and an `output_name` column."""
|
|
277
|
+
# Assuming Y is a wide output from the model, its columns are self.tasks
|
|
278
|
+
base_df = pd.DataFrame(
|
|
279
|
+
data=Y.cpu().detach().float().numpy(),
|
|
280
|
+
columns=self.tasks,
|
|
281
|
+
)
|
|
282
|
+
# The rows are assumed to correspond to the rows of the comparison data frame
|
|
283
|
+
base_df[["id"] + self.parameter_names] = comparison_df[
|
|
284
|
+
["id"] + self.parameter_names
|
|
285
|
+
]
|
|
286
|
+
# Pivot the data frame to a long format, separating the task names into protocol arm and output name
|
|
287
|
+
long_df = (
|
|
288
|
+
pd.wide_to_long(
|
|
289
|
+
df=base_df,
|
|
290
|
+
stubnames=self.output_names,
|
|
291
|
+
i=["id"] + self.parameter_names,
|
|
292
|
+
j="protocol_arm",
|
|
293
|
+
sep="_",
|
|
294
|
+
suffix=".*",
|
|
295
|
+
)
|
|
296
|
+
.reset_index()
|
|
297
|
+
.melt(
|
|
298
|
+
id_vars=["id"] + self.parameter_names + ["protocol_arm"],
|
|
299
|
+
value_vars=self.output_names,
|
|
300
|
+
var_name="output_name",
|
|
301
|
+
value_name=name,
|
|
302
|
+
)
|
|
303
|
+
)
|
|
304
|
+
return long_df
|
|
305
|
+
|
|
306
|
+
def get_data_inputs(
|
|
307
|
+
self, data_set: str | pd.DataFrame
|
|
308
|
+
) -> tuple[torch.Tensor, pd.DataFrame, pd.DataFrame, bool]:
|
|
309
|
+
"""Process a new data set of inputs and format them for a surrogate model to use
|
|
310
|
+
|
|
311
|
+
The new data may be incomplete. The function expects a long data table (unpivotted). This function is under-optimized, and should not be used during training.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
data_set (str | pd.DataFrame):
|
|
315
|
+
Either "training" or "validation" OR
|
|
316
|
+
An input data frame on which to predict with the GP. Should contain the following columns
|
|
317
|
+
- `id`
|
|
318
|
+
- one column per descriptor
|
|
319
|
+
- `protocol_name`
|
|
320
|
+
- `value` (Optional)
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
torch.Tensor: the inputs to provide to a surrogate model for predicting the same values as provided in the data set
|
|
324
|
+
pd.DataFrame: the processed data frame, in a wide format
|
|
325
|
+
pd.DataFrame: the original data frame, in a long format
|
|
326
|
+
bool: a flag, True if the value column is dummy in the output data frames
|
|
327
|
+
"""
|
|
328
|
+
if isinstance(data_set, str):
|
|
329
|
+
if data_set == "training":
|
|
330
|
+
patients = self.training_patients
|
|
331
|
+
elif data_set == "validation":
|
|
332
|
+
patients = self.validation_patients
|
|
333
|
+
else:
|
|
334
|
+
raise ValueError(
|
|
335
|
+
f"Incorrect data set choice: {data_set}. Use `training` or `validation`"
|
|
336
|
+
)
|
|
337
|
+
new_data = self.full_df_raw.loc[self.full_df_raw["id"].isin(patients)]
|
|
338
|
+
elif isinstance(data_set, pd.DataFrame):
|
|
339
|
+
new_data = data_set
|
|
340
|
+
else:
|
|
341
|
+
raise ValueError(
|
|
342
|
+
"`predict_new_data` expects either a str (`training`|`validation`) or a data frame."
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Validate the content of the new data frame
|
|
346
|
+
new_columns = new_data.columns.to_list()
|
|
347
|
+
if not "protocol_arm" in new_columns:
|
|
348
|
+
new_protocols = ["identity"]
|
|
349
|
+
else:
|
|
350
|
+
new_protocols = new_data["protocol_arm"].unique().tolist()
|
|
351
|
+
new_outputs = new_data["output_name"].unique().tolist()
|
|
352
|
+
if not (set(new_protocols) <= set(self.protocol_arms)):
|
|
353
|
+
raise ValueError(
|
|
354
|
+
"Supplied data frame contains a different set of protocol arms."
|
|
355
|
+
)
|
|
356
|
+
if not (set(new_outputs) <= set(self.output_names)):
|
|
357
|
+
raise ValueError(
|
|
358
|
+
"Supplied data frame contains a different set of model outputs."
|
|
359
|
+
)
|
|
360
|
+
if not (set(self.parameter_names) <= set(new_columns)):
|
|
361
|
+
raise ValueError(
|
|
362
|
+
"All model descriptors are not supplied in the new data frame."
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Flag the case where no observed value was supplied
|
|
366
|
+
remove_value = False
|
|
367
|
+
if not "value" in new_columns:
|
|
368
|
+
remove_value = True
|
|
369
|
+
# Add a dummy `value` column
|
|
370
|
+
new_data["value"] = 1.0
|
|
371
|
+
|
|
372
|
+
wide_df = self.pivot_input_data(new_data)
|
|
373
|
+
tensor_inputs_wide = torch.as_tensor(
|
|
374
|
+
wide_df[self.parameter_names].values, device=device
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
return tensor_inputs_wide, wide_df, new_data, remove_value
|
|
378
|
+
|
|
379
|
+
def merge_predictions_long(
|
|
380
|
+
self,
|
|
381
|
+
pred: tuple[torch.Tensor, torch.Tensor],
|
|
382
|
+
wide_df: pd.DataFrame,
|
|
383
|
+
long_df: pd.DataFrame,
|
|
384
|
+
remove_value: bool,
|
|
385
|
+
) -> pd.DataFrame:
|
|
386
|
+
"""Merge model predictions with an observation data frame
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
pred (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): Predictions from a model: mean, low bound, high bound. Predictions are expected in a wide format (as many columns as tasks)
|
|
390
|
+
wide_df (pd.DataFrame): The comparison data frame, in a wide format
|
|
391
|
+
long_df (pd.DataFrame): The comparison data frame, in a long format
|
|
392
|
+
remove_value (bool): True if the value column should be ignored
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
pd.DataFrame: A merged data frame in a long format, identical to the initial data, with additional columns [`pred_mean`, `pred_var`, `pred_low`, `pred_high`]
|
|
396
|
+
"""
|
|
397
|
+
pred_mean, pred_variance = pred
|
|
398
|
+
# Reshape these outputs into a long format
|
|
399
|
+
mean_df = self.pivot_outputs_longer(wide_df, pred_mean, "pred_mean")
|
|
400
|
+
var_df = self.pivot_outputs_longer(wide_df, pred_variance, "pred_var")
|
|
401
|
+
# Merge the model results with the long format data frame
|
|
402
|
+
out_df = reduce(
|
|
403
|
+
lambda left, right: pd.merge(
|
|
404
|
+
left,
|
|
405
|
+
right,
|
|
406
|
+
on=["id"] + self.parameter_names + ["protocol_arm", "output_name"],
|
|
407
|
+
how="left",
|
|
408
|
+
),
|
|
409
|
+
[long_df, mean_df, var_df],
|
|
410
|
+
)
|
|
411
|
+
out_df["pred_low"] = out_df.apply(
|
|
412
|
+
lambda r: r["pred_mean"] - 2 * np.sqrt(r["pred_var"]), axis=1
|
|
413
|
+
)
|
|
414
|
+
out_df["pred_high"] = out_df.apply(
|
|
415
|
+
lambda r: r["pred_mean"] + 2 * np.sqrt(r["pred_var"]), axis=1
|
|
416
|
+
)
|
|
417
|
+
# Remove the dummy value column if it was added during the data processing
|
|
418
|
+
if remove_value:
|
|
419
|
+
out_df = out_df.drop(columns=["value"])
|
|
420
|
+
return out_df
|