vpop-calibration 2.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vpop_calibration/__init__.py +22 -0
- vpop_calibration/data_generation.py +186 -0
- vpop_calibration/diagnostics.py +162 -0
- vpop_calibration/model/__init__.py +3 -0
- vpop_calibration/model/data.py +420 -0
- vpop_calibration/model/gp.py +517 -0
- vpop_calibration/model/plot.py +243 -0
- vpop_calibration/nlme.py +840 -0
- vpop_calibration/ode.py +203 -0
- vpop_calibration/saem.py +945 -0
- vpop_calibration/structural_model.py +200 -0
- vpop_calibration/test/__init__.py +11 -0
- vpop_calibration/test/test_data.py +21 -0
- vpop_calibration/test/test_gp_flavors.py +89 -0
- vpop_calibration/test/test_gp_saem.py +175 -0
- vpop_calibration/test/test_ode_saem.py +121 -0
- vpop_calibration/utils.py +9 -0
- vpop_calibration/vpop.py +50 -0
- vpop_calibration-2.2.8.dist-info/METADATA +78 -0
- vpop_calibration-2.2.8.dist-info/RECORD +22 -0
- vpop_calibration-2.2.8.dist-info/WHEEL +4 -0
- vpop_calibration-2.2.8.dist-info/licenses/LICENSE +21 -0
vpop_calibration/saem.py
ADDED
|
@@ -0,0 +1,945 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
from scipy.optimize import minimize
|
|
4
|
+
from tqdm.notebook import tqdm
|
|
5
|
+
from typing import Union, Optional
|
|
6
|
+
from pandas import DataFrame
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import numpy as np
|
|
9
|
+
from IPython.display import display
|
|
10
|
+
|
|
11
|
+
from .utils import smoke_test, device
|
|
12
|
+
from .nlme import NlmeModel
|
|
13
|
+
from .structural_model import StructuralGp
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Main SAEM Algorithm Class
|
|
17
|
+
class PySaem:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model: NlmeModel,
|
|
21
|
+
observations_df: DataFrame,
|
|
22
|
+
# MCMC parameters for the E-step
|
|
23
|
+
mcmc_first_burn_in: int = 5,
|
|
24
|
+
mcmc_nb_transitions: int = 1,
|
|
25
|
+
nb_phase1_iterations: int = 100,
|
|
26
|
+
nb_phase2_iterations: Union[int, None] = None,
|
|
27
|
+
convergence_threshold: float = 1e-4,
|
|
28
|
+
patience: int = 5,
|
|
29
|
+
learning_rate_power: float = 0.8,
|
|
30
|
+
annealing_factor: float = 0.95,
|
|
31
|
+
init_step_size: float = 0.5, # stick to the 0.1 - 1 range
|
|
32
|
+
verbose: bool = False,
|
|
33
|
+
optim_max_fun: int = 50,
|
|
34
|
+
live_plot: bool = True,
|
|
35
|
+
plot_frames: int = 20,
|
|
36
|
+
plot_columns: int = 3,
|
|
37
|
+
plot_indiv_figsize: tuple[float, float] = (3.0, 1.2),
|
|
38
|
+
true_log_MI: Optional[dict[str, float]] = None,
|
|
39
|
+
true_log_PDU: Optional[dict[str, dict[str, float | bool]]] = None,
|
|
40
|
+
true_res_var: Optional[list[float]] = None,
|
|
41
|
+
true_covariates: Optional[dict[str, dict[str, dict[str, str | float]]]] = None,
|
|
42
|
+
):
|
|
43
|
+
"""Instantiate an SAEM optimizer for an NLME model
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model (NlmeModel): The model to be optimized
|
|
47
|
+
observations_df (DataFrame): The data set containing observations
|
|
48
|
+
mcmc_first_burn_in (int, optional): Number of discarded samples in the first iteration. Defaults to 5.
|
|
49
|
+
mcmc_nb_transitions (int, optional): Number of kernel transitions computed at each iteration. Defaults to 1.
|
|
50
|
+
nb_phase1_iterations (int, optional): Number of iterations in the exploration phase. Defaults to 100.
|
|
51
|
+
nb_phase2_iterations (Union[int, None], optional): Number of iterations in the convergence phase. Defaults to None, implying nb_phase_2 = nb_phase_1.
|
|
52
|
+
convergence_threshold (float, optional): Estimated parameter relative change threshold considered for convergence. Defaults to 1e-4.
|
|
53
|
+
patience (int, optional): Number of iterations of consecutive low relative change considered for early stopping of the algorithm. Defaults to 5.
|
|
54
|
+
learning_rate_power (float, optional): Exponential decay exponent for the M-step learning rate (stochastic approximation). Defaults to 0.8.
|
|
55
|
+
annealing_factor (float, optional): Exploration phase annealing factor for residual and parameter variance. Defaults to 0.95.
|
|
56
|
+
init_step_size (float, optional): Initial MCMC step size scaling factor. Defaults to 0.5.
|
|
57
|
+
optim_max_fun(int): Maximum number of function calls in the scipy.optimize (used for model intrinsic parameters calibration). Defaults to 50.
|
|
58
|
+
verbose (bool): Print various info during iterations. Defaults to False.
|
|
59
|
+
live_plot (bool): Print and update a plot of parameters during iterations. Defaults to True.
|
|
60
|
+
plot_frames (int): Frequency at which the live plot should be updated (number of iterations). The lower the slower. Defaults to 20.
|
|
61
|
+
plot_columns (int): Number of columns to display the convergence plot. Defaults to 3.
|
|
62
|
+
plot_indiv_figsize (tuple[float,float]): individual figure size in the convergence plot (width, height).
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
self.model: NlmeModel = model
|
|
66
|
+
self.model.add_observations(observations_df)
|
|
67
|
+
# MCMC sampling in the E-step parameters
|
|
68
|
+
self.mcmc_first_burn_in: int = mcmc_first_burn_in
|
|
69
|
+
self.mcmc_nb_transitions: int = mcmc_nb_transitions
|
|
70
|
+
# SAEM iteration parameters
|
|
71
|
+
# phase 1 = exploratory: learning rate = 0 and simulated annealing on
|
|
72
|
+
# phase 2 = smoothing: learning rate 1/phase2_iter^factor
|
|
73
|
+
if smoke_test:
|
|
74
|
+
self.nb_phase1_iterations = 1
|
|
75
|
+
self.nb_phase2_iterations = 2
|
|
76
|
+
else:
|
|
77
|
+
self.nb_phase1_iterations: int = nb_phase1_iterations
|
|
78
|
+
self.nb_phase2_iterations: int = (
|
|
79
|
+
nb_phase2_iterations
|
|
80
|
+
if nb_phase2_iterations is not None
|
|
81
|
+
else nb_phase1_iterations
|
|
82
|
+
)
|
|
83
|
+
self.current_phase = 1
|
|
84
|
+
|
|
85
|
+
# convergence parameters
|
|
86
|
+
self.convergence_threshold: float = convergence_threshold
|
|
87
|
+
self.patience: int = patience
|
|
88
|
+
self.consecutive_converged_iters: int = 0
|
|
89
|
+
|
|
90
|
+
# Numerical parameters that depend on the iterations phase
|
|
91
|
+
# The learning rate for the step-size adaptation in E-step sampling
|
|
92
|
+
self.step_size: float = init_step_size / np.sqrt(self.model.nb_PDU)
|
|
93
|
+
self.init_step_size_adaptation: float = 0.5
|
|
94
|
+
self.step_size_learning_rate_power: float = 0.5
|
|
95
|
+
|
|
96
|
+
# The learning rate for the stochastic approximation in the M-step
|
|
97
|
+
self.learning_rate_m_step: float = 1.0
|
|
98
|
+
self.learning_rate_power: float = learning_rate_power
|
|
99
|
+
self.annealing_factor: float = annealing_factor
|
|
100
|
+
|
|
101
|
+
# Initialize the learning rate and step size adaptation rate
|
|
102
|
+
self.learning_rate_m_step, self.step_size_adaptation = (
|
|
103
|
+
self._compute_learning_rates(0)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
self.verbose = verbose
|
|
107
|
+
if smoke_test:
|
|
108
|
+
self.optim_max_fun = 1
|
|
109
|
+
else:
|
|
110
|
+
self.optim_max_fun = optim_max_fun
|
|
111
|
+
self.live_plot = live_plot
|
|
112
|
+
self.plot_frames = plot_frames
|
|
113
|
+
self.plot_columns = plot_columns
|
|
114
|
+
self.plot_indiv_figsize = plot_indiv_figsize
|
|
115
|
+
|
|
116
|
+
# Initialize the random effects to 0
|
|
117
|
+
self.current_etas = self.model.current_eta_samples
|
|
118
|
+
|
|
119
|
+
# Initialize current estimation of patient parameters from the 0 random effects
|
|
120
|
+
(
|
|
121
|
+
self.current_log_prob,
|
|
122
|
+
self.current_thetas,
|
|
123
|
+
self.current_log_pdu,
|
|
124
|
+
self.current_pred,
|
|
125
|
+
flagged_patients,
|
|
126
|
+
) = self.model.log_posterior_etas(self.current_etas)
|
|
127
|
+
self.current_complete_likelihood = torch.exp(
|
|
128
|
+
self.current_log_prob.sum(dim=0)
|
|
129
|
+
).to(device)
|
|
130
|
+
|
|
131
|
+
# Initialize the optimizer history
|
|
132
|
+
self._init_history(
|
|
133
|
+
self.model.population_betas,
|
|
134
|
+
self.model.omega_pop,
|
|
135
|
+
self.model.log_MI,
|
|
136
|
+
self.model.residual_var,
|
|
137
|
+
self.current_complete_likelihood,
|
|
138
|
+
flagged_patients,
|
|
139
|
+
)
|
|
140
|
+
self.current_iteration: int = 0
|
|
141
|
+
|
|
142
|
+
# Initialize the values for convergence checks
|
|
143
|
+
self.prev_params: dict[str, torch.Tensor] = {
|
|
144
|
+
"log_MI": self.model.log_MI,
|
|
145
|
+
"population_betas": self.model.population_betas,
|
|
146
|
+
"population_omega": self.model.omega_pop,
|
|
147
|
+
"residual_error_var": self.model.residual_var,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
# pre-compute full design matrix once
|
|
151
|
+
self.X = torch.stack(
|
|
152
|
+
[self.model.design_matrices[ind] for ind in self.model.patients],
|
|
153
|
+
dim=0,
|
|
154
|
+
).to(device)
|
|
155
|
+
# Precompute the gram matrix
|
|
156
|
+
self.sufficient_stat_gram_matrix = (
|
|
157
|
+
torch.matmul(self.X.transpose(1, 2), self.X).sum(dim=0).to(device)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Initialize sufficient statistics
|
|
161
|
+
self.sufficient_stat_cross_product = (
|
|
162
|
+
(self.X.transpose(1, 2) @ self.current_log_pdu.unsqueeze(-1))
|
|
163
|
+
.sum(dim=0)
|
|
164
|
+
.to(device)
|
|
165
|
+
)
|
|
166
|
+
self.sufficient_stat_outer_product = torch.matmul(
|
|
167
|
+
self.current_log_pdu.transpose(0, 1), self.current_log_pdu
|
|
168
|
+
).to(device)
|
|
169
|
+
|
|
170
|
+
self.true_log_MI = true_log_MI
|
|
171
|
+
self.true_log_PDUs = true_log_PDU
|
|
172
|
+
if true_covariates is not None:
|
|
173
|
+
self.true_cov = {
|
|
174
|
+
str(cov["coef"]): float(cov["value"])
|
|
175
|
+
for item in true_covariates.values()
|
|
176
|
+
for cov in item.values()
|
|
177
|
+
}
|
|
178
|
+
if true_res_var is not None:
|
|
179
|
+
self.true_res_var = {
|
|
180
|
+
self.model.outputs_names[k]: val for k, val in enumerate(true_res_var)
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
if isinstance(self.model.structural_model, StructuralGp):
|
|
184
|
+
self.training_ranges = {}
|
|
185
|
+
training_samples = np.log(
|
|
186
|
+
self.model.structural_model.gp_model.data.full_df_raw[
|
|
187
|
+
self.model.PDU_names + self.model.MI_names
|
|
188
|
+
]
|
|
189
|
+
)
|
|
190
|
+
train_min = training_samples.min(axis=0)
|
|
191
|
+
train_max = training_samples.max(axis=0)
|
|
192
|
+
for param in self.model.PDU_names + self.model.MI_names:
|
|
193
|
+
self.training_ranges.update(
|
|
194
|
+
{param: {"low": train_min[param], "high": train_max[param]}}
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def m_step_update(
|
|
198
|
+
self,
|
|
199
|
+
log_pdu: torch.Tensor,
|
|
200
|
+
s_cross_product: torch.Tensor,
|
|
201
|
+
s_outer_product: torch.Tensor,
|
|
202
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
203
|
+
"""Perform the M-step update
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
log_pdu (torch.Tensor): Current estimation of the log-scaled parameters
|
|
207
|
+
s_cross_product (torch.Tensor): Current sufficient statistics 1 - cross product
|
|
208
|
+
s_outer_product (torch.Tensor): Current sufficient statistics 2 - outer product
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: Updated value for
|
|
212
|
+
- sufficient statistics: cross product
|
|
213
|
+
- sufficient statistics: outer product
|
|
214
|
+
- beta parameters
|
|
215
|
+
- omega matrix
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
assert log_pdu.shape[0] == self.X.shape[0]
|
|
219
|
+
cross_product = (
|
|
220
|
+
(self.X.transpose(1, 2) @ log_pdu.unsqueeze(-1)).sum(dim=0).to(device)
|
|
221
|
+
)
|
|
222
|
+
new_s_cross_product = self._stochastic_approximation(
|
|
223
|
+
s_cross_product, cross_product
|
|
224
|
+
)
|
|
225
|
+
outer_product = torch.matmul(log_pdu.transpose(0, 1), log_pdu).to(device)
|
|
226
|
+
new_s_outer_product = self._stochastic_approximation(
|
|
227
|
+
s_outer_product, outer_product
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
new_beta = torch.linalg.solve(
|
|
231
|
+
self.sufficient_stat_gram_matrix, new_s_cross_product
|
|
232
|
+
).to(device)
|
|
233
|
+
|
|
234
|
+
new_log_pdu = torch.matmul(self.X, new_beta.unsqueeze(0)).squeeze(-1).to(device)
|
|
235
|
+
# Propose a new value for omega
|
|
236
|
+
new_omega = (
|
|
237
|
+
1
|
|
238
|
+
/ self.model.nb_patients
|
|
239
|
+
* (
|
|
240
|
+
new_s_outer_product
|
|
241
|
+
- torch.matmul(new_log_pdu.transpose(0, 1), new_log_pdu)
|
|
242
|
+
)
|
|
243
|
+
).to(device)
|
|
244
|
+
new_omega = self._clamp_eigen_values(new_omega)
|
|
245
|
+
|
|
246
|
+
return (
|
|
247
|
+
new_s_cross_product,
|
|
248
|
+
new_s_outer_product,
|
|
249
|
+
new_beta.squeeze(-1),
|
|
250
|
+
new_omega,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
def _check_convergence(self, new_params: dict[str, torch.Tensor]) -> bool:
|
|
254
|
+
"""Checks for convergence based on the relative change in parameters."""
|
|
255
|
+
all_converged = True
|
|
256
|
+
for name, current_val in new_params.items():
|
|
257
|
+
if current_val.shape[0] > 0:
|
|
258
|
+
prev_val = self.prev_params[name]
|
|
259
|
+
abs_diff = torch.abs(current_val - prev_val)
|
|
260
|
+
abs_sum = torch.abs(current_val) + torch.abs(prev_val) + 1e-9
|
|
261
|
+
relative_change = abs_diff / abs_sum
|
|
262
|
+
if torch.any(relative_change > self.convergence_threshold):
|
|
263
|
+
all_converged = False
|
|
264
|
+
break
|
|
265
|
+
return all_converged
|
|
266
|
+
|
|
267
|
+
def _compute_learning_rates(self, iteration: int) -> tuple[float, float]:
|
|
268
|
+
"""
|
|
269
|
+
Calculates the SAEM learning rate (alpha_k) and Metropolis Hastings step-size (gamma_k).
|
|
270
|
+
|
|
271
|
+
Phase 1:
|
|
272
|
+
alpha_k = 1 (exploration)
|
|
273
|
+
gamma_k = c_0 / k^(0.5) , c0 = init_step_size_adaptation / sqrt(n_PDU)
|
|
274
|
+
Phase 2:
|
|
275
|
+
alpha_k = 1 / (iteration - phase1_iterations + 1) ^ exponent (the iteration index in phase 2)
|
|
276
|
+
gamma_k = 0
|
|
277
|
+
"""
|
|
278
|
+
if iteration < self.nb_phase1_iterations:
|
|
279
|
+
learning_rate_m_step = 1.0
|
|
280
|
+
learning_rate_e_step = self.init_step_size_adaptation / (
|
|
281
|
+
np.maximum(1, iteration) ** 0.5
|
|
282
|
+
)
|
|
283
|
+
else:
|
|
284
|
+
learning_rate_m_step = 1.0 / (
|
|
285
|
+
(iteration - self.nb_phase1_iterations + 1) ** self.learning_rate_power
|
|
286
|
+
)
|
|
287
|
+
learning_rate_e_step = 0
|
|
288
|
+
return learning_rate_m_step, learning_rate_e_step
|
|
289
|
+
|
|
290
|
+
def _stochastic_approximation(
|
|
291
|
+
self, previous: torch.Tensor, new: torch.Tensor
|
|
292
|
+
) -> torch.Tensor:
|
|
293
|
+
"""Perform stochastic approximation
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
previous (torch.Tensor): The current value of the tensor
|
|
297
|
+
new (torch.Tensor): The target value of the tensor
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
torch.Tensor: (1 - learning_rate) * previous + learning_rate * new
|
|
301
|
+
"""
|
|
302
|
+
assert (
|
|
303
|
+
previous.shape == new.shape
|
|
304
|
+
), f"Wrong shape in stochastic approximation: {previous.shape}, {new.shape}"
|
|
305
|
+
stochastic_approx = (
|
|
306
|
+
(1 - self.learning_rate_m_step) * previous + self.learning_rate_m_step * new
|
|
307
|
+
).to(device)
|
|
308
|
+
return stochastic_approx
|
|
309
|
+
|
|
310
|
+
def _simulated_annealing(
|
|
311
|
+
self, current: torch.Tensor, target: torch.Tensor
|
|
312
|
+
) -> torch.Tensor:
|
|
313
|
+
"""Perform simulated annealing
|
|
314
|
+
|
|
315
|
+
This function allows to constrain the reduction of certain values by a factor stored in self.annealing_factor
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
current (torch.Tensor): Current value of the tensor
|
|
319
|
+
target (torch.Tensor): Target value of the tensor
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
torch.Tensor: maximum(annealing_factor * current, target)
|
|
323
|
+
"""
|
|
324
|
+
return torch.maximum(self.annealing_factor * current, target).to(device)
|
|
325
|
+
|
|
326
|
+
def _clamp_eigen_values(self, omega: torch.Tensor, min_eigenvalue: float = 1e-6):
|
|
327
|
+
"""
|
|
328
|
+
Project a matrix onto the cone of Positive Definite matrices.
|
|
329
|
+
"""
|
|
330
|
+
# 1. Ensure symmetry (sometimes float error breaks symmetry slightly)
|
|
331
|
+
omega = (0.5 * (omega + omega.T)).to(device)
|
|
332
|
+
|
|
333
|
+
# 2. Eigen Decomposition
|
|
334
|
+
L, V = torch.linalg.eigh(omega)
|
|
335
|
+
|
|
336
|
+
# 3. Clamp eigenvalues
|
|
337
|
+
L_clamped = torch.clamp(L, min=min_eigenvalue)
|
|
338
|
+
|
|
339
|
+
# 4. Reconstruct
|
|
340
|
+
matrix_spd = torch.matmul(V, torch.matmul(torch.diag(L_clamped), V.T))
|
|
341
|
+
|
|
342
|
+
return matrix_spd.to(device)
|
|
343
|
+
|
|
344
|
+
def _init_history(
|
|
345
|
+
self,
|
|
346
|
+
beta: torch.Tensor,
|
|
347
|
+
omega: torch.Tensor,
|
|
348
|
+
log_mi: torch.Tensor,
|
|
349
|
+
res_var: torch.Tensor,
|
|
350
|
+
likelihood: torch.Tensor,
|
|
351
|
+
flagged_patients: list,
|
|
352
|
+
) -> None:
|
|
353
|
+
# Initialize the history
|
|
354
|
+
self.history = {}
|
|
355
|
+
# Add the pdus (mean, variance)
|
|
356
|
+
for i, pdu in enumerate(self.model.PDU_names):
|
|
357
|
+
beta_index = self.model.population_betas_names.index(pdu)
|
|
358
|
+
self.history.update(
|
|
359
|
+
{
|
|
360
|
+
pdu: {
|
|
361
|
+
"mu": [beta[beta_index].cpu()],
|
|
362
|
+
"sigma_sq": [omega[i, i].cpu()],
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
)
|
|
366
|
+
# Add the covariates
|
|
367
|
+
for i, cov in enumerate(self.model.covariate_coeffs_names):
|
|
368
|
+
beta_index = self.model.population_betas_names.index(cov)
|
|
369
|
+
self.history.update({cov: [beta[beta_index].cpu()]})
|
|
370
|
+
# Add Omega
|
|
371
|
+
self.history.update({"omega": [omega.cpu()]})
|
|
372
|
+
# Add the model intrinsic params
|
|
373
|
+
for i, mi in enumerate(self.model.MI_names):
|
|
374
|
+
self.history.update({mi: [log_mi[i].cpu()]})
|
|
375
|
+
# Add the residual variance
|
|
376
|
+
for i, output in enumerate(self.model.outputs_names):
|
|
377
|
+
self.history.update({output: [res_var[i].cpu()]})
|
|
378
|
+
self.history.update({"complete_likelihood": [likelihood.cpu()]})
|
|
379
|
+
self.history.update({"flagged_patients": [flagged_patients]})
|
|
380
|
+
|
|
381
|
+
def _append_history(
|
|
382
|
+
self,
|
|
383
|
+
beta: torch.Tensor,
|
|
384
|
+
omega: torch.Tensor,
|
|
385
|
+
log_mi: torch.Tensor,
|
|
386
|
+
res_var: torch.Tensor,
|
|
387
|
+
complete_likelihood: torch.Tensor,
|
|
388
|
+
flagged_patients: list,
|
|
389
|
+
) -> None:
|
|
390
|
+
# Update the history
|
|
391
|
+
for i, pdu in enumerate(self.model.PDU_names):
|
|
392
|
+
beta_index = self.model.population_betas_names.index(pdu)
|
|
393
|
+
self.history[pdu]["mu"].append(beta[beta_index].cpu())
|
|
394
|
+
self.history[pdu]["sigma_sq"].append(omega[i, i].cpu())
|
|
395
|
+
|
|
396
|
+
for i, cov in enumerate(self.model.covariate_coeffs_names):
|
|
397
|
+
beta_index = self.model.population_betas_names.index(cov)
|
|
398
|
+
self.history[cov].append(beta[beta_index].cpu())
|
|
399
|
+
|
|
400
|
+
self.history["omega"].append(omega.cpu())
|
|
401
|
+
|
|
402
|
+
for i, mi in enumerate(self.model.MI_names):
|
|
403
|
+
self.history[mi].append(log_mi[i].cpu())
|
|
404
|
+
|
|
405
|
+
for i, output in enumerate(self.model.outputs_names):
|
|
406
|
+
self.history[output].append(res_var[i].cpu())
|
|
407
|
+
self.history["complete_likelihood"].append(complete_likelihood.cpu())
|
|
408
|
+
self.history["flagged_patients"].append(flagged_patients)
|
|
409
|
+
|
|
410
|
+
def one_iteration(self, k: int) -> bool:
|
|
411
|
+
"""Perform one iteration of SAEM
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
k (int): the iteration number
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
if self.verbose:
|
|
418
|
+
print(f"Running iteration {k}")
|
|
419
|
+
# If first iteration, consider burn in
|
|
420
|
+
if k == 0:
|
|
421
|
+
current_iter_burn_in = self.mcmc_first_burn_in
|
|
422
|
+
else:
|
|
423
|
+
current_iter_burn_in = 0
|
|
424
|
+
|
|
425
|
+
self.learning_rate_m_step, self.step_size_adaptation = (
|
|
426
|
+
self._compute_learning_rates(k)
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# --- E-step: perform MCMC kernel transitions
|
|
430
|
+
if self.verbose:
|
|
431
|
+
print(" MCMC sampling")
|
|
432
|
+
print(
|
|
433
|
+
f" Current MCMC parameters: step-size={self.step_size:.2f}, adaptation rate={self.step_size_adaptation:.2f}"
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
flagged_patients_iter = []
|
|
437
|
+
for _ in range(current_iter_burn_in + self.mcmc_nb_transitions):
|
|
438
|
+
(
|
|
439
|
+
self.current_etas,
|
|
440
|
+
self.current_log_prob,
|
|
441
|
+
self.current_complete_likelihood,
|
|
442
|
+
self.current_pred,
|
|
443
|
+
self.current_thetas,
|
|
444
|
+
self.current_log_pdu,
|
|
445
|
+
self.step_size,
|
|
446
|
+
flagged_patients,
|
|
447
|
+
) = self.model.mh_step(
|
|
448
|
+
current_etas=self.current_etas,
|
|
449
|
+
current_log_prob=self.current_log_prob,
|
|
450
|
+
current_pred=self.current_pred,
|
|
451
|
+
current_thetas=self.current_thetas,
|
|
452
|
+
current_pdu=self.current_log_pdu,
|
|
453
|
+
step_size=self.step_size,
|
|
454
|
+
learning_rate=self.step_size_adaptation,
|
|
455
|
+
verbose=self.verbose,
|
|
456
|
+
)
|
|
457
|
+
flagged_patients_iter += flagged_patients
|
|
458
|
+
|
|
459
|
+
# Update the model's eta and thetas
|
|
460
|
+
self.model.update_eta_samples(self.current_etas)
|
|
461
|
+
self.model.update_map_estimates(self.current_thetas)
|
|
462
|
+
|
|
463
|
+
# --- M-Step: Update Population Means, Omega and Residual variance ---
|
|
464
|
+
|
|
465
|
+
# 1. Update residual error variances
|
|
466
|
+
sum_sq_res = self.model.sum_sq_residuals(self.current_pred)
|
|
467
|
+
target_res_var: torch.Tensor = (
|
|
468
|
+
sum_sq_res / self.model.n_tot_observations_per_output
|
|
469
|
+
)
|
|
470
|
+
current_res_var: torch.Tensor = self.model.residual_var
|
|
471
|
+
if k < self.nb_phase1_iterations:
|
|
472
|
+
target_res_var = self._simulated_annealing(current_res_var, target_res_var)
|
|
473
|
+
|
|
474
|
+
new_residual_error_var = self._stochastic_approximation(
|
|
475
|
+
current_res_var, target_res_var
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
self.model.update_res_var(new_residual_error_var)
|
|
479
|
+
|
|
480
|
+
# 2. Update sufficient statistics with stochastic approximation
|
|
481
|
+
(
|
|
482
|
+
self.sufficient_stat_cross_product,
|
|
483
|
+
self.sufficient_stat_outer_product,
|
|
484
|
+
new_beta,
|
|
485
|
+
new_omega,
|
|
486
|
+
) = self.m_step_update(
|
|
487
|
+
self.current_log_pdu,
|
|
488
|
+
self.sufficient_stat_cross_product,
|
|
489
|
+
self.sufficient_stat_outer_product,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# Update beta
|
|
493
|
+
self.model.update_betas(new_beta)
|
|
494
|
+
|
|
495
|
+
# Update omega
|
|
496
|
+
if k < self.nb_phase1_iterations:
|
|
497
|
+
# Simulated annealing during phase 1
|
|
498
|
+
new_omega_diag = torch.diag(new_omega).to(device)
|
|
499
|
+
current_omega_diag = torch.diag(self.model.omega_pop).to(device)
|
|
500
|
+
annealed_omega_diag = self._simulated_annealing(
|
|
501
|
+
current_omega_diag, new_omega_diag
|
|
502
|
+
)
|
|
503
|
+
new_omega = torch.diag(annealed_omega_diag).to(device)
|
|
504
|
+
self.model.update_omega(new_omega)
|
|
505
|
+
|
|
506
|
+
# 3. Update fixed effects MIs
|
|
507
|
+
if self.model.nb_MI > 0:
|
|
508
|
+
# This step is notoriously under-optimized
|
|
509
|
+
self.current_full_res_var_for_MI = self.model.residual_var.index_select(
|
|
510
|
+
0, self.model.full_output_indices
|
|
511
|
+
)
|
|
512
|
+
target_log_MI_np = minimize(
|
|
513
|
+
fun=self.MI_objective_function,
|
|
514
|
+
x0=self.model.log_MI.cpu().squeeze().numpy(),
|
|
515
|
+
method="L-BFGS-B",
|
|
516
|
+
options={"maxfun": self.optim_max_fun},
|
|
517
|
+
).x
|
|
518
|
+
target_log_MI = torch.from_numpy(target_log_MI_np).to(device)
|
|
519
|
+
new_log_MI = self._stochastic_approximation(
|
|
520
|
+
self.model.log_MI, target_log_MI
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
self.model.update_log_mi(new_log_MI)
|
|
524
|
+
|
|
525
|
+
if self.verbose:
|
|
526
|
+
print(
|
|
527
|
+
f" Updated MIs: {', '.join([f'{torch.exp(logMI).item():.4f}' for logMI in self.model.log_MI.detach().cpu()])}"
|
|
528
|
+
)
|
|
529
|
+
print(
|
|
530
|
+
f" Updated Betas: {', '.join([f'{beta:.4f}' for beta in self.model.population_betas.detach().cpu().numpy().flatten()])}"
|
|
531
|
+
)
|
|
532
|
+
print(
|
|
533
|
+
f" Updated Omega (diag): {', '.join([f'{val.item():.4f}' for val in torch.diag(self.model.omega_pop.detach().cpu())])}"
|
|
534
|
+
)
|
|
535
|
+
print(
|
|
536
|
+
f" Updated Residual Var: {', '.join([f'{res_var:.4f}' for res_var in self.model.residual_var.detach().cpu().numpy().flatten()])}"
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# Convergence check
|
|
540
|
+
new_params: dict[str, torch.Tensor] = {
|
|
541
|
+
"log_MI": self.model.log_MI,
|
|
542
|
+
"population_betas": self.model.population_betas,
|
|
543
|
+
"population_omega": self.model.omega_pop,
|
|
544
|
+
"residual_error_var": self.model.residual_var,
|
|
545
|
+
}
|
|
546
|
+
is_converged = self._check_convergence(new_params)
|
|
547
|
+
|
|
548
|
+
# store history
|
|
549
|
+
self._append_history(
|
|
550
|
+
self.model.population_betas,
|
|
551
|
+
self.model.omega_pop,
|
|
552
|
+
self.model.log_MI,
|
|
553
|
+
self.model.residual_var,
|
|
554
|
+
self.current_complete_likelihood,
|
|
555
|
+
flagged_patients_iter,
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# update prev_params for the next iteration's convergence check
|
|
559
|
+
self.prev_params = new_params
|
|
560
|
+
|
|
561
|
+
if self.verbose:
|
|
562
|
+
print("Iter done")
|
|
563
|
+
return is_converged
|
|
564
|
+
|
|
565
|
+
def MI_objective_function(self, log_MI):
|
|
566
|
+
log_MI_expanded = (
|
|
567
|
+
torch.as_tensor(log_MI, device=device)
|
|
568
|
+
.unsqueeze(0)
|
|
569
|
+
.repeat((self.model.nb_patients, 1))
|
|
570
|
+
)
|
|
571
|
+
if hasattr(self.model, "patients_pdk"):
|
|
572
|
+
pdk_full = self.model.patients_pdk_full
|
|
573
|
+
else:
|
|
574
|
+
pdk_full = torch.empty((self.model.nb_patients, 0), device=device)
|
|
575
|
+
# Assemble the patient parameters in the right order: PDK, PDU, MI
|
|
576
|
+
new_thetas = torch.cat(
|
|
577
|
+
(
|
|
578
|
+
pdk_full,
|
|
579
|
+
torch.exp(
|
|
580
|
+
torch.cat(
|
|
581
|
+
(
|
|
582
|
+
self.current_log_pdu,
|
|
583
|
+
log_MI_expanded,
|
|
584
|
+
),
|
|
585
|
+
dim=1,
|
|
586
|
+
),
|
|
587
|
+
),
|
|
588
|
+
),
|
|
589
|
+
dim=1,
|
|
590
|
+
)
|
|
591
|
+
predictions, _ = self.model.predict_outputs_from_theta(new_thetas)
|
|
592
|
+
total_log_lik = (
|
|
593
|
+
self.model.log_likelihood_observation(
|
|
594
|
+
predictions,
|
|
595
|
+
)
|
|
596
|
+
.cpu()
|
|
597
|
+
.sum()
|
|
598
|
+
.item()
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
return -total_log_lik
|
|
602
|
+
|
|
603
|
+
def run(
|
|
604
|
+
self,
|
|
605
|
+
) -> None:
|
|
606
|
+
"""
|
|
607
|
+
This method starts the SAEM estimation by initiating some class attributes then calling the iterate method.
|
|
608
|
+
returns self.population_betas, self.estimated_population_mus, self.population_omega, self.residual_error_var, self.history
|
|
609
|
+
stores the current state of the estimation so that the iterations can continue later with the continue_iterating method.
|
|
610
|
+
"""
|
|
611
|
+
if self.verbose:
|
|
612
|
+
print("Starting SAEM Estimation...")
|
|
613
|
+
print(
|
|
614
|
+
f"Initial Population Betas: {', '.join([f'{beta.item():.2f}' for beta in self.model.population_betas.cpu()])}"
|
|
615
|
+
)
|
|
616
|
+
print(
|
|
617
|
+
f"Initial Population MIs: {', '.join([f'{torch.exp(logMI).item():.2f}' for logMI in self.model.log_MI.cpu()])}"
|
|
618
|
+
)
|
|
619
|
+
print(f"Initial Omega:\n{self.model.omega_pop.cpu()}")
|
|
620
|
+
print(f"Initial Residual Variance: {self.model.residual_var.cpu()}")
|
|
621
|
+
|
|
622
|
+
print("Phase 1 (exploration):")
|
|
623
|
+
(
|
|
624
|
+
self.convergence_plot_handle,
|
|
625
|
+
self.convergence_plot_fig,
|
|
626
|
+
self.convergence_plot_axes,
|
|
627
|
+
) = self._build_convergence_plot(
|
|
628
|
+
indiv_figsize=self.plot_indiv_figsize, n_cols=self.plot_columns
|
|
629
|
+
)
|
|
630
|
+
for k in tqdm(range(1, self.nb_phase1_iterations)):
|
|
631
|
+
# Run iteration, do not check for convergence in the exploration phase
|
|
632
|
+
# Iteration 0 is in fact already done (initialization)
|
|
633
|
+
_ = self.one_iteration(k)
|
|
634
|
+
self.current_iteration = k
|
|
635
|
+
if (self.live_plot) & (k % self.plot_frames == 0):
|
|
636
|
+
self._update_convergence_plot()
|
|
637
|
+
|
|
638
|
+
if self.nb_phase2_iterations > 0:
|
|
639
|
+
self.current_phase = 2
|
|
640
|
+
print("Phase 2 (smoothing):")
|
|
641
|
+
for k in tqdm(
|
|
642
|
+
range(
|
|
643
|
+
self.nb_phase1_iterations,
|
|
644
|
+
self.nb_phase1_iterations + self.nb_phase2_iterations,
|
|
645
|
+
)
|
|
646
|
+
):
|
|
647
|
+
# Run iteration
|
|
648
|
+
is_converged = self.one_iteration(k)
|
|
649
|
+
self.current_iteration = k
|
|
650
|
+
|
|
651
|
+
if (self.live_plot) & (k % self.plot_frames == 0):
|
|
652
|
+
self._update_convergence_plot()
|
|
653
|
+
# Check for convergence, and stop if criterion matched
|
|
654
|
+
if is_converged:
|
|
655
|
+
self.consecutive_converged_iters += 1
|
|
656
|
+
if self.verbose:
|
|
657
|
+
print(
|
|
658
|
+
f"Convergence met. Consecutive iterations: {self.consecutive_converged_iters}/{self.patience}"
|
|
659
|
+
)
|
|
660
|
+
if self.consecutive_converged_iters >= self.patience:
|
|
661
|
+
print(
|
|
662
|
+
f"\nConvergence reached after {k + 1} iterations. Stopping early."
|
|
663
|
+
)
|
|
664
|
+
self._update_convergence_plot()
|
|
665
|
+
break
|
|
666
|
+
else:
|
|
667
|
+
self.consecutive_converged_iters = 0
|
|
668
|
+
self._update_convergence_plot()
|
|
669
|
+
plt.close(self.convergence_plot_fig)
|
|
670
|
+
return None
|
|
671
|
+
|
|
672
|
+
def continue_iterating(self, nb_add_iters_ph1=0, nb_add_iters_ph2=0) -> None:
|
|
673
|
+
"""
|
|
674
|
+
This method is to be used when the run method has already run and the user wants to further iterate.
|
|
675
|
+
"""
|
|
676
|
+
if self.current_phase == 2:
|
|
677
|
+
if nb_add_iters_ph1 > 0:
|
|
678
|
+
print("Smoothing phase has started, cannot add phase 1 iterations.")
|
|
679
|
+
nb_add_iters_ph1 = 0
|
|
680
|
+
if self.current_phase == 1:
|
|
681
|
+
if nb_add_iters_ph1 > 0:
|
|
682
|
+
print("Continuing phase 1 (exploration):")
|
|
683
|
+
for k in tqdm(range(self.nb_phase1_iterations + nb_add_iters_ph1)):
|
|
684
|
+
# Run iteration, do not check for convergence in the exploration phase
|
|
685
|
+
_ = self.one_iteration(k)
|
|
686
|
+
|
|
687
|
+
print("Switching to Phase 2 (smoothing)")
|
|
688
|
+
self.current_phase = 2
|
|
689
|
+
|
|
690
|
+
if nb_add_iters_ph2 > 0:
|
|
691
|
+
for k in tqdm(
|
|
692
|
+
range(
|
|
693
|
+
self.nb_phase1_iterations
|
|
694
|
+
+ self.nb_phase2_iterations
|
|
695
|
+
+ nb_add_iters_ph1,
|
|
696
|
+
self.nb_phase1_iterations
|
|
697
|
+
+ self.nb_phase2_iterations
|
|
698
|
+
+ nb_add_iters_ph1
|
|
699
|
+
+ nb_add_iters_ph2,
|
|
700
|
+
)
|
|
701
|
+
):
|
|
702
|
+
# Run iteration
|
|
703
|
+
is_converged = self.one_iteration(k)
|
|
704
|
+
# Check for convergence, and stop if criterion matched
|
|
705
|
+
if is_converged:
|
|
706
|
+
self.consecutive_converged_iters += 1
|
|
707
|
+
if self.verbose:
|
|
708
|
+
print(
|
|
709
|
+
f"Convergence met. Consecutive iterations: {self.consecutive_converged_iters}/{self.patience}"
|
|
710
|
+
)
|
|
711
|
+
if self.consecutive_converged_iters >= self.patience:
|
|
712
|
+
print(
|
|
713
|
+
f"\nConvergence reached after {k + 1} iterations. Stopping early."
|
|
714
|
+
)
|
|
715
|
+
break
|
|
716
|
+
else:
|
|
717
|
+
self.consecutive_converged_iters = 0
|
|
718
|
+
return None
|
|
719
|
+
|
|
720
|
+
def _build_convergence_plot(
|
|
721
|
+
self,
|
|
722
|
+
indiv_figsize: tuple[float, float] = (2.0, 1.2),
|
|
723
|
+
n_cols: int = 3,
|
|
724
|
+
):
|
|
725
|
+
"""
|
|
726
|
+
This method plots the evolution of the estimated parameters (MI, betas, omega, residual error variances) across iterations
|
|
727
|
+
"""
|
|
728
|
+
history = self.history
|
|
729
|
+
nb_MI: int = self.model.nb_MI
|
|
730
|
+
nb_betas: int = self.model.nb_betas
|
|
731
|
+
nb_omega_diag_params: int = self.model.nb_PDU
|
|
732
|
+
nb_var_res_params: int = self.model.nb_outputs
|
|
733
|
+
nb_plots = nb_MI + nb_betas + nb_omega_diag_params + nb_var_res_params + 2
|
|
734
|
+
nb_cols = n_cols
|
|
735
|
+
nb_rows = int(np.ceil(nb_plots / nb_cols))
|
|
736
|
+
maxiter = self.nb_phase1_iterations + self.nb_phase2_iterations
|
|
737
|
+
fig, axes = plt.subplots(
|
|
738
|
+
nrows=nb_rows,
|
|
739
|
+
ncols=nb_cols,
|
|
740
|
+
figsize=(
|
|
741
|
+
nb_cols * indiv_figsize[0],
|
|
742
|
+
nb_rows * indiv_figsize[1],
|
|
743
|
+
),
|
|
744
|
+
squeeze=False,
|
|
745
|
+
sharex="all",
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
self.traces = {}
|
|
749
|
+
plot_idx: int = 0
|
|
750
|
+
# Plot the MI parameters
|
|
751
|
+
for mi_name in self.model.MI_names:
|
|
752
|
+
row, col = plot_idx // nb_cols, plot_idx % nb_cols
|
|
753
|
+
ax = axes[row, col]
|
|
754
|
+
ax.set_xlim(0, maxiter)
|
|
755
|
+
MI_history = [h.item() for h in history[mi_name]]
|
|
756
|
+
(tr,) = ax.plot(
|
|
757
|
+
MI_history,
|
|
758
|
+
)
|
|
759
|
+
if hasattr(self, "true_log_MI"):
|
|
760
|
+
if self.true_log_MI is not None:
|
|
761
|
+
ax.axhline(
|
|
762
|
+
y=self.true_log_MI[mi_name],
|
|
763
|
+
linestyle="--",
|
|
764
|
+
)
|
|
765
|
+
if hasattr(self, "training_ranges"):
|
|
766
|
+
if self.training_ranges is not None:
|
|
767
|
+
ax.fill_between(
|
|
768
|
+
[0, maxiter],
|
|
769
|
+
self.training_ranges[mi_name]["low"],
|
|
770
|
+
self.training_ranges[mi_name]["high"],
|
|
771
|
+
alpha=0.25,
|
|
772
|
+
)
|
|
773
|
+
ax.set_title("Model intrinsic {MI_name}")
|
|
774
|
+
ax.grid(True)
|
|
775
|
+
self.traces.update({mi_name: tr})
|
|
776
|
+
plot_idx += 1
|
|
777
|
+
# Plot the PDUs means
|
|
778
|
+
for pdu in self.model.PDU_names:
|
|
779
|
+
row, col = plot_idx // nb_cols, plot_idx % nb_cols
|
|
780
|
+
ax = axes[row, col]
|
|
781
|
+
ax.set_xlim(0, maxiter)
|
|
782
|
+
beta_history = [h.item() for h in history[pdu]["mu"]]
|
|
783
|
+
(tr,) = ax.plot(
|
|
784
|
+
beta_history,
|
|
785
|
+
)
|
|
786
|
+
if hasattr(self, "true_log_PDUs"):
|
|
787
|
+
if self.true_log_PDUs is not None:
|
|
788
|
+
ax.axhline(
|
|
789
|
+
y=self.true_log_PDUs[pdu]["mean"],
|
|
790
|
+
linestyle="--",
|
|
791
|
+
)
|
|
792
|
+
if hasattr(self, "training_ranges"):
|
|
793
|
+
if self.training_ranges is not None:
|
|
794
|
+
ax.fill_between(
|
|
795
|
+
[0, maxiter],
|
|
796
|
+
self.training_ranges[pdu]["low"],
|
|
797
|
+
self.training_ranges[pdu]["high"],
|
|
798
|
+
alpha=0.25,
|
|
799
|
+
)
|
|
800
|
+
ax.set_title(rf"{pdu}: $\mu$ (log)")
|
|
801
|
+
ax.set_xlabel("")
|
|
802
|
+
ax.grid(True)
|
|
803
|
+
self.traces.update({pdu: {"mu": tr}})
|
|
804
|
+
plot_idx += 1
|
|
805
|
+
# Plot the PDUs sigma
|
|
806
|
+
for pdu in self.model.PDU_names:
|
|
807
|
+
row, col = plot_idx // nb_cols, plot_idx % nb_cols
|
|
808
|
+
ax = axes[row, col]
|
|
809
|
+
ax.set_xlim(0, maxiter)
|
|
810
|
+
beta_history = [h.item() for h in history[pdu]["sigma_sq"]]
|
|
811
|
+
(tr,) = ax.plot(
|
|
812
|
+
beta_history,
|
|
813
|
+
)
|
|
814
|
+
if hasattr(self, "true_log_PDUs"):
|
|
815
|
+
if self.true_log_PDUs is not None:
|
|
816
|
+
ax.axhline(
|
|
817
|
+
y=self.true_log_PDUs[pdu]["sd"],
|
|
818
|
+
linestyle="--",
|
|
819
|
+
)
|
|
820
|
+
ax.set_title(rf"{pdu}: $\sigma^2$")
|
|
821
|
+
ax.set_xlabel("")
|
|
822
|
+
ax.grid(True)
|
|
823
|
+
self.traces[pdu].update({"sigma_sq": tr})
|
|
824
|
+
plot_idx += 1
|
|
825
|
+
# Plot the coefficients of covariation
|
|
826
|
+
for beta_name in self.model.covariate_coeffs_names:
|
|
827
|
+
row, col = plot_idx // nb_cols, plot_idx % nb_cols
|
|
828
|
+
ax = axes[row, col]
|
|
829
|
+
ax.set_xlim(0, maxiter)
|
|
830
|
+
beta_history = [h.item() for h in history[beta_name]]
|
|
831
|
+
(tr,) = ax.plot(
|
|
832
|
+
beta_history,
|
|
833
|
+
)
|
|
834
|
+
if hasattr(self, "true_cov"):
|
|
835
|
+
if self.true_cov is not None:
|
|
836
|
+
ax.axhline(
|
|
837
|
+
y=self.true_cov[beta_name],
|
|
838
|
+
linestyle="--",
|
|
839
|
+
)
|
|
840
|
+
ax.set_title(rf"{beta_name}")
|
|
841
|
+
ax.set_xlabel("")
|
|
842
|
+
ax.grid(True)
|
|
843
|
+
self.traces.update({beta_name: tr})
|
|
844
|
+
plot_idx += 1
|
|
845
|
+
# Plot the residual variance
|
|
846
|
+
for res_name in self.model.outputs_names:
|
|
847
|
+
row, col = plot_idx // nb_cols, plot_idx % nb_cols
|
|
848
|
+
ax = axes[row, col]
|
|
849
|
+
ax.set_xlim(0, maxiter)
|
|
850
|
+
var_res_history = [h.item() for h in history[res_name]]
|
|
851
|
+
(tr,) = ax.plot(
|
|
852
|
+
var_res_history,
|
|
853
|
+
)
|
|
854
|
+
if hasattr(self, "true_res_var"):
|
|
855
|
+
if self.true_res_var is not None:
|
|
856
|
+
ax.axhline(
|
|
857
|
+
y=self.true_res_var[res_name],
|
|
858
|
+
linestyle="--",
|
|
859
|
+
)
|
|
860
|
+
ax.set_title(rf"{res_name}: $\sigma^2$")
|
|
861
|
+
ax.grid(True)
|
|
862
|
+
self.traces.update({res_name: tr})
|
|
863
|
+
plot_idx += 1
|
|
864
|
+
# Plot the convergence indicator (total log prob)
|
|
865
|
+
row, col = plot_idx // nb_cols, plot_idx % nb_cols
|
|
866
|
+
ax = axes[row, col]
|
|
867
|
+
ax.set_xlim(0, maxiter)
|
|
868
|
+
convergence_ind = [h.item() for h in history["complete_likelihood"]]
|
|
869
|
+
(tr,) = ax.plot(
|
|
870
|
+
convergence_ind,
|
|
871
|
+
)
|
|
872
|
+
ax.set_title(rf"Convergence indicator")
|
|
873
|
+
ax.grid(True)
|
|
874
|
+
self.traces.update({"convergence_ind": tr})
|
|
875
|
+
plot_idx += 1
|
|
876
|
+
# Plot the number of out of bounds patients
|
|
877
|
+
row, col = plot_idx // nb_cols, plot_idx % nb_cols
|
|
878
|
+
ax = axes[row, col]
|
|
879
|
+
ax.set_xlim(0, maxiter)
|
|
880
|
+
ax.set_ylim(0, self.model.nb_patients)
|
|
881
|
+
oob_patients = [len(h) for h in history["flagged_patients"]]
|
|
882
|
+
(tr,) = ax.plot(
|
|
883
|
+
oob_patients,
|
|
884
|
+
)
|
|
885
|
+
ax.set_title(rf"Out-of-bounds patients")
|
|
886
|
+
ax.grid(True)
|
|
887
|
+
self.traces.update({"oob_patients": tr})
|
|
888
|
+
plot_idx += 1
|
|
889
|
+
|
|
890
|
+
# Turn off extra subplots
|
|
891
|
+
while plot_idx < nb_rows * nb_cols:
|
|
892
|
+
row, col = plot_idx // nb_cols, plot_idx % nb_cols
|
|
893
|
+
ax = axes[row, col]
|
|
894
|
+
ax.set_visible(False)
|
|
895
|
+
plot_idx += 1
|
|
896
|
+
if not smoke_test:
|
|
897
|
+
plt.tight_layout()
|
|
898
|
+
handle = display(fig, display_id=True)
|
|
899
|
+
else:
|
|
900
|
+
handle = None
|
|
901
|
+
return (handle, fig, axes)
|
|
902
|
+
|
|
903
|
+
def _update_convergence_plot(self):
|
|
904
|
+
history = self.history
|
|
905
|
+
new_xaxis = np.arange(self.current_iteration + 1)
|
|
906
|
+
# Plot the MI parameters
|
|
907
|
+
for mi_name in self.model.MI_names:
|
|
908
|
+
MI_history = [h.item() for h in history[mi_name]]
|
|
909
|
+
self.traces[mi_name].set_data(new_xaxis, MI_history)
|
|
910
|
+
# Plot the PDUs means
|
|
911
|
+
for pdu in self.model.PDU_names:
|
|
912
|
+
beta_history = [h.item() for h in history[pdu]["mu"]]
|
|
913
|
+
self.traces[pdu]["mu"].set_data(new_xaxis, beta_history)
|
|
914
|
+
# Plot the PDUs sigma
|
|
915
|
+
for pdu in self.model.PDU_names:
|
|
916
|
+
beta_history = [h.item() for h in history[pdu]["sigma_sq"]]
|
|
917
|
+
self.traces[pdu]["sigma_sq"].set_data(new_xaxis, beta_history)
|
|
918
|
+
# Plot the coefficients of covariation
|
|
919
|
+
for beta_name in self.model.covariate_coeffs_names:
|
|
920
|
+
beta_history = [h.item() for h in history[beta_name]]
|
|
921
|
+
self.traces[beta_name].set_data(new_xaxis, beta_history)
|
|
922
|
+
# Plot the residual variance
|
|
923
|
+
for res_name in self.model.outputs_names:
|
|
924
|
+
var_res_history = [h.item() for h in history[res_name]]
|
|
925
|
+
self.traces[res_name].set_data(new_xaxis, var_res_history)
|
|
926
|
+
conv_ind = [h.item() for h in history["complete_likelihood"]]
|
|
927
|
+
self.traces["convergence_ind"].set_data(new_xaxis, conv_ind)
|
|
928
|
+
oob_patients = [len(h) for h in history["flagged_patients"]]
|
|
929
|
+
self.traces["oob_patients"].set_data(new_xaxis, oob_patients)
|
|
930
|
+
if not smoke_test:
|
|
931
|
+
for ax in self.convergence_plot_axes.flatten():
|
|
932
|
+
ax.autoscale_view(scaley=True, scalex=False)
|
|
933
|
+
ax.relim()
|
|
934
|
+
if self.convergence_plot_handle is not None:
|
|
935
|
+
self.convergence_plot_handle.update(self.convergence_plot_fig)
|
|
936
|
+
|
|
937
|
+
def plot_convergence_history(
|
|
938
|
+
self,
|
|
939
|
+
indiv_figsize: tuple[float, float] = (2.0, 1.2),
|
|
940
|
+
n_cols: int = 3,
|
|
941
|
+
):
|
|
942
|
+
handle, fig, axes = self._build_convergence_plot(
|
|
943
|
+
indiv_figsize=indiv_figsize, n_cols=n_cols
|
|
944
|
+
)
|
|
945
|
+
plt.close(fig)
|