vpop-calibration 2.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,945 @@
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ from scipy.optimize import minimize
4
+ from tqdm.notebook import tqdm
5
+ from typing import Union, Optional
6
+ from pandas import DataFrame
7
+ import pandas as pd
8
+ import numpy as np
9
+ from IPython.display import display
10
+
11
+ from .utils import smoke_test, device
12
+ from .nlme import NlmeModel
13
+ from .structural_model import StructuralGp
14
+
15
+
16
+ # Main SAEM Algorithm Class
17
+ class PySaem:
18
+ def __init__(
19
+ self,
20
+ model: NlmeModel,
21
+ observations_df: DataFrame,
22
+ # MCMC parameters for the E-step
23
+ mcmc_first_burn_in: int = 5,
24
+ mcmc_nb_transitions: int = 1,
25
+ nb_phase1_iterations: int = 100,
26
+ nb_phase2_iterations: Union[int, None] = None,
27
+ convergence_threshold: float = 1e-4,
28
+ patience: int = 5,
29
+ learning_rate_power: float = 0.8,
30
+ annealing_factor: float = 0.95,
31
+ init_step_size: float = 0.5, # stick to the 0.1 - 1 range
32
+ verbose: bool = False,
33
+ optim_max_fun: int = 50,
34
+ live_plot: bool = True,
35
+ plot_frames: int = 20,
36
+ plot_columns: int = 3,
37
+ plot_indiv_figsize: tuple[float, float] = (3.0, 1.2),
38
+ true_log_MI: Optional[dict[str, float]] = None,
39
+ true_log_PDU: Optional[dict[str, dict[str, float | bool]]] = None,
40
+ true_res_var: Optional[list[float]] = None,
41
+ true_covariates: Optional[dict[str, dict[str, dict[str, str | float]]]] = None,
42
+ ):
43
+ """Instantiate an SAEM optimizer for an NLME model
44
+
45
+ Args:
46
+ model (NlmeModel): The model to be optimized
47
+ observations_df (DataFrame): The data set containing observations
48
+ mcmc_first_burn_in (int, optional): Number of discarded samples in the first iteration. Defaults to 5.
49
+ mcmc_nb_transitions (int, optional): Number of kernel transitions computed at each iteration. Defaults to 1.
50
+ nb_phase1_iterations (int, optional): Number of iterations in the exploration phase. Defaults to 100.
51
+ nb_phase2_iterations (Union[int, None], optional): Number of iterations in the convergence phase. Defaults to None, implying nb_phase_2 = nb_phase_1.
52
+ convergence_threshold (float, optional): Estimated parameter relative change threshold considered for convergence. Defaults to 1e-4.
53
+ patience (int, optional): Number of iterations of consecutive low relative change considered for early stopping of the algorithm. Defaults to 5.
54
+ learning_rate_power (float, optional): Exponential decay exponent for the M-step learning rate (stochastic approximation). Defaults to 0.8.
55
+ annealing_factor (float, optional): Exploration phase annealing factor for residual and parameter variance. Defaults to 0.95.
56
+ init_step_size (float, optional): Initial MCMC step size scaling factor. Defaults to 0.5.
57
+ optim_max_fun(int): Maximum number of function calls in the scipy.optimize (used for model intrinsic parameters calibration). Defaults to 50.
58
+ verbose (bool): Print various info during iterations. Defaults to False.
59
+ live_plot (bool): Print and update a plot of parameters during iterations. Defaults to True.
60
+ plot_frames (int): Frequency at which the live plot should be updated (number of iterations). The lower the slower. Defaults to 20.
61
+ plot_columns (int): Number of columns to display the convergence plot. Defaults to 3.
62
+ plot_indiv_figsize (tuple[float,float]): individual figure size in the convergence plot (width, height).
63
+ """
64
+
65
+ self.model: NlmeModel = model
66
+ self.model.add_observations(observations_df)
67
+ # MCMC sampling in the E-step parameters
68
+ self.mcmc_first_burn_in: int = mcmc_first_burn_in
69
+ self.mcmc_nb_transitions: int = mcmc_nb_transitions
70
+ # SAEM iteration parameters
71
+ # phase 1 = exploratory: learning rate = 0 and simulated annealing on
72
+ # phase 2 = smoothing: learning rate 1/phase2_iter^factor
73
+ if smoke_test:
74
+ self.nb_phase1_iterations = 1
75
+ self.nb_phase2_iterations = 2
76
+ else:
77
+ self.nb_phase1_iterations: int = nb_phase1_iterations
78
+ self.nb_phase2_iterations: int = (
79
+ nb_phase2_iterations
80
+ if nb_phase2_iterations is not None
81
+ else nb_phase1_iterations
82
+ )
83
+ self.current_phase = 1
84
+
85
+ # convergence parameters
86
+ self.convergence_threshold: float = convergence_threshold
87
+ self.patience: int = patience
88
+ self.consecutive_converged_iters: int = 0
89
+
90
+ # Numerical parameters that depend on the iterations phase
91
+ # The learning rate for the step-size adaptation in E-step sampling
92
+ self.step_size: float = init_step_size / np.sqrt(self.model.nb_PDU)
93
+ self.init_step_size_adaptation: float = 0.5
94
+ self.step_size_learning_rate_power: float = 0.5
95
+
96
+ # The learning rate for the stochastic approximation in the M-step
97
+ self.learning_rate_m_step: float = 1.0
98
+ self.learning_rate_power: float = learning_rate_power
99
+ self.annealing_factor: float = annealing_factor
100
+
101
+ # Initialize the learning rate and step size adaptation rate
102
+ self.learning_rate_m_step, self.step_size_adaptation = (
103
+ self._compute_learning_rates(0)
104
+ )
105
+
106
+ self.verbose = verbose
107
+ if smoke_test:
108
+ self.optim_max_fun = 1
109
+ else:
110
+ self.optim_max_fun = optim_max_fun
111
+ self.live_plot = live_plot
112
+ self.plot_frames = plot_frames
113
+ self.plot_columns = plot_columns
114
+ self.plot_indiv_figsize = plot_indiv_figsize
115
+
116
+ # Initialize the random effects to 0
117
+ self.current_etas = self.model.current_eta_samples
118
+
119
+ # Initialize current estimation of patient parameters from the 0 random effects
120
+ (
121
+ self.current_log_prob,
122
+ self.current_thetas,
123
+ self.current_log_pdu,
124
+ self.current_pred,
125
+ flagged_patients,
126
+ ) = self.model.log_posterior_etas(self.current_etas)
127
+ self.current_complete_likelihood = torch.exp(
128
+ self.current_log_prob.sum(dim=0)
129
+ ).to(device)
130
+
131
+ # Initialize the optimizer history
132
+ self._init_history(
133
+ self.model.population_betas,
134
+ self.model.omega_pop,
135
+ self.model.log_MI,
136
+ self.model.residual_var,
137
+ self.current_complete_likelihood,
138
+ flagged_patients,
139
+ )
140
+ self.current_iteration: int = 0
141
+
142
+ # Initialize the values for convergence checks
143
+ self.prev_params: dict[str, torch.Tensor] = {
144
+ "log_MI": self.model.log_MI,
145
+ "population_betas": self.model.population_betas,
146
+ "population_omega": self.model.omega_pop,
147
+ "residual_error_var": self.model.residual_var,
148
+ }
149
+
150
+ # pre-compute full design matrix once
151
+ self.X = torch.stack(
152
+ [self.model.design_matrices[ind] for ind in self.model.patients],
153
+ dim=0,
154
+ ).to(device)
155
+ # Precompute the gram matrix
156
+ self.sufficient_stat_gram_matrix = (
157
+ torch.matmul(self.X.transpose(1, 2), self.X).sum(dim=0).to(device)
158
+ )
159
+
160
+ # Initialize sufficient statistics
161
+ self.sufficient_stat_cross_product = (
162
+ (self.X.transpose(1, 2) @ self.current_log_pdu.unsqueeze(-1))
163
+ .sum(dim=0)
164
+ .to(device)
165
+ )
166
+ self.sufficient_stat_outer_product = torch.matmul(
167
+ self.current_log_pdu.transpose(0, 1), self.current_log_pdu
168
+ ).to(device)
169
+
170
+ self.true_log_MI = true_log_MI
171
+ self.true_log_PDUs = true_log_PDU
172
+ if true_covariates is not None:
173
+ self.true_cov = {
174
+ str(cov["coef"]): float(cov["value"])
175
+ for item in true_covariates.values()
176
+ for cov in item.values()
177
+ }
178
+ if true_res_var is not None:
179
+ self.true_res_var = {
180
+ self.model.outputs_names[k]: val for k, val in enumerate(true_res_var)
181
+ }
182
+
183
+ if isinstance(self.model.structural_model, StructuralGp):
184
+ self.training_ranges = {}
185
+ training_samples = np.log(
186
+ self.model.structural_model.gp_model.data.full_df_raw[
187
+ self.model.PDU_names + self.model.MI_names
188
+ ]
189
+ )
190
+ train_min = training_samples.min(axis=0)
191
+ train_max = training_samples.max(axis=0)
192
+ for param in self.model.PDU_names + self.model.MI_names:
193
+ self.training_ranges.update(
194
+ {param: {"low": train_min[param], "high": train_max[param]}}
195
+ )
196
+
197
+ def m_step_update(
198
+ self,
199
+ log_pdu: torch.Tensor,
200
+ s_cross_product: torch.Tensor,
201
+ s_outer_product: torch.Tensor,
202
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
203
+ """Perform the M-step update
204
+
205
+ Args:
206
+ log_pdu (torch.Tensor): Current estimation of the log-scaled parameters
207
+ s_cross_product (torch.Tensor): Current sufficient statistics 1 - cross product
208
+ s_outer_product (torch.Tensor): Current sufficient statistics 2 - outer product
209
+
210
+ Returns:
211
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: Updated value for
212
+ - sufficient statistics: cross product
213
+ - sufficient statistics: outer product
214
+ - beta parameters
215
+ - omega matrix
216
+ """
217
+
218
+ assert log_pdu.shape[0] == self.X.shape[0]
219
+ cross_product = (
220
+ (self.X.transpose(1, 2) @ log_pdu.unsqueeze(-1)).sum(dim=0).to(device)
221
+ )
222
+ new_s_cross_product = self._stochastic_approximation(
223
+ s_cross_product, cross_product
224
+ )
225
+ outer_product = torch.matmul(log_pdu.transpose(0, 1), log_pdu).to(device)
226
+ new_s_outer_product = self._stochastic_approximation(
227
+ s_outer_product, outer_product
228
+ )
229
+
230
+ new_beta = torch.linalg.solve(
231
+ self.sufficient_stat_gram_matrix, new_s_cross_product
232
+ ).to(device)
233
+
234
+ new_log_pdu = torch.matmul(self.X, new_beta.unsqueeze(0)).squeeze(-1).to(device)
235
+ # Propose a new value for omega
236
+ new_omega = (
237
+ 1
238
+ / self.model.nb_patients
239
+ * (
240
+ new_s_outer_product
241
+ - torch.matmul(new_log_pdu.transpose(0, 1), new_log_pdu)
242
+ )
243
+ ).to(device)
244
+ new_omega = self._clamp_eigen_values(new_omega)
245
+
246
+ return (
247
+ new_s_cross_product,
248
+ new_s_outer_product,
249
+ new_beta.squeeze(-1),
250
+ new_omega,
251
+ )
252
+
253
+ def _check_convergence(self, new_params: dict[str, torch.Tensor]) -> bool:
254
+ """Checks for convergence based on the relative change in parameters."""
255
+ all_converged = True
256
+ for name, current_val in new_params.items():
257
+ if current_val.shape[0] > 0:
258
+ prev_val = self.prev_params[name]
259
+ abs_diff = torch.abs(current_val - prev_val)
260
+ abs_sum = torch.abs(current_val) + torch.abs(prev_val) + 1e-9
261
+ relative_change = abs_diff / abs_sum
262
+ if torch.any(relative_change > self.convergence_threshold):
263
+ all_converged = False
264
+ break
265
+ return all_converged
266
+
267
+ def _compute_learning_rates(self, iteration: int) -> tuple[float, float]:
268
+ """
269
+ Calculates the SAEM learning rate (alpha_k) and Metropolis Hastings step-size (gamma_k).
270
+
271
+ Phase 1:
272
+ alpha_k = 1 (exploration)
273
+ gamma_k = c_0 / k^(0.5) , c0 = init_step_size_adaptation / sqrt(n_PDU)
274
+ Phase 2:
275
+ alpha_k = 1 / (iteration - phase1_iterations + 1) ^ exponent (the iteration index in phase 2)
276
+ gamma_k = 0
277
+ """
278
+ if iteration < self.nb_phase1_iterations:
279
+ learning_rate_m_step = 1.0
280
+ learning_rate_e_step = self.init_step_size_adaptation / (
281
+ np.maximum(1, iteration) ** 0.5
282
+ )
283
+ else:
284
+ learning_rate_m_step = 1.0 / (
285
+ (iteration - self.nb_phase1_iterations + 1) ** self.learning_rate_power
286
+ )
287
+ learning_rate_e_step = 0
288
+ return learning_rate_m_step, learning_rate_e_step
289
+
290
+ def _stochastic_approximation(
291
+ self, previous: torch.Tensor, new: torch.Tensor
292
+ ) -> torch.Tensor:
293
+ """Perform stochastic approximation
294
+
295
+ Args:
296
+ previous (torch.Tensor): The current value of the tensor
297
+ new (torch.Tensor): The target value of the tensor
298
+
299
+ Returns:
300
+ torch.Tensor: (1 - learning_rate) * previous + learning_rate * new
301
+ """
302
+ assert (
303
+ previous.shape == new.shape
304
+ ), f"Wrong shape in stochastic approximation: {previous.shape}, {new.shape}"
305
+ stochastic_approx = (
306
+ (1 - self.learning_rate_m_step) * previous + self.learning_rate_m_step * new
307
+ ).to(device)
308
+ return stochastic_approx
309
+
310
+ def _simulated_annealing(
311
+ self, current: torch.Tensor, target: torch.Tensor
312
+ ) -> torch.Tensor:
313
+ """Perform simulated annealing
314
+
315
+ This function allows to constrain the reduction of certain values by a factor stored in self.annealing_factor
316
+
317
+ Args:
318
+ current (torch.Tensor): Current value of the tensor
319
+ target (torch.Tensor): Target value of the tensor
320
+
321
+ Returns:
322
+ torch.Tensor: maximum(annealing_factor * current, target)
323
+ """
324
+ return torch.maximum(self.annealing_factor * current, target).to(device)
325
+
326
+ def _clamp_eigen_values(self, omega: torch.Tensor, min_eigenvalue: float = 1e-6):
327
+ """
328
+ Project a matrix onto the cone of Positive Definite matrices.
329
+ """
330
+ # 1. Ensure symmetry (sometimes float error breaks symmetry slightly)
331
+ omega = (0.5 * (omega + omega.T)).to(device)
332
+
333
+ # 2. Eigen Decomposition
334
+ L, V = torch.linalg.eigh(omega)
335
+
336
+ # 3. Clamp eigenvalues
337
+ L_clamped = torch.clamp(L, min=min_eigenvalue)
338
+
339
+ # 4. Reconstruct
340
+ matrix_spd = torch.matmul(V, torch.matmul(torch.diag(L_clamped), V.T))
341
+
342
+ return matrix_spd.to(device)
343
+
344
+ def _init_history(
345
+ self,
346
+ beta: torch.Tensor,
347
+ omega: torch.Tensor,
348
+ log_mi: torch.Tensor,
349
+ res_var: torch.Tensor,
350
+ likelihood: torch.Tensor,
351
+ flagged_patients: list,
352
+ ) -> None:
353
+ # Initialize the history
354
+ self.history = {}
355
+ # Add the pdus (mean, variance)
356
+ for i, pdu in enumerate(self.model.PDU_names):
357
+ beta_index = self.model.population_betas_names.index(pdu)
358
+ self.history.update(
359
+ {
360
+ pdu: {
361
+ "mu": [beta[beta_index].cpu()],
362
+ "sigma_sq": [omega[i, i].cpu()],
363
+ }
364
+ }
365
+ )
366
+ # Add the covariates
367
+ for i, cov in enumerate(self.model.covariate_coeffs_names):
368
+ beta_index = self.model.population_betas_names.index(cov)
369
+ self.history.update({cov: [beta[beta_index].cpu()]})
370
+ # Add Omega
371
+ self.history.update({"omega": [omega.cpu()]})
372
+ # Add the model intrinsic params
373
+ for i, mi in enumerate(self.model.MI_names):
374
+ self.history.update({mi: [log_mi[i].cpu()]})
375
+ # Add the residual variance
376
+ for i, output in enumerate(self.model.outputs_names):
377
+ self.history.update({output: [res_var[i].cpu()]})
378
+ self.history.update({"complete_likelihood": [likelihood.cpu()]})
379
+ self.history.update({"flagged_patients": [flagged_patients]})
380
+
381
+ def _append_history(
382
+ self,
383
+ beta: torch.Tensor,
384
+ omega: torch.Tensor,
385
+ log_mi: torch.Tensor,
386
+ res_var: torch.Tensor,
387
+ complete_likelihood: torch.Tensor,
388
+ flagged_patients: list,
389
+ ) -> None:
390
+ # Update the history
391
+ for i, pdu in enumerate(self.model.PDU_names):
392
+ beta_index = self.model.population_betas_names.index(pdu)
393
+ self.history[pdu]["mu"].append(beta[beta_index].cpu())
394
+ self.history[pdu]["sigma_sq"].append(omega[i, i].cpu())
395
+
396
+ for i, cov in enumerate(self.model.covariate_coeffs_names):
397
+ beta_index = self.model.population_betas_names.index(cov)
398
+ self.history[cov].append(beta[beta_index].cpu())
399
+
400
+ self.history["omega"].append(omega.cpu())
401
+
402
+ for i, mi in enumerate(self.model.MI_names):
403
+ self.history[mi].append(log_mi[i].cpu())
404
+
405
+ for i, output in enumerate(self.model.outputs_names):
406
+ self.history[output].append(res_var[i].cpu())
407
+ self.history["complete_likelihood"].append(complete_likelihood.cpu())
408
+ self.history["flagged_patients"].append(flagged_patients)
409
+
410
+ def one_iteration(self, k: int) -> bool:
411
+ """Perform one iteration of SAEM
412
+
413
+ Args:
414
+ k (int): the iteration number
415
+ """
416
+
417
+ if self.verbose:
418
+ print(f"Running iteration {k}")
419
+ # If first iteration, consider burn in
420
+ if k == 0:
421
+ current_iter_burn_in = self.mcmc_first_burn_in
422
+ else:
423
+ current_iter_burn_in = 0
424
+
425
+ self.learning_rate_m_step, self.step_size_adaptation = (
426
+ self._compute_learning_rates(k)
427
+ )
428
+
429
+ # --- E-step: perform MCMC kernel transitions
430
+ if self.verbose:
431
+ print(" MCMC sampling")
432
+ print(
433
+ f" Current MCMC parameters: step-size={self.step_size:.2f}, adaptation rate={self.step_size_adaptation:.2f}"
434
+ )
435
+
436
+ flagged_patients_iter = []
437
+ for _ in range(current_iter_burn_in + self.mcmc_nb_transitions):
438
+ (
439
+ self.current_etas,
440
+ self.current_log_prob,
441
+ self.current_complete_likelihood,
442
+ self.current_pred,
443
+ self.current_thetas,
444
+ self.current_log_pdu,
445
+ self.step_size,
446
+ flagged_patients,
447
+ ) = self.model.mh_step(
448
+ current_etas=self.current_etas,
449
+ current_log_prob=self.current_log_prob,
450
+ current_pred=self.current_pred,
451
+ current_thetas=self.current_thetas,
452
+ current_pdu=self.current_log_pdu,
453
+ step_size=self.step_size,
454
+ learning_rate=self.step_size_adaptation,
455
+ verbose=self.verbose,
456
+ )
457
+ flagged_patients_iter += flagged_patients
458
+
459
+ # Update the model's eta and thetas
460
+ self.model.update_eta_samples(self.current_etas)
461
+ self.model.update_map_estimates(self.current_thetas)
462
+
463
+ # --- M-Step: Update Population Means, Omega and Residual variance ---
464
+
465
+ # 1. Update residual error variances
466
+ sum_sq_res = self.model.sum_sq_residuals(self.current_pred)
467
+ target_res_var: torch.Tensor = (
468
+ sum_sq_res / self.model.n_tot_observations_per_output
469
+ )
470
+ current_res_var: torch.Tensor = self.model.residual_var
471
+ if k < self.nb_phase1_iterations:
472
+ target_res_var = self._simulated_annealing(current_res_var, target_res_var)
473
+
474
+ new_residual_error_var = self._stochastic_approximation(
475
+ current_res_var, target_res_var
476
+ )
477
+
478
+ self.model.update_res_var(new_residual_error_var)
479
+
480
+ # 2. Update sufficient statistics with stochastic approximation
481
+ (
482
+ self.sufficient_stat_cross_product,
483
+ self.sufficient_stat_outer_product,
484
+ new_beta,
485
+ new_omega,
486
+ ) = self.m_step_update(
487
+ self.current_log_pdu,
488
+ self.sufficient_stat_cross_product,
489
+ self.sufficient_stat_outer_product,
490
+ )
491
+
492
+ # Update beta
493
+ self.model.update_betas(new_beta)
494
+
495
+ # Update omega
496
+ if k < self.nb_phase1_iterations:
497
+ # Simulated annealing during phase 1
498
+ new_omega_diag = torch.diag(new_omega).to(device)
499
+ current_omega_diag = torch.diag(self.model.omega_pop).to(device)
500
+ annealed_omega_diag = self._simulated_annealing(
501
+ current_omega_diag, new_omega_diag
502
+ )
503
+ new_omega = torch.diag(annealed_omega_diag).to(device)
504
+ self.model.update_omega(new_omega)
505
+
506
+ # 3. Update fixed effects MIs
507
+ if self.model.nb_MI > 0:
508
+ # This step is notoriously under-optimized
509
+ self.current_full_res_var_for_MI = self.model.residual_var.index_select(
510
+ 0, self.model.full_output_indices
511
+ )
512
+ target_log_MI_np = minimize(
513
+ fun=self.MI_objective_function,
514
+ x0=self.model.log_MI.cpu().squeeze().numpy(),
515
+ method="L-BFGS-B",
516
+ options={"maxfun": self.optim_max_fun},
517
+ ).x
518
+ target_log_MI = torch.from_numpy(target_log_MI_np).to(device)
519
+ new_log_MI = self._stochastic_approximation(
520
+ self.model.log_MI, target_log_MI
521
+ )
522
+
523
+ self.model.update_log_mi(new_log_MI)
524
+
525
+ if self.verbose:
526
+ print(
527
+ f" Updated MIs: {', '.join([f'{torch.exp(logMI).item():.4f}' for logMI in self.model.log_MI.detach().cpu()])}"
528
+ )
529
+ print(
530
+ f" Updated Betas: {', '.join([f'{beta:.4f}' for beta in self.model.population_betas.detach().cpu().numpy().flatten()])}"
531
+ )
532
+ print(
533
+ f" Updated Omega (diag): {', '.join([f'{val.item():.4f}' for val in torch.diag(self.model.omega_pop.detach().cpu())])}"
534
+ )
535
+ print(
536
+ f" Updated Residual Var: {', '.join([f'{res_var:.4f}' for res_var in self.model.residual_var.detach().cpu().numpy().flatten()])}"
537
+ )
538
+
539
+ # Convergence check
540
+ new_params: dict[str, torch.Tensor] = {
541
+ "log_MI": self.model.log_MI,
542
+ "population_betas": self.model.population_betas,
543
+ "population_omega": self.model.omega_pop,
544
+ "residual_error_var": self.model.residual_var,
545
+ }
546
+ is_converged = self._check_convergence(new_params)
547
+
548
+ # store history
549
+ self._append_history(
550
+ self.model.population_betas,
551
+ self.model.omega_pop,
552
+ self.model.log_MI,
553
+ self.model.residual_var,
554
+ self.current_complete_likelihood,
555
+ flagged_patients_iter,
556
+ )
557
+
558
+ # update prev_params for the next iteration's convergence check
559
+ self.prev_params = new_params
560
+
561
+ if self.verbose:
562
+ print("Iter done")
563
+ return is_converged
564
+
565
+ def MI_objective_function(self, log_MI):
566
+ log_MI_expanded = (
567
+ torch.as_tensor(log_MI, device=device)
568
+ .unsqueeze(0)
569
+ .repeat((self.model.nb_patients, 1))
570
+ )
571
+ if hasattr(self.model, "patients_pdk"):
572
+ pdk_full = self.model.patients_pdk_full
573
+ else:
574
+ pdk_full = torch.empty((self.model.nb_patients, 0), device=device)
575
+ # Assemble the patient parameters in the right order: PDK, PDU, MI
576
+ new_thetas = torch.cat(
577
+ (
578
+ pdk_full,
579
+ torch.exp(
580
+ torch.cat(
581
+ (
582
+ self.current_log_pdu,
583
+ log_MI_expanded,
584
+ ),
585
+ dim=1,
586
+ ),
587
+ ),
588
+ ),
589
+ dim=1,
590
+ )
591
+ predictions, _ = self.model.predict_outputs_from_theta(new_thetas)
592
+ total_log_lik = (
593
+ self.model.log_likelihood_observation(
594
+ predictions,
595
+ )
596
+ .cpu()
597
+ .sum()
598
+ .item()
599
+ )
600
+
601
+ return -total_log_lik
602
+
603
+ def run(
604
+ self,
605
+ ) -> None:
606
+ """
607
+ This method starts the SAEM estimation by initiating some class attributes then calling the iterate method.
608
+ returns self.population_betas, self.estimated_population_mus, self.population_omega, self.residual_error_var, self.history
609
+ stores the current state of the estimation so that the iterations can continue later with the continue_iterating method.
610
+ """
611
+ if self.verbose:
612
+ print("Starting SAEM Estimation...")
613
+ print(
614
+ f"Initial Population Betas: {', '.join([f'{beta.item():.2f}' for beta in self.model.population_betas.cpu()])}"
615
+ )
616
+ print(
617
+ f"Initial Population MIs: {', '.join([f'{torch.exp(logMI).item():.2f}' for logMI in self.model.log_MI.cpu()])}"
618
+ )
619
+ print(f"Initial Omega:\n{self.model.omega_pop.cpu()}")
620
+ print(f"Initial Residual Variance: {self.model.residual_var.cpu()}")
621
+
622
+ print("Phase 1 (exploration):")
623
+ (
624
+ self.convergence_plot_handle,
625
+ self.convergence_plot_fig,
626
+ self.convergence_plot_axes,
627
+ ) = self._build_convergence_plot(
628
+ indiv_figsize=self.plot_indiv_figsize, n_cols=self.plot_columns
629
+ )
630
+ for k in tqdm(range(1, self.nb_phase1_iterations)):
631
+ # Run iteration, do not check for convergence in the exploration phase
632
+ # Iteration 0 is in fact already done (initialization)
633
+ _ = self.one_iteration(k)
634
+ self.current_iteration = k
635
+ if (self.live_plot) & (k % self.plot_frames == 0):
636
+ self._update_convergence_plot()
637
+
638
+ if self.nb_phase2_iterations > 0:
639
+ self.current_phase = 2
640
+ print("Phase 2 (smoothing):")
641
+ for k in tqdm(
642
+ range(
643
+ self.nb_phase1_iterations,
644
+ self.nb_phase1_iterations + self.nb_phase2_iterations,
645
+ )
646
+ ):
647
+ # Run iteration
648
+ is_converged = self.one_iteration(k)
649
+ self.current_iteration = k
650
+
651
+ if (self.live_plot) & (k % self.plot_frames == 0):
652
+ self._update_convergence_plot()
653
+ # Check for convergence, and stop if criterion matched
654
+ if is_converged:
655
+ self.consecutive_converged_iters += 1
656
+ if self.verbose:
657
+ print(
658
+ f"Convergence met. Consecutive iterations: {self.consecutive_converged_iters}/{self.patience}"
659
+ )
660
+ if self.consecutive_converged_iters >= self.patience:
661
+ print(
662
+ f"\nConvergence reached after {k + 1} iterations. Stopping early."
663
+ )
664
+ self._update_convergence_plot()
665
+ break
666
+ else:
667
+ self.consecutive_converged_iters = 0
668
+ self._update_convergence_plot()
669
+ plt.close(self.convergence_plot_fig)
670
+ return None
671
+
672
+ def continue_iterating(self, nb_add_iters_ph1=0, nb_add_iters_ph2=0) -> None:
673
+ """
674
+ This method is to be used when the run method has already run and the user wants to further iterate.
675
+ """
676
+ if self.current_phase == 2:
677
+ if nb_add_iters_ph1 > 0:
678
+ print("Smoothing phase has started, cannot add phase 1 iterations.")
679
+ nb_add_iters_ph1 = 0
680
+ if self.current_phase == 1:
681
+ if nb_add_iters_ph1 > 0:
682
+ print("Continuing phase 1 (exploration):")
683
+ for k in tqdm(range(self.nb_phase1_iterations + nb_add_iters_ph1)):
684
+ # Run iteration, do not check for convergence in the exploration phase
685
+ _ = self.one_iteration(k)
686
+
687
+ print("Switching to Phase 2 (smoothing)")
688
+ self.current_phase = 2
689
+
690
+ if nb_add_iters_ph2 > 0:
691
+ for k in tqdm(
692
+ range(
693
+ self.nb_phase1_iterations
694
+ + self.nb_phase2_iterations
695
+ + nb_add_iters_ph1,
696
+ self.nb_phase1_iterations
697
+ + self.nb_phase2_iterations
698
+ + nb_add_iters_ph1
699
+ + nb_add_iters_ph2,
700
+ )
701
+ ):
702
+ # Run iteration
703
+ is_converged = self.one_iteration(k)
704
+ # Check for convergence, and stop if criterion matched
705
+ if is_converged:
706
+ self.consecutive_converged_iters += 1
707
+ if self.verbose:
708
+ print(
709
+ f"Convergence met. Consecutive iterations: {self.consecutive_converged_iters}/{self.patience}"
710
+ )
711
+ if self.consecutive_converged_iters >= self.patience:
712
+ print(
713
+ f"\nConvergence reached after {k + 1} iterations. Stopping early."
714
+ )
715
+ break
716
+ else:
717
+ self.consecutive_converged_iters = 0
718
+ return None
719
+
720
+ def _build_convergence_plot(
721
+ self,
722
+ indiv_figsize: tuple[float, float] = (2.0, 1.2),
723
+ n_cols: int = 3,
724
+ ):
725
+ """
726
+ This method plots the evolution of the estimated parameters (MI, betas, omega, residual error variances) across iterations
727
+ """
728
+ history = self.history
729
+ nb_MI: int = self.model.nb_MI
730
+ nb_betas: int = self.model.nb_betas
731
+ nb_omega_diag_params: int = self.model.nb_PDU
732
+ nb_var_res_params: int = self.model.nb_outputs
733
+ nb_plots = nb_MI + nb_betas + nb_omega_diag_params + nb_var_res_params + 2
734
+ nb_cols = n_cols
735
+ nb_rows = int(np.ceil(nb_plots / nb_cols))
736
+ maxiter = self.nb_phase1_iterations + self.nb_phase2_iterations
737
+ fig, axes = plt.subplots(
738
+ nrows=nb_rows,
739
+ ncols=nb_cols,
740
+ figsize=(
741
+ nb_cols * indiv_figsize[0],
742
+ nb_rows * indiv_figsize[1],
743
+ ),
744
+ squeeze=False,
745
+ sharex="all",
746
+ )
747
+
748
+ self.traces = {}
749
+ plot_idx: int = 0
750
+ # Plot the MI parameters
751
+ for mi_name in self.model.MI_names:
752
+ row, col = plot_idx // nb_cols, plot_idx % nb_cols
753
+ ax = axes[row, col]
754
+ ax.set_xlim(0, maxiter)
755
+ MI_history = [h.item() for h in history[mi_name]]
756
+ (tr,) = ax.plot(
757
+ MI_history,
758
+ )
759
+ if hasattr(self, "true_log_MI"):
760
+ if self.true_log_MI is not None:
761
+ ax.axhline(
762
+ y=self.true_log_MI[mi_name],
763
+ linestyle="--",
764
+ )
765
+ if hasattr(self, "training_ranges"):
766
+ if self.training_ranges is not None:
767
+ ax.fill_between(
768
+ [0, maxiter],
769
+ self.training_ranges[mi_name]["low"],
770
+ self.training_ranges[mi_name]["high"],
771
+ alpha=0.25,
772
+ )
773
+ ax.set_title("Model intrinsic {MI_name}")
774
+ ax.grid(True)
775
+ self.traces.update({mi_name: tr})
776
+ plot_idx += 1
777
+ # Plot the PDUs means
778
+ for pdu in self.model.PDU_names:
779
+ row, col = plot_idx // nb_cols, plot_idx % nb_cols
780
+ ax = axes[row, col]
781
+ ax.set_xlim(0, maxiter)
782
+ beta_history = [h.item() for h in history[pdu]["mu"]]
783
+ (tr,) = ax.plot(
784
+ beta_history,
785
+ )
786
+ if hasattr(self, "true_log_PDUs"):
787
+ if self.true_log_PDUs is not None:
788
+ ax.axhline(
789
+ y=self.true_log_PDUs[pdu]["mean"],
790
+ linestyle="--",
791
+ )
792
+ if hasattr(self, "training_ranges"):
793
+ if self.training_ranges is not None:
794
+ ax.fill_between(
795
+ [0, maxiter],
796
+ self.training_ranges[pdu]["low"],
797
+ self.training_ranges[pdu]["high"],
798
+ alpha=0.25,
799
+ )
800
+ ax.set_title(rf"{pdu}: $\mu$ (log)")
801
+ ax.set_xlabel("")
802
+ ax.grid(True)
803
+ self.traces.update({pdu: {"mu": tr}})
804
+ plot_idx += 1
805
+ # Plot the PDUs sigma
806
+ for pdu in self.model.PDU_names:
807
+ row, col = plot_idx // nb_cols, plot_idx % nb_cols
808
+ ax = axes[row, col]
809
+ ax.set_xlim(0, maxiter)
810
+ beta_history = [h.item() for h in history[pdu]["sigma_sq"]]
811
+ (tr,) = ax.plot(
812
+ beta_history,
813
+ )
814
+ if hasattr(self, "true_log_PDUs"):
815
+ if self.true_log_PDUs is not None:
816
+ ax.axhline(
817
+ y=self.true_log_PDUs[pdu]["sd"],
818
+ linestyle="--",
819
+ )
820
+ ax.set_title(rf"{pdu}: $\sigma^2$")
821
+ ax.set_xlabel("")
822
+ ax.grid(True)
823
+ self.traces[pdu].update({"sigma_sq": tr})
824
+ plot_idx += 1
825
+ # Plot the coefficients of covariation
826
+ for beta_name in self.model.covariate_coeffs_names:
827
+ row, col = plot_idx // nb_cols, plot_idx % nb_cols
828
+ ax = axes[row, col]
829
+ ax.set_xlim(0, maxiter)
830
+ beta_history = [h.item() for h in history[beta_name]]
831
+ (tr,) = ax.plot(
832
+ beta_history,
833
+ )
834
+ if hasattr(self, "true_cov"):
835
+ if self.true_cov is not None:
836
+ ax.axhline(
837
+ y=self.true_cov[beta_name],
838
+ linestyle="--",
839
+ )
840
+ ax.set_title(rf"{beta_name}")
841
+ ax.set_xlabel("")
842
+ ax.grid(True)
843
+ self.traces.update({beta_name: tr})
844
+ plot_idx += 1
845
+ # Plot the residual variance
846
+ for res_name in self.model.outputs_names:
847
+ row, col = plot_idx // nb_cols, plot_idx % nb_cols
848
+ ax = axes[row, col]
849
+ ax.set_xlim(0, maxiter)
850
+ var_res_history = [h.item() for h in history[res_name]]
851
+ (tr,) = ax.plot(
852
+ var_res_history,
853
+ )
854
+ if hasattr(self, "true_res_var"):
855
+ if self.true_res_var is not None:
856
+ ax.axhline(
857
+ y=self.true_res_var[res_name],
858
+ linestyle="--",
859
+ )
860
+ ax.set_title(rf"{res_name}: $\sigma^2$")
861
+ ax.grid(True)
862
+ self.traces.update({res_name: tr})
863
+ plot_idx += 1
864
+ # Plot the convergence indicator (total log prob)
865
+ row, col = plot_idx // nb_cols, plot_idx % nb_cols
866
+ ax = axes[row, col]
867
+ ax.set_xlim(0, maxiter)
868
+ convergence_ind = [h.item() for h in history["complete_likelihood"]]
869
+ (tr,) = ax.plot(
870
+ convergence_ind,
871
+ )
872
+ ax.set_title(rf"Convergence indicator")
873
+ ax.grid(True)
874
+ self.traces.update({"convergence_ind": tr})
875
+ plot_idx += 1
876
+ # Plot the number of out of bounds patients
877
+ row, col = plot_idx // nb_cols, plot_idx % nb_cols
878
+ ax = axes[row, col]
879
+ ax.set_xlim(0, maxiter)
880
+ ax.set_ylim(0, self.model.nb_patients)
881
+ oob_patients = [len(h) for h in history["flagged_patients"]]
882
+ (tr,) = ax.plot(
883
+ oob_patients,
884
+ )
885
+ ax.set_title(rf"Out-of-bounds patients")
886
+ ax.grid(True)
887
+ self.traces.update({"oob_patients": tr})
888
+ plot_idx += 1
889
+
890
+ # Turn off extra subplots
891
+ while plot_idx < nb_rows * nb_cols:
892
+ row, col = plot_idx // nb_cols, plot_idx % nb_cols
893
+ ax = axes[row, col]
894
+ ax.set_visible(False)
895
+ plot_idx += 1
896
+ if not smoke_test:
897
+ plt.tight_layout()
898
+ handle = display(fig, display_id=True)
899
+ else:
900
+ handle = None
901
+ return (handle, fig, axes)
902
+
903
+ def _update_convergence_plot(self):
904
+ history = self.history
905
+ new_xaxis = np.arange(self.current_iteration + 1)
906
+ # Plot the MI parameters
907
+ for mi_name in self.model.MI_names:
908
+ MI_history = [h.item() for h in history[mi_name]]
909
+ self.traces[mi_name].set_data(new_xaxis, MI_history)
910
+ # Plot the PDUs means
911
+ for pdu in self.model.PDU_names:
912
+ beta_history = [h.item() for h in history[pdu]["mu"]]
913
+ self.traces[pdu]["mu"].set_data(new_xaxis, beta_history)
914
+ # Plot the PDUs sigma
915
+ for pdu in self.model.PDU_names:
916
+ beta_history = [h.item() for h in history[pdu]["sigma_sq"]]
917
+ self.traces[pdu]["sigma_sq"].set_data(new_xaxis, beta_history)
918
+ # Plot the coefficients of covariation
919
+ for beta_name in self.model.covariate_coeffs_names:
920
+ beta_history = [h.item() for h in history[beta_name]]
921
+ self.traces[beta_name].set_data(new_xaxis, beta_history)
922
+ # Plot the residual variance
923
+ for res_name in self.model.outputs_names:
924
+ var_res_history = [h.item() for h in history[res_name]]
925
+ self.traces[res_name].set_data(new_xaxis, var_res_history)
926
+ conv_ind = [h.item() for h in history["complete_likelihood"]]
927
+ self.traces["convergence_ind"].set_data(new_xaxis, conv_ind)
928
+ oob_patients = [len(h) for h in history["flagged_patients"]]
929
+ self.traces["oob_patients"].set_data(new_xaxis, oob_patients)
930
+ if not smoke_test:
931
+ for ax in self.convergence_plot_axes.flatten():
932
+ ax.autoscale_view(scaley=True, scalex=False)
933
+ ax.relim()
934
+ if self.convergence_plot_handle is not None:
935
+ self.convergence_plot_handle.update(self.convergence_plot_fig)
936
+
937
+ def plot_convergence_history(
938
+ self,
939
+ indiv_figsize: tuple[float, float] = (2.0, 1.2),
940
+ n_cols: int = 3,
941
+ ):
942
+ handle, fig, axes = self._build_convergence_plot(
943
+ indiv_figsize=indiv_figsize, n_cols=n_cols
944
+ )
945
+ plt.close(fig)