vpop-calibration 2.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,840 @@
1
+ import torch
2
+ from typing import Union, Optional
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ from .structural_model import StructuralModel
7
+ from .utils import device
8
+
9
+
10
+ class NlmeModel:
11
+ def __init__(
12
+ self,
13
+ structural_model: StructuralModel,
14
+ patients_df: pd.DataFrame,
15
+ init_log_MI: dict[str, float],
16
+ init_PDU: dict[str, dict[str, float]],
17
+ init_res_var: list[float],
18
+ covariate_map: Optional[dict[str, dict[str, dict[str, str | float]]]] = None,
19
+ error_model_type: str = "additive",
20
+ pred_var_threshold: float = 1e-2,
21
+ ):
22
+ """Create a non-linear mixed effects model
23
+
24
+ Using a structural model (simulation model) and a covariate structure, create a non-linear mixed effects model, to be used in PySAEM or another optimizer, or to predict data using a covariance structure.
25
+
26
+ Args:
27
+ structural_model (StructuralModel): A simulation model defined via the convenience class StructuralModel
28
+ patients_df (DataFrame): the list of patients to be considered, with potential covariate values listed, and the name of the protocol arm on which the patient was evaluated (optional - if not supplied, `identity` will be used). The `id` column is expected, any additional column will be handled as a covariate
29
+ init_log_MI: for each model intrinsic parameter, provide an initial value (log)
30
+ init_PDU: for each patient descriptor unknown parameter, provide an initial mean and sd of the log
31
+ init_res_var: for each model output, provide an initial residual variance
32
+ covariate_map (optional[dict]): for each PDU, the list of covariates that affect it - each associated with a covariation coefficient (to be calibrated)
33
+ Example
34
+ {"pdu_name":
35
+ {"covariate_name":
36
+ {"coef": "coef_name", "value": initial_value}
37
+ }
38
+ }
39
+ error_model_type (str): either `additive` or `proportional` error model
40
+ pred_var_threshold (float): Threshold of predictive variance that will issue a warning. Default 1e-2.
41
+ """
42
+ self.structural_model: StructuralModel = structural_model
43
+ self.pred_var_threshold = pred_var_threshold
44
+
45
+ self.MI_names: list[str] = list(init_log_MI.keys())
46
+ self.nb_MI: int = len(self.MI_names)
47
+ self.initial_log_MI = torch.tensor([val for _, val in init_log_MI.items()]).to(
48
+ device
49
+ )
50
+ self.PDU_names: list[str] = list(init_PDU.keys())
51
+ self.nb_PDU: int = len(self.PDU_names)
52
+
53
+ if set(self.MI_names) & set(self.PDU_names):
54
+ raise ValueError(
55
+ f"Overlapping model intrinsic and PDU descriptors:{(set(self.MI_names) & set(self.PDU_names))}"
56
+ )
57
+
58
+ self.patients_df: pd.DataFrame = patients_df.drop_duplicates()
59
+ self.patients: list[str | int] = self.patients_df["id"].unique().tolist()
60
+ self.nb_patients: int = len(self.patients)
61
+ covariate_columns = self.patients_df.columns.to_list()
62
+ if "protocol_arm" not in covariate_columns:
63
+ self.patients_df["protocol_arm"] = "identity"
64
+
65
+ additional_columns: list[str] = self.patients_df.drop(
66
+ ["id", "protocol_arm"], axis=1
67
+ ).columns.tolist()
68
+
69
+ init_betas_list: list = []
70
+ if covariate_map is None:
71
+ print(
72
+ f"No covariate map provided. All additional columns in `patients_df` will be handled as known descriptors: {additional_columns}"
73
+ )
74
+ self.covariate_map = None
75
+ self.covariate_names = []
76
+ self.covariate_coeffs_names = []
77
+ self.nb_covariates = 0
78
+ self.population_betas_names = self.PDU_names
79
+ init_betas_list = [val["mean"] for _, val in init_PDU.items()]
80
+ self.PDK_names = additional_columns
81
+ self.nb_PDK = len(self.PDK_names)
82
+ else:
83
+ self.covariate_map = covariate_map
84
+ self.population_betas_names: list = []
85
+ covariate_set = set()
86
+ covariate_coeffs_set = set()
87
+ pdk_names = set(additional_columns)
88
+ for PDU_name in self.PDU_names:
89
+ self.population_betas_names.append(PDU_name)
90
+ init_betas_list.append(init_PDU[PDU_name]["mean"])
91
+ if PDU_name not in covariate_map:
92
+ raise ValueError(
93
+ f"No covariate map listed for {PDU_name}. Add an empty set if it has no covariate."
94
+ )
95
+ for covariate, coef in self.covariate_map[PDU_name].items():
96
+ if covariate not in additional_columns:
97
+ raise ValueError(
98
+ f"Covariate appears in the map but not in the patient set: {covariate}"
99
+ )
100
+ if covariate is not None:
101
+ covariate_set.add(covariate)
102
+ if covariate in pdk_names:
103
+ pdk_names.remove(covariate)
104
+ coef_name = coef["coef"]
105
+ covariate_coeffs_set.add(coef_name)
106
+ coef_val = coef["value"]
107
+ self.population_betas_names.append(coef_name)
108
+ init_betas_list.append(coef_val)
109
+ self.covariate_names = list(covariate_set)
110
+ self.covariate_coeffs_names = list(covariate_coeffs_set)
111
+ self.nb_covariates = len(self.covariate_names)
112
+ self.PDK_names = list(pdk_names)
113
+ self.nb_PDK = len(self.PDK_names)
114
+
115
+ print(f"Successfully loaded {self.nb_covariates} covariates:")
116
+ print(self.covariate_names)
117
+ if self.nb_PDK > 0:
118
+ self.patients_pdk = {}
119
+ for patient in self.patients:
120
+ row = self.patients_df.loc[
121
+ self.patients_df["id"] == patient
122
+ ].drop_duplicates()
123
+ self.patients_pdk.update(
124
+ {
125
+ patient: torch.as_tensor(
126
+ row[self.PDK_names].values, device=device
127
+ )
128
+ }
129
+ )
130
+ # Store the full pdk tensor on the device
131
+ self.patients_pdk_full = torch.cat(
132
+ [self.patients_pdk[ind] for ind in self.patients]
133
+ ).to(device)
134
+ print(f"Successfully loaded {self.nb_PDK} known descriptors:")
135
+ print(self.PDK_names)
136
+ else:
137
+ # Create an empty pdk tensor
138
+ self.patients_pdk_full = torch.empty((self.nb_patients, 0), device=device)
139
+
140
+ if set(self.PDK_names + self.PDU_names + self.MI_names) != set(
141
+ self.structural_model.parameter_names
142
+ ):
143
+ raise ValueError(
144
+ f"Non-matching descriptor set and structural model parameter set:\n{set(self.PDK_names + self.PDU_names + self.MI_names)}\n{set(self.structural_model.parameter_names)}"
145
+ )
146
+
147
+ self.descriptors: list[str] = self.PDK_names + self.PDU_names + self.MI_names
148
+ self.nb_descriptors: int = len(self.descriptors)
149
+ # Assume that the descriptors will always be provided to the model in the following order:
150
+ # PDK, PDU, MI
151
+ self.model_input_to_descriptor = torch.as_tensor(
152
+ [
153
+ self.descriptors.index(param)
154
+ for param in self.structural_model.parameter_names
155
+ ],
156
+ device=device,
157
+ ).long()
158
+ self.initial_betas = torch.as_tensor(init_betas_list, device=device)
159
+ self.nb_betas: int = len(self.population_betas_names)
160
+ self.outputs_names: list[str] = self.structural_model.output_names
161
+ self.nb_outputs: int = self.structural_model.nb_outputs
162
+ self.error_model_type: str = error_model_type
163
+ self.init_res_var = torch.as_tensor(init_res_var, device=device)
164
+ self.init_omega = torch.diag(
165
+ torch.as_tensor(
166
+ [float(init_PDU[pdu]["sd"]) for pdu in self.PDU_names], device=device
167
+ )
168
+ )
169
+
170
+ # Assemble the list of design matrices from the covariance structure
171
+ self.design_matrices = self._create_all_design_matrices()
172
+ # Store the full design matrix on the device
173
+ self.full_design_matrix = torch.stack(
174
+ [self.design_matrices[p] for p in self.patients]
175
+ ).to(device)
176
+
177
+ # Initiate the nlme parameters
178
+ self.log_MI = self.initial_log_MI
179
+ self.population_betas = self.initial_betas
180
+ self.omega_pop = self.init_omega
181
+ self.omega_pop_lower_chol = torch.linalg.cholesky(self.omega_pop).to(device)
182
+ self.residual_var = self.init_res_var
183
+ self.eta_distribution = torch.distributions.MultivariateNormal(
184
+ loc=torch.zeros(self.nb_PDU, device=device),
185
+ covariance_matrix=self.omega_pop,
186
+ )
187
+ self.current_eta_samples = self.sample_individual_etas()
188
+ self.current_map_estimates = self.individual_parameters(
189
+ self.current_eta_samples
190
+ )
191
+
192
+ def _create_design_matrix(self, covariates: dict[str, float]) -> torch.Tensor:
193
+ """
194
+ Creates the design matrix X_i for a single individual based on the model's covariate map.
195
+ This matrix will be multiplied with population betas so that log(theta_i[PDU]) = X_i @ betas + eta_i.
196
+ """
197
+ design_matrix_X_i = torch.zeros((self.nb_PDU, self.nb_betas), device=device)
198
+ col_idx = 0
199
+ for i, PDU_name in enumerate(self.PDU_names):
200
+ design_matrix_X_i[i, col_idx] = 1.0
201
+ col_idx += 1
202
+ if self.covariate_map is not None:
203
+ for covariate in self.covariate_map[PDU_name].keys():
204
+ design_matrix_X_i[i, col_idx] = float(covariates[covariate])
205
+ col_idx += 1
206
+ return design_matrix_X_i
207
+
208
+ def _create_all_design_matrices(self) -> dict[Union[str, int], torch.Tensor]:
209
+ """Creates a design matrix for each unique individual based on their covariates, given the in the covariates_df."""
210
+ design_matrices = {}
211
+ if self.nb_covariates == 0:
212
+ for ind_id in self.patients:
213
+ design_matrices[ind_id] = self._create_design_matrix({})
214
+ else:
215
+ for ind_id in self.patients:
216
+ individual_covariates = (
217
+ self.patients_df[self.patients_df["id"] == ind_id]
218
+ .iloc[0]
219
+ .drop("id")
220
+ )
221
+ covariates_dict = individual_covariates.to_dict()
222
+ design_matrices[ind_id] = self._create_design_matrix(covariates_dict)
223
+ return design_matrices
224
+
225
+ def add_observations(self, observations_df: pd.DataFrame) -> None:
226
+ """Associate the NLME model with a data frame of observations
227
+
228
+ Args:
229
+ observations_df (pd.DataFrame): A data frame of observations, with columns
230
+ - `id`: the patient id. Should be consistent with self.patients_df
231
+ - `time`: the observation time
232
+ - `output_name`
233
+ - `value`
234
+ """
235
+ # Store the raw data frame
236
+ self.observations_df = observations_df
237
+ # Data validation
238
+ input_columns = observations_df.columns.tolist()
239
+ unique_outputs = observations_df["output_name"].unique().tolist()
240
+ if "id" not in input_columns:
241
+ raise ValueError(
242
+ "Provided observation data frame should contain `id` column."
243
+ )
244
+ input_patients = observations_df["id"].unique()
245
+ if set(input_patients) != set(self.patients):
246
+ # Note this check might be unnecessary
247
+ raise ValueError(
248
+ f"Missing observations for the following patients: {set(self.patients) - set(input_patients)}"
249
+ )
250
+ if "time" not in input_columns:
251
+ raise ValueError(
252
+ "Provided observation data frame should contain `time` column."
253
+ )
254
+ if not (set(unique_outputs) <= set(self.outputs_names)):
255
+ raise ValueError(
256
+ f"Unknown model output: {set(unique_outputs) - set(self.outputs_names)}"
257
+ )
258
+ if hasattr(self, "observations_tensors"):
259
+ print(
260
+ "Warning: overriding existing observation data frame for the NLME model"
261
+ )
262
+ if "value" not in input_columns:
263
+ raise ValueError(
264
+ "The provided observations data frame does not contain a `value` column."
265
+ )
266
+ processed_df = observations_df[["id", "output_name", "time", "value"]].merge(
267
+ self.patients_df, how="left", on="id"
268
+ )
269
+ processed_df["task"] = processed_df.apply(
270
+ lambda r: r["output_name"] + "_" + r["protocol_arm"], axis=1
271
+ )
272
+ processed_df["task_index"] = processed_df["task"].apply(
273
+ lambda t: self.structural_model.tasks.index(t)
274
+ )
275
+ processed_df["output_index"] = processed_df["output_name"].apply(
276
+ lambda o: self.structural_model.output_names.index(o)
277
+ )
278
+ global_time_steps = (
279
+ processed_df["time"].drop_duplicates().sort_values().to_list()
280
+ )
281
+ processed_df["time_step_index"] = processed_df["time"].apply(
282
+ lambda t: global_time_steps.index(t)
283
+ )
284
+ self.global_time_steps = torch.as_tensor(global_time_steps, device=device)
285
+ self.nb_global_time_steps = self.global_time_steps.shape[0]
286
+ self.global_time_steps_expanded = (
287
+ self.global_time_steps.unsqueeze(0)
288
+ .unsqueeze(-1)
289
+ .repeat((self.nb_patients, 1, 1))
290
+ )
291
+ # Browse the observed data set and store relevant elements
292
+ self.observations_tensors: dict = {}
293
+ self.n_tot_observations_per_output = torch.zeros(self.nb_outputs, device=device)
294
+ for ind, patient in enumerate(self.patients):
295
+ this_patient = processed_df.loc[processed_df["id"] == patient]
296
+
297
+ tasks_indices_np = this_patient["task_index"].values
298
+ tasks_indices = torch.as_tensor(tasks_indices_np, device=device).long()
299
+
300
+ outputs_indices_np = this_patient["output_index"].values
301
+ outputs_indices = torch.as_tensor(outputs_indices_np, device=device).long()
302
+ # Add counts of observations to the total per output
303
+ self.n_tot_observations_per_output.scatter_add_(
304
+ 0,
305
+ outputs_indices,
306
+ torch.ones_like(outputs_indices, device=device, dtype=torch.float64),
307
+ )
308
+
309
+ outputs = torch.as_tensor(this_patient["value"].values, device=device)
310
+
311
+ time_steps = torch.as_tensor(this_patient["time"].values, device=device)
312
+ time_step_indices = torch.as_tensor(
313
+ this_patient["time_step_index"].values, device=device
314
+ ).long()
315
+ p_index_repeated = torch.full(
316
+ outputs.shape, ind, dtype=torch.int64, device=device
317
+ )
318
+
319
+ self.observations_tensors.update(
320
+ {
321
+ patient: {
322
+ "observations": outputs,
323
+ "time_steps": time_steps,
324
+ "time_step_indices": time_step_indices,
325
+ "tasks_indices": tasks_indices,
326
+ "outputs_indices": outputs_indices,
327
+ "p_index_repeated": p_index_repeated,
328
+ }
329
+ }
330
+ )
331
+
332
+ # Build the full data set tensors
333
+ self.full_obs_data = torch.cat(
334
+ [self.observations_tensors[p]["observations"] for p in self.patients]
335
+ ).to(device)
336
+ # Construct the indexing tensors
337
+ self.observation_to_patient_index = (
338
+ torch.cat(
339
+ [
340
+ self.observations_tensors[p]["p_index_repeated"]
341
+ for p in self.patients
342
+ ]
343
+ )
344
+ .long()
345
+ .to(device)
346
+ )
347
+ self.observation_to_timestep_index = (
348
+ torch.cat(
349
+ [
350
+ self.observations_tensors[p]["time_step_indices"]
351
+ for p in self.patients
352
+ ]
353
+ )
354
+ .long()
355
+ .to(device)
356
+ )
357
+ self.observation_to_task_index = (
358
+ torch.cat(
359
+ [self.observations_tensors[p]["tasks_indices"] for p in self.patients]
360
+ )
361
+ .long()
362
+ .to(device)
363
+ )
364
+ # Construct a tuple allowing to index a 3D tensor of outputs into a 1D tensor of outputs
365
+ self.prediction_index = (
366
+ self.observation_to_patient_index,
367
+ self.observation_to_timestep_index,
368
+ self.observation_to_task_index,
369
+ )
370
+
371
+ self.full_output_indices = (
372
+ torch.cat(
373
+ [self.observations_tensors[p]["outputs_indices"] for p in self.patients]
374
+ )
375
+ .long()
376
+ .to(device)
377
+ )
378
+ self.chunk_sizes: list[int] = [
379
+ self.observations_tensors[p]["observations"].cpu().shape[0]
380
+ for p in self.patients
381
+ ]
382
+
383
+ def update_omega(self, omega: torch.Tensor) -> None:
384
+ """Update the covariance matrix of the NLME model."""
385
+ assert (
386
+ self.omega_pop.shape == omega.shape
387
+ ), f"Wrong omega shape: {omega.shape}, expected: {self.omega_pop.shape}"
388
+ self.omega_pop = omega
389
+ self.omega_pop_lower_chol = torch.linalg.cholesky(self.omega_pop).to(device)
390
+ self.eta_distribution = torch.distributions.MultivariateNormal(
391
+ loc=torch.zeros(self.nb_PDU, device=device),
392
+ covariance_matrix=self.omega_pop,
393
+ )
394
+
395
+ def update_res_var(self, residual_var: torch.Tensor) -> None:
396
+ """Update the residual variance of the NLME model, while ensuring it remains positive."""
397
+ assert (
398
+ self.residual_var.shape == residual_var.shape
399
+ ), f"Wrong res var shape: {residual_var.shape}, expected: {self.residual_var.shape}"
400
+ self.residual_var = residual_var.clamp(min=1e-6)
401
+
402
+ def update_betas(self, betas: torch.Tensor) -> None:
403
+ """Update the betas of the NLME model."""
404
+ assert (
405
+ self.population_betas.shape == betas.shape
406
+ ), f"Wrong beta shape: {betas.shape}, expected: {self.population_betas.shape}"
407
+ self.population_betas = betas
408
+
409
+ def update_log_mi(self, log_MI: torch.Tensor) -> None:
410
+ """Update the model intrinsic parameter values of the NLME model."""
411
+ assert (
412
+ self.log_MI.shape == log_MI.shape
413
+ ), f"Wrong MI shape: {log_MI.shape}, expected: {self.log_MI.shape}"
414
+ self.log_MI = log_MI
415
+
416
+ def update_eta_samples(self, eta: torch.Tensor) -> None:
417
+ """Update the model current individual random effect sampels."""
418
+ assert (
419
+ self.current_eta_samples.shape == eta.shape
420
+ ), f"Wrong individual samples shape: {eta.shape}, expected: {self.current_eta_samples.shape}"
421
+ self.current_eta_samples = eta
422
+
423
+ def update_map_estimates(self, theta: torch.Tensor) -> None:
424
+ """Update the model current maximum a posteriori estimates."""
425
+ assert (
426
+ self.current_map_estimates.shape == theta.shape
427
+ ), f"Wrong individual parameters shape: {theta.shape}, expected: {self.current_map_estimates.shape}"
428
+ self.current_map_estimates = theta
429
+
430
+ def sample_individual_etas(self) -> torch.Tensor:
431
+ """Sample individual random effects from the current estimate of Omega
432
+
433
+ Returns:
434
+ torch.Tensor (size nb_patients x nb_PDUs): individual random effects for all patients in the population
435
+ """
436
+ etas_dist = self.eta_distribution.expand([self.nb_patients])
437
+ etas = etas_dist.sample()
438
+ return etas
439
+
440
+ def individual_parameters(
441
+ self,
442
+ individual_etas: torch.Tensor,
443
+ ) -> torch.Tensor:
444
+ """Compute individual patient parameters
445
+
446
+ Transforms log(MI) (Model intrinsic), betas: log(mu)s & coeffs for covariates and individual random effects (etas) into individual parameters (theta_i), for each set of etas of the list and corresponding design matrix.
447
+ Assumes log-normal distribution for individual parameters and covariate effects: theta_i[PDU] = mu_pop * exp(eta_i) * exp(covariates_i * cov_coeffs) where eta_i is from N(0, Omega) and theta_i[MI]=MI.
448
+
449
+ Args:
450
+ individual_etas (torch.Tensor): one set of sampled random effects for each patient
451
+ ind_ids_for_etas (list[Union[str, int]]): list of individual ids corresponding to the sampled etas, used to fetch the design matrices
452
+ Returns:
453
+ torch.Tensor [nb_patients x nb_parameters]: One parameter set for each patient. Dim 0 corresponds to the patients, dim 1 is the parameters
454
+ """
455
+ # Compute the inidividual PDU
456
+ log_thetas_PDU = (
457
+ self.full_design_matrix @ self.population_betas + individual_etas
458
+ )
459
+ # Gather the MI values, and expand them (same for each patient)
460
+ log_MI_expanded = self.log_MI.unsqueeze(0).repeat(self.nb_patients, 1)
461
+
462
+ # list the PDK values for each patient, and assemble them in a tensor
463
+ # This step is crucial: we need to ensure the parameters are stored in the correct order
464
+ # PDK, PDU, MI
465
+ thetas = torch.cat(
466
+ (
467
+ self.patients_pdk_full,
468
+ torch.exp(torch.cat((log_thetas_PDU, log_MI_expanded), dim=1)),
469
+ ),
470
+ dim=1,
471
+ )
472
+ return thetas
473
+
474
+ def struc_model_inputs_from_theta(self, thetas: torch.Tensor) -> torch.Tensor:
475
+ """Return model inputs for all patients
476
+
477
+ Args:
478
+ thetas (torch.Tensor): Parameter values per patient (one by row)
479
+
480
+ Returns:
481
+ torch.Tensor: the full inputs required to simulate all patients on all time steps
482
+ """
483
+
484
+ if not hasattr(self, "observations_tensors"):
485
+ raise ValueError(
486
+ "Cannot compute patient predictions without an associated observations data frame."
487
+ )
488
+ # Order the columns of theta, and add a repeat dimension to cover time steps
489
+ theta_expanded = (
490
+ thetas[:, self.model_input_to_descriptor]
491
+ .unsqueeze(1)
492
+ .repeat((1, self.nb_global_time_steps, 1))
493
+ )
494
+ full_inputs = torch.cat(
495
+ (
496
+ theta_expanded,
497
+ self.global_time_steps_expanded,
498
+ ),
499
+ dim=2,
500
+ )
501
+ assert full_inputs.shape == (
502
+ self.nb_patients,
503
+ self.nb_global_time_steps,
504
+ self.nb_descriptors + 1,
505
+ )
506
+ return full_inputs
507
+
508
+ def predict_outputs_from_theta(
509
+ self, thetas: torch.Tensor
510
+ ) -> tuple[torch.Tensor, torch.Tensor]:
511
+ """Return model predictions for all patients
512
+
513
+ Args:
514
+ thetas (torch.Tensor): Parameter values per patient (one by row)
515
+
516
+ Returns:
517
+ list[torch.Tensor]: a tensor of predictions for each patient
518
+ """
519
+ model_inputs = self.struc_model_inputs_from_theta(thetas)
520
+ # shape: [nb_patients, nb_time_steps, nb_params + 1]
521
+ pred_mean, pred_var = self.structural_model.simulate(
522
+ model_inputs,
523
+ self.prediction_index,
524
+ self.chunk_sizes,
525
+ )
526
+ return pred_mean, pred_var
527
+
528
+ def add_residual_error(self, outputs: torch.Tensor) -> torch.Tensor:
529
+ res_var = self.residual_var.index_select(0, self.full_output_indices)
530
+ noise = torch.distributions.Normal(
531
+ torch.zeros(self.full_output_indices.shape[0], device=device), res_var
532
+ ).sample()
533
+ if self.error_model_type == "additive":
534
+ new_out = outputs + noise
535
+ elif self.error_model_type == "proportional":
536
+ new_out = outputs * noise
537
+ else:
538
+ raise ValueError(f"Non-implemented error model {self.error_model_type}")
539
+ return new_out
540
+
541
+ def outputs_to_df(self, outputs: torch.Tensor) -> pd.DataFrame:
542
+ """Transform the NLME model outputs to a data frame in order to compare with observed data
543
+
544
+ Args:
545
+ outputs (torch.Tensor): Outputs from `self.predict_outputs_from_theta`
546
+
547
+ Returns:
548
+ pd.DataFrame: A data frame containing the following columns
549
+ - `id`
550
+ - `output_name`
551
+ - `protocol_arm`
552
+ - `time`
553
+ - `predicted_value`
554
+ """
555
+ outputs_list = outputs.cpu().split(self.chunk_sizes)
556
+ df_list = []
557
+ for ind_idx, ind in enumerate(self.patients):
558
+ time_steps = self.observations_tensors[ind]["time_steps"]
559
+ task_list = self.observations_tensors[ind]["tasks_indices"]
560
+ temp_df = pd.DataFrame(
561
+ {
562
+ "time": time_steps.cpu().numpy(),
563
+ "id": ind,
564
+ "task_index": task_list.cpu(),
565
+ "predicted_value": outputs_list[ind_idx].numpy(),
566
+ }
567
+ )
568
+ temp_df["output_name"] = temp_df["task_index"].apply(
569
+ lambda t: self.outputs_names[
570
+ self.structural_model.task_idx_to_output_idx[t]
571
+ ]
572
+ )
573
+ temp_df["protocol_arm"] = temp_df["task_index"].apply(
574
+ lambda t: self.structural_model.task_idx_to_protocol[t]
575
+ )
576
+ df_list.append(temp_df)
577
+ out_df = pd.concat(df_list)
578
+ out_df = out_df.drop(columns=["task_index"])
579
+ return out_df
580
+
581
+ def _log_prior_etas(self, etas: torch.Tensor) -> torch.Tensor:
582
+ """Compute log-prior of random effect samples (etas)
583
+
584
+ Args:
585
+ etas (torch.Tensor): Individual samples, assuming eta_i ~ N(0, Omega)
586
+
587
+ Returns:
588
+ torch.Tensor [nb_eta_i x nb_PDU]: Values of log-prior, computed according to:
589
+
590
+ P(eta) = (1/sqrt((2pi)^k * |Omega|)) * exp(-0.5 * eta.T * omega.inv * eta)
591
+ log P(eta) = -0.5 * (k * log(2pi) + log|Omega| + eta.T * omega.inv * eta)
592
+
593
+ """
594
+
595
+ log_priors: torch.Tensor = self.eta_distribution.log_prob(etas).to(device)
596
+ return log_priors
597
+
598
+ def log_posterior_etas(
599
+ self,
600
+ etas: torch.Tensor,
601
+ ) -> tuple[
602
+ torch.Tensor,
603
+ torch.Tensor,
604
+ torch.Tensor,
605
+ torch.Tensor,
606
+ list,
607
+ ]:
608
+ """Compute the log-posterior of a list of random effects
609
+
610
+ Args:
611
+ etas (torch.Tensor): Random effects samples
612
+
613
+ Returns:
614
+ tuple[torch.Tensor, list[torch.Tensor], DataFrame]:
615
+ - log-posterior likelihood of etas
616
+ - current thetas
617
+ - log values of current pdu estimation (useful for SAEM)
618
+ - list of simulated values for each patient
619
+
620
+ """
621
+ if not hasattr(self, "observations_tensors"):
622
+ raise ValueError(
623
+ "Cannot compute log-posterior without an associated observations data frame."
624
+ )
625
+ # Get individual parameters in a tensor
626
+ individual_params: torch.Tensor = self.individual_parameters(
627
+ individual_etas=etas,
628
+ )
629
+ # Run the surrogate model
630
+ full_pred, full_var = self.predict_outputs_from_theta(individual_params)
631
+ var_list = full_var.split(self.chunk_sizes)
632
+ # Validate the variance magnitude
633
+ warnings = self.variance_level_check(var_list, self.pred_var_threshold)
634
+ flagged_patients = [individual_params[i, :].tolist() for i in warnings]
635
+
636
+ # calculate log-prior of the random samples
637
+ log_priors: torch.Tensor = self._log_prior_etas(etas)
638
+
639
+ # group by individual and calculate log-likelihood for each
640
+ log_likelihood_observations = self.log_likelihood_observation(full_pred)
641
+
642
+ log_posterior = log_likelihood_observations + log_priors
643
+ current_log_pdu = torch.log(
644
+ individual_params[:, self.nb_PDK : self.nb_PDK + self.nb_PDU]
645
+ )
646
+ return (
647
+ log_posterior,
648
+ individual_params,
649
+ current_log_pdu,
650
+ full_pred,
651
+ flagged_patients,
652
+ )
653
+
654
+ def calculate_residuals(
655
+ self, observed_data: torch.Tensor, predictions: torch.Tensor
656
+ ) -> torch.Tensor:
657
+ """Calculates residuals based on the error model for a single patient
658
+
659
+ Args:
660
+ observed_data: Tensor of observations
661
+ predictions: Tensor of predictions
662
+
663
+ Returns:
664
+ torch.Tensor: a tensor of residual values
665
+ """
666
+ if self.error_model_type == "additive":
667
+ return observed_data - predictions
668
+ elif self.error_model_type == "proportional":
669
+ return (observed_data - predictions) / predictions
670
+ else:
671
+ raise ValueError("Unsupported error model type.")
672
+
673
+ def sum_sq_residuals(self, prediction: torch.Tensor) -> torch.Tensor:
674
+ sq_residuals = torch.square(
675
+ self.calculate_residuals(self.full_obs_data, prediction)
676
+ )
677
+ sum_residuals = torch.zeros(self.nb_outputs, device=device)
678
+ sum_residuals.scatter_add_(0, self.full_output_indices, sq_residuals)
679
+ return sum_residuals
680
+
681
+ def log_likelihood_observation(
682
+ self,
683
+ predictions: torch.Tensor,
684
+ ) -> torch.Tensor:
685
+ """
686
+ Calculates the log-likelihood of observations given predictions and error model, assuming errors follow N(0,sqrt(residual_error_var))
687
+ observed_data: torch.Tensor of observations for one individual
688
+ predictions: torch.Tensor of predictions for one individual organized in the same way as observed_data
689
+ residual_error_var: torch.Tensor of the error for each output, dim: [nb_outputs]
690
+ """
691
+ residuals: torch.Tensor = self.calculate_residuals(
692
+ self.full_obs_data, predictions
693
+ )
694
+ res_error_var = self.residual_var.index_select(0, self.full_output_indices)
695
+ # Log-likelihood of normal distribution
696
+ if self.error_model_type == "additive":
697
+ log_lik_full = -0.5 * (
698
+ torch.log(2 * torch.pi * res_error_var) + (residuals**2 / res_error_var)
699
+ )
700
+ elif self.error_model_type == "proportional":
701
+ variance = res_error_var * torch.square(predictions)
702
+ log_lik_full = -0.5 * (
703
+ torch.log(2 * torch.pi * variance) + (residuals**2 / variance)
704
+ )
705
+ else:
706
+ raise ValueError("Non supported error type.")
707
+ log_lik_per_patient = torch.zeros(self.nb_patients, device=device)
708
+ log_lik_per_patient.scatter_add_(
709
+ 0, self.observation_to_patient_index, log_lik_full
710
+ )
711
+ return log_lik_per_patient
712
+
713
+ def mh_step(
714
+ self,
715
+ current_etas: torch.Tensor,
716
+ current_log_prob: torch.Tensor,
717
+ current_pred: torch.Tensor,
718
+ current_pdu: torch.Tensor,
719
+ current_thetas: torch.Tensor,
720
+ step_size: float,
721
+ learning_rate: float,
722
+ target_acceptance_rate: float = 0.234,
723
+ verbose: bool = False,
724
+ ) -> tuple[
725
+ torch.Tensor,
726
+ torch.Tensor,
727
+ torch.Tensor,
728
+ torch.Tensor,
729
+ torch.Tensor,
730
+ torch.Tensor,
731
+ float,
732
+ list,
733
+ ]:
734
+ """Perform one step of a Metropolis-Hastings transition kernel
735
+
736
+ Args:
737
+ current_etas (torch.Tensor): values of the individual random effects for all patients
738
+ current_log_prob (torch.Tensor): log posterior likelihood of current random effects
739
+ current_pred (list[torch.Tensor]): associated model predictions with current random effects
740
+ step_size (torch.Tensor): current value of MH step size,
741
+ learning_rate (float): current learning rate (defined by the optimization algorithm)
742
+ target_acceptance_rate (float, optional): Target for the MCMC acceptance rate. Defaults to 0.234 [1].
743
+
744
+ [1] Sherlock C. Optimal Scaling of the Random Walk Metropolis: General Criteria for the 0.234 Acceptance Rule. Journal of Applied Probability. 2013;50(1):1-15. doi:10.1239/jap/1363784420
745
+
746
+ Returns:
747
+ tuple[torch.Tensor, torch.Tensor, dict[int | str, torch.Tensor], torch.Tensor, float]:
748
+ - updated individual random effects
749
+ - updated log posterior likelihood
750
+ - updated predictions, for each patient of the observation data set
751
+ - updated thetas
752
+ - updated values of log PDUs
753
+ - updated step size
754
+ - a dict of warnings for all patients with predictive variance above threshold
755
+ """
756
+
757
+ proposal_noise = (
758
+ torch.randn_like(current_etas, device=device) @ self.omega_pop_lower_chol
759
+ )
760
+ proposal_etas = current_etas + step_size * proposal_noise
761
+ (
762
+ proposal_log_prob,
763
+ proposal_theta,
764
+ proposal_log_pdus,
765
+ proposal_pred,
766
+ warnings,
767
+ ) = self.log_posterior_etas(proposal_etas)
768
+ deltas: torch.Tensor = proposal_log_prob - current_log_prob
769
+ log_u: torch.Tensor = torch.log(torch.rand_like(deltas, device=device))
770
+ accept_mask: torch.Tensor = log_u < deltas
771
+ accept_mask_parameters = accept_mask.unsqueeze(-1).expand(
772
+ -1, current_etas.shape[1]
773
+ )
774
+ accept_mask_predictions = accept_mask.index_select(
775
+ 0, self.observation_to_patient_index
776
+ )
777
+ new_etas = torch.where(accept_mask_parameters, proposal_etas, current_etas).to(
778
+ device
779
+ )
780
+ new_log_pdus = torch.where(
781
+ accept_mask_parameters, proposal_log_pdus, current_pdu
782
+ ).to(device)
783
+ new_log_prob = torch.where(accept_mask, proposal_log_prob, current_log_prob).to(
784
+ device
785
+ )
786
+ new_complete_likelihood = -2 * new_log_prob.sum(dim=0)
787
+ new_pred = torch.where(accept_mask_predictions, proposal_pred, current_pred).to(
788
+ device
789
+ )
790
+ new_acceptance_rate: float = accept_mask.cpu().float().mean().item()
791
+ if verbose:
792
+ print(f" Acceptance rate: {new_acceptance_rate:.2f}")
793
+ new_step_size: float = step_size * np.exp(
794
+ learning_rate * (new_acceptance_rate - target_acceptance_rate)
795
+ )
796
+ new_theta = torch.where(
797
+ accept_mask.unsqueeze(1).expand(-1, current_thetas.shape[1]),
798
+ proposal_theta,
799
+ current_thetas,
800
+ ).to(device)
801
+ return (
802
+ new_etas,
803
+ new_log_prob,
804
+ new_complete_likelihood,
805
+ new_pred,
806
+ new_theta,
807
+ new_log_pdus,
808
+ new_step_size,
809
+ warnings,
810
+ )
811
+
812
+ def map_estimates_descriptors(self) -> pd.DataFrame:
813
+ theta = self.current_map_estimates
814
+ if theta is None:
815
+ raise ValueError("No estimation available yet. Run the algorithm first.")
816
+
817
+ map_per_patient = pd.DataFrame(
818
+ data=theta.cpu().numpy(), columns=self.descriptors
819
+ )
820
+ return map_per_patient
821
+
822
+ def map_estimates_predictions(self) -> pd.DataFrame:
823
+ theta = self.current_map_estimates
824
+ if theta is None:
825
+ raise ValueError(
826
+ "No estimation available yet. Run the optimization algorithm first."
827
+ )
828
+ simulated_tensor, _ = self.predict_outputs_from_theta(theta)
829
+ simulated_df = self.outputs_to_df(simulated_tensor)
830
+ return simulated_df
831
+
832
+ def variance_level_check(
833
+ self, var_list: tuple[torch.Tensor, ...], threshold: float
834
+ ) -> list:
835
+ warnings = []
836
+ for i in range(self.nb_patients):
837
+ var, _ = var_list[i].cpu().max(0)
838
+ if var > threshold:
839
+ warnings.append(i)
840
+ return warnings