vpop-calibration 2.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,517 @@
1
+ from matplotlib import pyplot as plt
2
+ import math
3
+ import torch
4
+ import gpytorch
5
+ from tqdm.notebook import tqdm
6
+ from gpytorch.mlls import VariationalELBO, PredictiveLogLikelihood
7
+ from torch.utils.data import TensorDataset, DataLoader
8
+ import numpy as np
9
+ import pandas as pd
10
+ from typing import Optional, cast
11
+
12
+ from .data import TrainingDataSet
13
+ from .plot import (
14
+ plot_all_solutions,
15
+ plot_individual_solution,
16
+ plot_obs_vs_predicted,
17
+ plot_loss,
18
+ )
19
+ from ..utils import smoke_test, device
20
+
21
+ torch.set_default_dtype(torch.float64)
22
+ gpytorch.settings.cholesky_jitter(1e-6)
23
+
24
+
25
+ class SVGP(gpytorch.models.ApproximateGP):
26
+ """The internal GP class used to create surrogate models, interfacing with gpytorch's API"""
27
+
28
+ def __init__(
29
+ self,
30
+ inducing_points: torch.Tensor,
31
+ nb_params: int,
32
+ nb_tasks: int,
33
+ nb_latents: int,
34
+ var_dist: str = "Chol",
35
+ var_strat: str = "IMV",
36
+ kernel: str = "RBF",
37
+ deep_kernel: bool = False,
38
+ jitter: float = 1e-6,
39
+ nb_mixtures: int = 4, # only for the SMK kernel
40
+ nb_features: int = 10,
41
+ ):
42
+ """_summary_
43
+
44
+ Args:
45
+ inducing_points (torch.Tensor): Initial choice for the inducing points
46
+ nb_params (int): Number of input parameters
47
+ nb_outputs (int): Number of outputs (tasks)
48
+ var_dist (str, optional): Variational distribution choice. Defaults to "Chol".
49
+ var_strat (str, optional): Variational strategy choice. Defaults to "IMV".
50
+ kernel (str, optional): Kernel choice. Defaults to "RBF".
51
+ deep_kernel (bool, optional): Add a neural network feature extractor in the kernel
52
+ jitter (float, optional): Jitter value (for numerical stability). Defaults to 1e-6.
53
+ nb_mixtures (int, optional): Number of mixtures for the SMK kernel. Defaults to 4.
54
+ nb_features (int, optional): Number of features for the deep kernel. Defaults to 10.
55
+ """
56
+ assert var_dist == "Chol", f"Unsupported variational distribution: {var_dist}"
57
+ if var_strat == "LMCV":
58
+ self.batch_size = nb_latents
59
+ elif var_strat == "IMV":
60
+ self.batch_size = nb_tasks
61
+ else:
62
+ self.batch_size = nb_tasks
63
+
64
+ self.kernel_type = kernel
65
+ self.deep_kernel = deep_kernel
66
+ if deep_kernel:
67
+ self.kernel_size = nb_features
68
+ else:
69
+ self.kernel_size = nb_params
70
+
71
+ variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
72
+ inducing_points.shape[0],
73
+ batch_shape=torch.Size([self.batch_size]),
74
+ mean_init_std=1e-3,
75
+ )
76
+ if var_strat == "IMV":
77
+ variational_strategy = (
78
+ gpytorch.variational.IndependentMultitaskVariationalStrategy(
79
+ gpytorch.variational.VariationalStrategy(
80
+ self,
81
+ inducing_points,
82
+ variational_distribution,
83
+ learn_inducing_locations=True,
84
+ jitter_val=jitter,
85
+ ),
86
+ num_tasks=nb_tasks,
87
+ )
88
+ )
89
+ elif var_strat == "LMCV":
90
+ variational_strategy = gpytorch.variational.LMCVariationalStrategy(
91
+ gpytorch.variational.VariationalStrategy(
92
+ self,
93
+ inducing_points,
94
+ variational_distribution,
95
+ learn_inducing_locations=True,
96
+ jitter_val=jitter,
97
+ ),
98
+ num_tasks=nb_tasks,
99
+ num_latents=nb_latents,
100
+ latent_dim=-1,
101
+ )
102
+ else:
103
+ raise ValueError(f"Unsupported variational strategy {var_strat}")
104
+
105
+ super().__init__(variational_strategy)
106
+
107
+ # Todo : allow for different mean choices
108
+ self.mean_module = gpytorch.means.ConstantMean(
109
+ batch_shape=torch.Size([self.batch_size])
110
+ )
111
+
112
+ if kernel == "RBF":
113
+ self.covar_module = gpytorch.kernels.ScaleKernel(
114
+ gpytorch.kernels.RBFKernel(
115
+ batch_shape=torch.Size([self.batch_size]),
116
+ ard_num_dims=self.kernel_size,
117
+ jitter=jitter,
118
+ ),
119
+ batch_shape=torch.Size([self.batch_size]),
120
+ )
121
+ elif kernel == "SMK":
122
+ self.covar_module = gpytorch.kernels.SpectralMixtureKernel(
123
+ batch_size=self.batch_size,
124
+ num_mixtures=nb_mixtures,
125
+ ard_num_dims=self.kernel_size,
126
+ jitter=jitter,
127
+ )
128
+ elif kernel == "Matern":
129
+ self.covar_module = gpytorch.kernels.ScaleKernel(
130
+ gpytorch.kernels.MaternKernel(
131
+ nu=2.5,
132
+ batch_size=nb_tasks,
133
+ num_mixtures=nb_mixtures,
134
+ ard_num_dims=self.kernel_size,
135
+ jitter=jitter,
136
+ ),
137
+ batch_shape=torch.Size([self.batch_size]),
138
+ )
139
+ else:
140
+ raise ValueError(f"Unsupported kernel {kernel}")
141
+ if deep_kernel:
142
+ self.feature_extractor = LargeFeatureExtractor(nb_params, nb_features)
143
+ self.scale_to_bounds = gpytorch.utils.grid.ScaleToBounds(-1.0, 1.0)
144
+
145
+ def forward(self, x: torch.Tensor):
146
+ if self.deep_kernel:
147
+ proj_x = self.feature_extractor(x)
148
+ proj_x = self.scale_to_bounds(proj_x)
149
+ mean_x = cast(torch.Tensor, self.mean_module(proj_x))
150
+ covar_x = self.covar_module(proj_x)
151
+ else:
152
+ mean_x = cast(torch.Tensor, self.mean_module(x))
153
+ covar_x = self.covar_module(x)
154
+ return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
155
+
156
+
157
+ class GP:
158
+ """GP surrogate model"""
159
+
160
+ def __init__(
161
+ self,
162
+ training_df: pd.DataFrame,
163
+ descriptors: list[str],
164
+ var_dist: str = "Chol", # only Cholesky currently supported
165
+ var_strat: str = "IMV", # either IMV (Independent Multitask Variational) or LMCV (Linear Model of Coregionalization Variational)
166
+ kernel: str = "RBF", # RBF or SMK
167
+ deep_kernel: bool = True,
168
+ nb_training_iter: int = 400,
169
+ training_proportion: float = 0.7,
170
+ nb_inducing_points: int = 200,
171
+ data_already_normalized: bool = False,
172
+ log_lower_limit: float = 1e-10,
173
+ log_inputs: list[str] = [],
174
+ log_outputs: list[str] = [],
175
+ nb_latents: Optional[int] = None,
176
+ # by default we will use nb_latents = nb_outputs
177
+ mll: str = "ELBO", # ELBO or PLL
178
+ learning_rate: Optional[float] = None, # optional
179
+ lr_decay: Optional[float] = None,
180
+ num_mixtures: int = 4, # only for the SMK kernel
181
+ nb_features: int = 10, # only for the DK kernel
182
+ jitter: float = 1e-6,
183
+ ):
184
+ """Instantiate a GP model on a training data frame
185
+
186
+ Args:
187
+ training_df (pd.DataFrame): Training data frame containing the following columns:
188
+ - `id`: the id of the patient, str or int
189
+ - `descriptors`: one column per patient descriptor (including `time`, if necessary)
190
+ - `output_name`: the name of simulated model output
191
+ - `value`: the simulated value (for a given patient, protocol arm and output name)
192
+ - `protocol_arm` (optional): the protocol arm on which this patient was simulated. If not provided, `identity` will be used
193
+ descriptors (list[str]): the names of the columns of `training_df` which correspond to descriptors on which to train the GP
194
+ var_dist (str, optional): Variational distribution choice. Defaults to "Chol".
195
+ nb_training_iter (int, optional): Number of iterations for training. Defaults to 400.
196
+ training_proportion (float, optional): Proportion of patients to be used as training vs. validation. Defaults to 0.7.
197
+ nb_inducing_points (int, optional): Number of inducing points to be used for variational inference. Defaults to 200.
198
+ log_inputs (list[str]): the list of parameter inputs which should be rescaled to log when fed to the GP. Avoid adding time here, or any parameter that takes 0 as a value.
199
+ log_outputs (list[str]): list of model outptus which should be rescaled to log
200
+ log_lower_limit(float): epsilon value that is added to all rescaled value to avoid numerical errors when log-scaling variables
201
+ nb_latents (Optional[int], optional): Number of latents. Defaults to None, implying that nb_latents = nb_tasks will be used
202
+ mll (str, optional): Marginal log likelihood choice. Defaults to "ELBO" (other option "PLL")
203
+ learning_rate (Optional[float]): learning rate initial value. Defaults to 0.001 (in torch.optim.Adam)
204
+ lr_decay (Optional[float]): learning rate decay rate.
205
+ num_mixtures (int): Number of mixtures used in the SMK kernel. Not used for other kernel choices. Default to 4.
206
+ nb_features (int): Number of features used in the deep kernel. Not used for other kernel choices. Default to 10.
207
+ jitter: Jitter value for numerical stability
208
+
209
+ Comments:
210
+ The GP will learn nb_tasks = nb_outputs * nb_protocol_arms, i.e. one predicted task per model output per protocol arm.
211
+
212
+ """
213
+ # Define GP parameters
214
+ self.var_dist = var_dist
215
+ self.var_strat = var_strat
216
+ self.kernel = kernel
217
+ self.deep_kernel = deep_kernel
218
+ if smoke_test:
219
+ self.nb_training_iter = 1
220
+ self.nb_inducing_points = 10
221
+ else:
222
+ self.nb_training_iter = nb_training_iter
223
+ self.nb_inducing_points = nb_inducing_points
224
+ self.learning_rate = learning_rate
225
+ self.mll_name = mll
226
+ self.num_mixtures = num_mixtures
227
+ self.nb_features = nb_features
228
+ self.jitter = jitter
229
+ if lr_decay is not None:
230
+ self.lr_decay = lr_decay
231
+
232
+ self.data = TrainingDataSet(
233
+ training_df,
234
+ descriptors,
235
+ training_proportion,
236
+ log_lower_limit,
237
+ log_inputs,
238
+ log_outputs,
239
+ data_already_normalized,
240
+ )
241
+
242
+ if nb_latents is None:
243
+ self.nb_latents = self.data.nb_tasks
244
+ else:
245
+ self.nb_latents = nb_latents
246
+ # Create inducing points
247
+ self.inducing_points = self.data.X_training[
248
+ torch.randperm(self.data.X_training.shape[0])[: self.nb_inducing_points],
249
+ :,
250
+ ]
251
+
252
+ # Initialize likelihood and model
253
+ self.likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
254
+ num_tasks=self.data.nb_tasks, has_global_noise=True, has_task_noise=True
255
+ )
256
+ self.model = SVGP(
257
+ inducing_points=self.inducing_points,
258
+ nb_params=self.data.nb_parameters,
259
+ nb_tasks=self.data.nb_tasks,
260
+ nb_latents=self.nb_latents,
261
+ var_dist=self.var_dist,
262
+ var_strat=self.var_strat,
263
+ kernel=self.kernel,
264
+ jitter=self.jitter,
265
+ nb_mixtures=self.num_mixtures,
266
+ nb_features=self.nb_features,
267
+ )
268
+
269
+ # set the marginal log likelihood
270
+ if self.mll_name == "ELBO":
271
+ self.mll = VariationalELBO(
272
+ self.likelihood, self.model, num_data=self.data.Y_training.size(0)
273
+ )
274
+ elif self.mll_name == "PLL":
275
+ self.mll = PredictiveLogLikelihood(
276
+ self.likelihood, self.model, num_data=self.data.Y_training.size(0)
277
+ )
278
+ else:
279
+ raise ValueError(f"Invalid MLL choice ({self.mll}). Choose ELBO or PLL.")
280
+
281
+ # Move all components to the selected device
282
+ self.model.to(device)
283
+ self.likelihood.to(device)
284
+ self.mll.to(device)
285
+
286
+ def train(
287
+ self, mini_batching: bool = False, mini_batch_size: Optional[int] = None
288
+ ) -> None:
289
+ # set model and likelihood in training mode
290
+ self.model.train()
291
+ self.likelihood.train()
292
+
293
+ # initialize the adam optimizer
294
+ params_to_optim = [
295
+ {"params": self.model.parameters()},
296
+ {"params": self.likelihood.parameters()},
297
+ ]
298
+ if self.learning_rate is None:
299
+ optimizer = torch.optim.Adam(params_to_optim)
300
+ else:
301
+ optimizer = torch.optim.Adam(params_to_optim, lr=self.learning_rate)
302
+ if hasattr(self, "lr_decay"):
303
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
304
+ optimizer, gamma=self.lr_decay
305
+ )
306
+ else:
307
+ scheduler = None
308
+
309
+ # keep track of the loss
310
+ losses_list = []
311
+ epochs = tqdm(range(self.nb_training_iter), desc="Epochs", position=0)
312
+ with gpytorch.settings.observation_nan_policy("fill"):
313
+ # Batch training loop
314
+ if mini_batching:
315
+ # set the mini_batch_size to a power of two of the total size -4
316
+ if mini_batch_size == None:
317
+ power = np.maximum(
318
+ math.floor(math.log2(self.data.X_training.shape[0])) - 4, 1
319
+ )
320
+ self.mini_batch_size: int = math.floor(2**power)
321
+ else:
322
+ self.mini_batch_size = mini_batch_size
323
+
324
+ # prepare mini-batching
325
+ train_dataset = TensorDataset(
326
+ self.data.X_training, self.data.Y_training
327
+ )
328
+ train_loader = DataLoader(
329
+ train_dataset,
330
+ batch_size=self.mini_batch_size,
331
+ shuffle=True,
332
+ )
333
+
334
+ # main training loop
335
+ for _ in epochs:
336
+ for batch_params, batch_outputs in tqdm(
337
+ train_loader, desc="Batch progress", position=1, leave=False
338
+ ):
339
+ optimizer.zero_grad() # zero gradients from previous iteration
340
+ output = self.model(batch_params) # recalculate the prediction
341
+ loss = -cast(torch.Tensor, self.mll(output, batch_outputs))
342
+ loss.backward() # compute the gradients of the parameters that can be changed
343
+ epochs.set_postfix({"loss": loss.item()})
344
+ optimizer.step()
345
+ if scheduler is not None:
346
+ scheduler.step()
347
+
348
+ # Full data set training loop
349
+ else:
350
+ for _ in epochs:
351
+ optimizer.zero_grad() # zero gradients from previous iteration
352
+ output = self.model(
353
+ self.data.X_training
354
+ ) # calculate the prediction with current parameters
355
+ loss = -cast(torch.Tensor, self.mll(output, self.data.Y_training))
356
+ loss.backward() # compute the gradients of the parameters that can be changed
357
+ losses_list.append(loss.item())
358
+ optimizer.step()
359
+ epochs.set_postfix({"loss": loss.item()})
360
+ if scheduler is not None:
361
+ scheduler.step()
362
+ self.losses = losses_list
363
+
364
+ def _predict_training(self, X: torch.Tensor):
365
+ """Internal method used to predict normalized outputs on normalized inputs."""
366
+ self.model.eval()
367
+ self.likelihood.eval()
368
+
369
+ with torch.no_grad():
370
+ pred = cast(
371
+ gpytorch.distributions.MultitaskMultivariateNormal,
372
+ self.likelihood(self.model(X)),
373
+ )
374
+
375
+ return pred.mean, pred.confidence_region
376
+
377
+ def predict_wide_scaled(self, X: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
378
+ """Predict mean and interval confidence values for a given input tensor (normalized inputs). This function outputs rescaled values."""
379
+ self.model.eval()
380
+ self.likelihood.eval()
381
+ inputs = self.data.normalize_inputs_tensor(X)
382
+
383
+ with torch.no_grad():
384
+ pred = self.model(inputs)
385
+
386
+ if self.data.data_already_normalized:
387
+ out_mean = pred.mean
388
+ else:
389
+ out_mean = self.data.unnormalize_output_wide(pred.mean)
390
+ return out_mean, pred.variance
391
+
392
+ def predict_long_scaled(
393
+ self, X: torch.Tensor, tasks: torch.LongTensor
394
+ ) -> tuple[torch.Tensor, torch.Tensor]:
395
+ """Predict outputs from the GP in a long format (one row per task)"""
396
+
397
+ self.model.eval()
398
+ self.likelihood.eval()
399
+ inputs = self.data.normalize_inputs_tensor(X)
400
+ with torch.no_grad():
401
+ pred = self.model(inputs, task_indices=tasks)
402
+ if self.data.data_already_normalized:
403
+ out_mean = pred.mean
404
+ else:
405
+ out_mean = self.data.unnormalize_output_long(pred.mean, task_indices=tasks)
406
+ return out_mean, pred.variance
407
+
408
+ def plot_loss(self) -> None:
409
+ if not hasattr(self, "losses"):
410
+ raise ValueError("Cannot plot loss before training the model.")
411
+ # plot the loss over iterations
412
+ iterations = np.arange(1, self.nb_training_iter + 1)
413
+ losses = np.array(self.losses)
414
+ plot_loss(iterations, losses)
415
+
416
+ def RMSE(self, y1: torch.Tensor, y2: torch.Tensor) -> torch.Tensor:
417
+ """Given two tensors of same shape, compute the Root Mean Squared Error on each column (outputs)."""
418
+ assert y1.shape == y2.shape
419
+ # Ignore potential NaN values in the RMSE computation
420
+ mask = (~torch.isnan(y1)) * (~torch.isnan(y2))
421
+ squared_residuals = torch.where(mask, torch.pow(y1 - y2, 2), 0)
422
+ return torch.sqrt(squared_residuals.sum(dim=0) / mask.sum(dim=0))
423
+
424
+ def eval_perf(self):
425
+ """Evaluate the model performance on its training data set and validation data set (normalized inputs and ouptuts)."""
426
+
427
+ def print_task_rmse(index, val):
428
+ print(
429
+ f" Output: {self.data.output_names[self.data.task_idx_to_output_idx[index]]}, protocol: {self.data.task_idx_to_protocol[index]}, RMSE: {val:.4f}"
430
+ )
431
+
432
+ (
433
+ self.Y_training_predicted_mean,
434
+ _,
435
+ ) = self._predict_training(self.data.X_training)
436
+ self.RMSE_training = self.RMSE(
437
+ self.Y_training_predicted_mean, self.data.Y_training
438
+ )
439
+ print("Training data set:")
440
+
441
+ for i, err in enumerate(self.RMSE_training):
442
+ print_task_rmse(i, err.item())
443
+
444
+ if not (self.data.X_validation is None) and not (
445
+ self.data.Y_validation is None
446
+ ):
447
+ (
448
+ self.Y_validation_predicted_mean,
449
+ _,
450
+ ) = self._predict_training(self.data.X_validation)
451
+ self.RMSE_validation = self.RMSE(
452
+ self.Y_validation_predicted_mean, self.data.Y_validation
453
+ )
454
+ print("Validation data set:")
455
+ for i, err in enumerate(self.RMSE_validation):
456
+ print_task_rmse(i, err.item())
457
+
458
+ def predict_new_data(self, data_set: str | pd.DataFrame) -> pd.DataFrame:
459
+ """Process a new data set of inputs and predict using the GP
460
+
461
+ The new data may be incomplete. The function expects a long data table (unpivotted). This function is under-optimized, and should not be used during training.
462
+
463
+ Args:
464
+ data_set (str | pd.DataFrame):
465
+ Either "training" or "validation" OR
466
+ An input data frame on which to predict with the GP. Should contain the following columns
467
+ - `id`
468
+ - one column per descriptor
469
+ - `protocol_name`
470
+
471
+ Returns:
472
+ pd.DataFrame: Same data frame as new_data, with additional columns
473
+ - `pred_mean`
474
+ - `pred_low`
475
+ - `pred_high`
476
+ """
477
+ # Fetch the data
478
+ X_wide, wide_df, long_df, remove_value = self.data.get_data_inputs(data_set)
479
+
480
+ # Simulate the model - using a wide output (all tasks predicted for each observation)
481
+ pred = self.predict_wide_scaled(X_wide)
482
+ out_df = self.data.merge_predictions_long(pred, wide_df, long_df, remove_value)
483
+ return out_df
484
+
485
+ def plot_obs_vs_predicted(self, data_set: pd.DataFrame | str, logScale=None):
486
+ """Plots the observed vs. predicted values on the training or validation data set, or on a new data set."""
487
+
488
+ obs_vs_pred = self.predict_new_data(data_set)
489
+ plot_obs_vs_predicted(obs_vs_pred, logScale)
490
+
491
+ # plot function
492
+ def plot_individual_solution(self, patient_number: int):
493
+ """Plot the model prediction (and confidence interval) vs. the input data for a single patient from the GP's internal data set. Can be either training or validation patient."""
494
+ patient_index = self.data.patients[patient_number]
495
+ input_df = self.data.full_df_raw.loc[
496
+ self.data.full_df_raw["id"] == patient_index
497
+ ]
498
+ obs_vs_pred = self.predict_new_data(input_df)
499
+ plot_individual_solution(obs_vs_pred)
500
+
501
+ def plot_all_solutions(self, data_set: str | pd.DataFrame):
502
+ """Plot the overlapped observations and model predictions for all patients, either on one the GP's intrinsic data sets, or on a new data set."""
503
+
504
+ obs_vs_pred = self.predict_new_data(data_set)
505
+ plot_all_solutions(obs_vs_pred)
506
+
507
+
508
+ class LargeFeatureExtractor(torch.nn.Sequential):
509
+ def __init__(self, n_params, n_features):
510
+ super(LargeFeatureExtractor, self).__init__()
511
+ self.add_module("linear1", torch.nn.Linear(n_params, 1000))
512
+ self.add_module("relu1", torch.nn.ReLU())
513
+ self.add_module("linear2", torch.nn.Linear(1000, 500))
514
+ self.add_module("relu2", torch.nn.ReLU())
515
+ self.add_module("linear3", torch.nn.Linear(500, 50))
516
+ self.add_module("relu3", torch.nn.ReLU())
517
+ self.add_module("linear4", torch.nn.Linear(50, n_features))