vpop-calibration 2.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vpop_calibration/__init__.py +22 -0
- vpop_calibration/data_generation.py +186 -0
- vpop_calibration/diagnostics.py +162 -0
- vpop_calibration/model/__init__.py +3 -0
- vpop_calibration/model/data.py +420 -0
- vpop_calibration/model/gp.py +517 -0
- vpop_calibration/model/plot.py +243 -0
- vpop_calibration/nlme.py +840 -0
- vpop_calibration/ode.py +203 -0
- vpop_calibration/saem.py +945 -0
- vpop_calibration/structural_model.py +200 -0
- vpop_calibration/test/__init__.py +11 -0
- vpop_calibration/test/test_data.py +21 -0
- vpop_calibration/test/test_gp_flavors.py +89 -0
- vpop_calibration/test/test_gp_saem.py +175 -0
- vpop_calibration/test/test_ode_saem.py +121 -0
- vpop_calibration/utils.py +9 -0
- vpop_calibration/vpop.py +50 -0
- vpop_calibration-2.2.8.dist-info/METADATA +78 -0
- vpop_calibration-2.2.8.dist-info/RECORD +22 -0
- vpop_calibration-2.2.8.dist-info/WHEEL +4 -0
- vpop_calibration-2.2.8.dist-info/licenses/LICENSE +21 -0
vpop_calibration/nlme.py
ADDED
|
@@ -0,0 +1,840 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Union, Optional
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from .structural_model import StructuralModel
|
|
7
|
+
from .utils import device
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NlmeModel:
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
structural_model: StructuralModel,
|
|
14
|
+
patients_df: pd.DataFrame,
|
|
15
|
+
init_log_MI: dict[str, float],
|
|
16
|
+
init_PDU: dict[str, dict[str, float]],
|
|
17
|
+
init_res_var: list[float],
|
|
18
|
+
covariate_map: Optional[dict[str, dict[str, dict[str, str | float]]]] = None,
|
|
19
|
+
error_model_type: str = "additive",
|
|
20
|
+
pred_var_threshold: float = 1e-2,
|
|
21
|
+
):
|
|
22
|
+
"""Create a non-linear mixed effects model
|
|
23
|
+
|
|
24
|
+
Using a structural model (simulation model) and a covariate structure, create a non-linear mixed effects model, to be used in PySAEM or another optimizer, or to predict data using a covariance structure.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
structural_model (StructuralModel): A simulation model defined via the convenience class StructuralModel
|
|
28
|
+
patients_df (DataFrame): the list of patients to be considered, with potential covariate values listed, and the name of the protocol arm on which the patient was evaluated (optional - if not supplied, `identity` will be used). The `id` column is expected, any additional column will be handled as a covariate
|
|
29
|
+
init_log_MI: for each model intrinsic parameter, provide an initial value (log)
|
|
30
|
+
init_PDU: for each patient descriptor unknown parameter, provide an initial mean and sd of the log
|
|
31
|
+
init_res_var: for each model output, provide an initial residual variance
|
|
32
|
+
covariate_map (optional[dict]): for each PDU, the list of covariates that affect it - each associated with a covariation coefficient (to be calibrated)
|
|
33
|
+
Example
|
|
34
|
+
{"pdu_name":
|
|
35
|
+
{"covariate_name":
|
|
36
|
+
{"coef": "coef_name", "value": initial_value}
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
error_model_type (str): either `additive` or `proportional` error model
|
|
40
|
+
pred_var_threshold (float): Threshold of predictive variance that will issue a warning. Default 1e-2.
|
|
41
|
+
"""
|
|
42
|
+
self.structural_model: StructuralModel = structural_model
|
|
43
|
+
self.pred_var_threshold = pred_var_threshold
|
|
44
|
+
|
|
45
|
+
self.MI_names: list[str] = list(init_log_MI.keys())
|
|
46
|
+
self.nb_MI: int = len(self.MI_names)
|
|
47
|
+
self.initial_log_MI = torch.tensor([val for _, val in init_log_MI.items()]).to(
|
|
48
|
+
device
|
|
49
|
+
)
|
|
50
|
+
self.PDU_names: list[str] = list(init_PDU.keys())
|
|
51
|
+
self.nb_PDU: int = len(self.PDU_names)
|
|
52
|
+
|
|
53
|
+
if set(self.MI_names) & set(self.PDU_names):
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Overlapping model intrinsic and PDU descriptors:{(set(self.MI_names) & set(self.PDU_names))}"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
self.patients_df: pd.DataFrame = patients_df.drop_duplicates()
|
|
59
|
+
self.patients: list[str | int] = self.patients_df["id"].unique().tolist()
|
|
60
|
+
self.nb_patients: int = len(self.patients)
|
|
61
|
+
covariate_columns = self.patients_df.columns.to_list()
|
|
62
|
+
if "protocol_arm" not in covariate_columns:
|
|
63
|
+
self.patients_df["protocol_arm"] = "identity"
|
|
64
|
+
|
|
65
|
+
additional_columns: list[str] = self.patients_df.drop(
|
|
66
|
+
["id", "protocol_arm"], axis=1
|
|
67
|
+
).columns.tolist()
|
|
68
|
+
|
|
69
|
+
init_betas_list: list = []
|
|
70
|
+
if covariate_map is None:
|
|
71
|
+
print(
|
|
72
|
+
f"No covariate map provided. All additional columns in `patients_df` will be handled as known descriptors: {additional_columns}"
|
|
73
|
+
)
|
|
74
|
+
self.covariate_map = None
|
|
75
|
+
self.covariate_names = []
|
|
76
|
+
self.covariate_coeffs_names = []
|
|
77
|
+
self.nb_covariates = 0
|
|
78
|
+
self.population_betas_names = self.PDU_names
|
|
79
|
+
init_betas_list = [val["mean"] for _, val in init_PDU.items()]
|
|
80
|
+
self.PDK_names = additional_columns
|
|
81
|
+
self.nb_PDK = len(self.PDK_names)
|
|
82
|
+
else:
|
|
83
|
+
self.covariate_map = covariate_map
|
|
84
|
+
self.population_betas_names: list = []
|
|
85
|
+
covariate_set = set()
|
|
86
|
+
covariate_coeffs_set = set()
|
|
87
|
+
pdk_names = set(additional_columns)
|
|
88
|
+
for PDU_name in self.PDU_names:
|
|
89
|
+
self.population_betas_names.append(PDU_name)
|
|
90
|
+
init_betas_list.append(init_PDU[PDU_name]["mean"])
|
|
91
|
+
if PDU_name not in covariate_map:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"No covariate map listed for {PDU_name}. Add an empty set if it has no covariate."
|
|
94
|
+
)
|
|
95
|
+
for covariate, coef in self.covariate_map[PDU_name].items():
|
|
96
|
+
if covariate not in additional_columns:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Covariate appears in the map but not in the patient set: {covariate}"
|
|
99
|
+
)
|
|
100
|
+
if covariate is not None:
|
|
101
|
+
covariate_set.add(covariate)
|
|
102
|
+
if covariate in pdk_names:
|
|
103
|
+
pdk_names.remove(covariate)
|
|
104
|
+
coef_name = coef["coef"]
|
|
105
|
+
covariate_coeffs_set.add(coef_name)
|
|
106
|
+
coef_val = coef["value"]
|
|
107
|
+
self.population_betas_names.append(coef_name)
|
|
108
|
+
init_betas_list.append(coef_val)
|
|
109
|
+
self.covariate_names = list(covariate_set)
|
|
110
|
+
self.covariate_coeffs_names = list(covariate_coeffs_set)
|
|
111
|
+
self.nb_covariates = len(self.covariate_names)
|
|
112
|
+
self.PDK_names = list(pdk_names)
|
|
113
|
+
self.nb_PDK = len(self.PDK_names)
|
|
114
|
+
|
|
115
|
+
print(f"Successfully loaded {self.nb_covariates} covariates:")
|
|
116
|
+
print(self.covariate_names)
|
|
117
|
+
if self.nb_PDK > 0:
|
|
118
|
+
self.patients_pdk = {}
|
|
119
|
+
for patient in self.patients:
|
|
120
|
+
row = self.patients_df.loc[
|
|
121
|
+
self.patients_df["id"] == patient
|
|
122
|
+
].drop_duplicates()
|
|
123
|
+
self.patients_pdk.update(
|
|
124
|
+
{
|
|
125
|
+
patient: torch.as_tensor(
|
|
126
|
+
row[self.PDK_names].values, device=device
|
|
127
|
+
)
|
|
128
|
+
}
|
|
129
|
+
)
|
|
130
|
+
# Store the full pdk tensor on the device
|
|
131
|
+
self.patients_pdk_full = torch.cat(
|
|
132
|
+
[self.patients_pdk[ind] for ind in self.patients]
|
|
133
|
+
).to(device)
|
|
134
|
+
print(f"Successfully loaded {self.nb_PDK} known descriptors:")
|
|
135
|
+
print(self.PDK_names)
|
|
136
|
+
else:
|
|
137
|
+
# Create an empty pdk tensor
|
|
138
|
+
self.patients_pdk_full = torch.empty((self.nb_patients, 0), device=device)
|
|
139
|
+
|
|
140
|
+
if set(self.PDK_names + self.PDU_names + self.MI_names) != set(
|
|
141
|
+
self.structural_model.parameter_names
|
|
142
|
+
):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"Non-matching descriptor set and structural model parameter set:\n{set(self.PDK_names + self.PDU_names + self.MI_names)}\n{set(self.structural_model.parameter_names)}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
self.descriptors: list[str] = self.PDK_names + self.PDU_names + self.MI_names
|
|
148
|
+
self.nb_descriptors: int = len(self.descriptors)
|
|
149
|
+
# Assume that the descriptors will always be provided to the model in the following order:
|
|
150
|
+
# PDK, PDU, MI
|
|
151
|
+
self.model_input_to_descriptor = torch.as_tensor(
|
|
152
|
+
[
|
|
153
|
+
self.descriptors.index(param)
|
|
154
|
+
for param in self.structural_model.parameter_names
|
|
155
|
+
],
|
|
156
|
+
device=device,
|
|
157
|
+
).long()
|
|
158
|
+
self.initial_betas = torch.as_tensor(init_betas_list, device=device)
|
|
159
|
+
self.nb_betas: int = len(self.population_betas_names)
|
|
160
|
+
self.outputs_names: list[str] = self.structural_model.output_names
|
|
161
|
+
self.nb_outputs: int = self.structural_model.nb_outputs
|
|
162
|
+
self.error_model_type: str = error_model_type
|
|
163
|
+
self.init_res_var = torch.as_tensor(init_res_var, device=device)
|
|
164
|
+
self.init_omega = torch.diag(
|
|
165
|
+
torch.as_tensor(
|
|
166
|
+
[float(init_PDU[pdu]["sd"]) for pdu in self.PDU_names], device=device
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Assemble the list of design matrices from the covariance structure
|
|
171
|
+
self.design_matrices = self._create_all_design_matrices()
|
|
172
|
+
# Store the full design matrix on the device
|
|
173
|
+
self.full_design_matrix = torch.stack(
|
|
174
|
+
[self.design_matrices[p] for p in self.patients]
|
|
175
|
+
).to(device)
|
|
176
|
+
|
|
177
|
+
# Initiate the nlme parameters
|
|
178
|
+
self.log_MI = self.initial_log_MI
|
|
179
|
+
self.population_betas = self.initial_betas
|
|
180
|
+
self.omega_pop = self.init_omega
|
|
181
|
+
self.omega_pop_lower_chol = torch.linalg.cholesky(self.omega_pop).to(device)
|
|
182
|
+
self.residual_var = self.init_res_var
|
|
183
|
+
self.eta_distribution = torch.distributions.MultivariateNormal(
|
|
184
|
+
loc=torch.zeros(self.nb_PDU, device=device),
|
|
185
|
+
covariance_matrix=self.omega_pop,
|
|
186
|
+
)
|
|
187
|
+
self.current_eta_samples = self.sample_individual_etas()
|
|
188
|
+
self.current_map_estimates = self.individual_parameters(
|
|
189
|
+
self.current_eta_samples
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def _create_design_matrix(self, covariates: dict[str, float]) -> torch.Tensor:
|
|
193
|
+
"""
|
|
194
|
+
Creates the design matrix X_i for a single individual based on the model's covariate map.
|
|
195
|
+
This matrix will be multiplied with population betas so that log(theta_i[PDU]) = X_i @ betas + eta_i.
|
|
196
|
+
"""
|
|
197
|
+
design_matrix_X_i = torch.zeros((self.nb_PDU, self.nb_betas), device=device)
|
|
198
|
+
col_idx = 0
|
|
199
|
+
for i, PDU_name in enumerate(self.PDU_names):
|
|
200
|
+
design_matrix_X_i[i, col_idx] = 1.0
|
|
201
|
+
col_idx += 1
|
|
202
|
+
if self.covariate_map is not None:
|
|
203
|
+
for covariate in self.covariate_map[PDU_name].keys():
|
|
204
|
+
design_matrix_X_i[i, col_idx] = float(covariates[covariate])
|
|
205
|
+
col_idx += 1
|
|
206
|
+
return design_matrix_X_i
|
|
207
|
+
|
|
208
|
+
def _create_all_design_matrices(self) -> dict[Union[str, int], torch.Tensor]:
|
|
209
|
+
"""Creates a design matrix for each unique individual based on their covariates, given the in the covariates_df."""
|
|
210
|
+
design_matrices = {}
|
|
211
|
+
if self.nb_covariates == 0:
|
|
212
|
+
for ind_id in self.patients:
|
|
213
|
+
design_matrices[ind_id] = self._create_design_matrix({})
|
|
214
|
+
else:
|
|
215
|
+
for ind_id in self.patients:
|
|
216
|
+
individual_covariates = (
|
|
217
|
+
self.patients_df[self.patients_df["id"] == ind_id]
|
|
218
|
+
.iloc[0]
|
|
219
|
+
.drop("id")
|
|
220
|
+
)
|
|
221
|
+
covariates_dict = individual_covariates.to_dict()
|
|
222
|
+
design_matrices[ind_id] = self._create_design_matrix(covariates_dict)
|
|
223
|
+
return design_matrices
|
|
224
|
+
|
|
225
|
+
def add_observations(self, observations_df: pd.DataFrame) -> None:
|
|
226
|
+
"""Associate the NLME model with a data frame of observations
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
observations_df (pd.DataFrame): A data frame of observations, with columns
|
|
230
|
+
- `id`: the patient id. Should be consistent with self.patients_df
|
|
231
|
+
- `time`: the observation time
|
|
232
|
+
- `output_name`
|
|
233
|
+
- `value`
|
|
234
|
+
"""
|
|
235
|
+
# Store the raw data frame
|
|
236
|
+
self.observations_df = observations_df
|
|
237
|
+
# Data validation
|
|
238
|
+
input_columns = observations_df.columns.tolist()
|
|
239
|
+
unique_outputs = observations_df["output_name"].unique().tolist()
|
|
240
|
+
if "id" not in input_columns:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"Provided observation data frame should contain `id` column."
|
|
243
|
+
)
|
|
244
|
+
input_patients = observations_df["id"].unique()
|
|
245
|
+
if set(input_patients) != set(self.patients):
|
|
246
|
+
# Note this check might be unnecessary
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"Missing observations for the following patients: {set(self.patients) - set(input_patients)}"
|
|
249
|
+
)
|
|
250
|
+
if "time" not in input_columns:
|
|
251
|
+
raise ValueError(
|
|
252
|
+
"Provided observation data frame should contain `time` column."
|
|
253
|
+
)
|
|
254
|
+
if not (set(unique_outputs) <= set(self.outputs_names)):
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Unknown model output: {set(unique_outputs) - set(self.outputs_names)}"
|
|
257
|
+
)
|
|
258
|
+
if hasattr(self, "observations_tensors"):
|
|
259
|
+
print(
|
|
260
|
+
"Warning: overriding existing observation data frame for the NLME model"
|
|
261
|
+
)
|
|
262
|
+
if "value" not in input_columns:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"The provided observations data frame does not contain a `value` column."
|
|
265
|
+
)
|
|
266
|
+
processed_df = observations_df[["id", "output_name", "time", "value"]].merge(
|
|
267
|
+
self.patients_df, how="left", on="id"
|
|
268
|
+
)
|
|
269
|
+
processed_df["task"] = processed_df.apply(
|
|
270
|
+
lambda r: r["output_name"] + "_" + r["protocol_arm"], axis=1
|
|
271
|
+
)
|
|
272
|
+
processed_df["task_index"] = processed_df["task"].apply(
|
|
273
|
+
lambda t: self.structural_model.tasks.index(t)
|
|
274
|
+
)
|
|
275
|
+
processed_df["output_index"] = processed_df["output_name"].apply(
|
|
276
|
+
lambda o: self.structural_model.output_names.index(o)
|
|
277
|
+
)
|
|
278
|
+
global_time_steps = (
|
|
279
|
+
processed_df["time"].drop_duplicates().sort_values().to_list()
|
|
280
|
+
)
|
|
281
|
+
processed_df["time_step_index"] = processed_df["time"].apply(
|
|
282
|
+
lambda t: global_time_steps.index(t)
|
|
283
|
+
)
|
|
284
|
+
self.global_time_steps = torch.as_tensor(global_time_steps, device=device)
|
|
285
|
+
self.nb_global_time_steps = self.global_time_steps.shape[0]
|
|
286
|
+
self.global_time_steps_expanded = (
|
|
287
|
+
self.global_time_steps.unsqueeze(0)
|
|
288
|
+
.unsqueeze(-1)
|
|
289
|
+
.repeat((self.nb_patients, 1, 1))
|
|
290
|
+
)
|
|
291
|
+
# Browse the observed data set and store relevant elements
|
|
292
|
+
self.observations_tensors: dict = {}
|
|
293
|
+
self.n_tot_observations_per_output = torch.zeros(self.nb_outputs, device=device)
|
|
294
|
+
for ind, patient in enumerate(self.patients):
|
|
295
|
+
this_patient = processed_df.loc[processed_df["id"] == patient]
|
|
296
|
+
|
|
297
|
+
tasks_indices_np = this_patient["task_index"].values
|
|
298
|
+
tasks_indices = torch.as_tensor(tasks_indices_np, device=device).long()
|
|
299
|
+
|
|
300
|
+
outputs_indices_np = this_patient["output_index"].values
|
|
301
|
+
outputs_indices = torch.as_tensor(outputs_indices_np, device=device).long()
|
|
302
|
+
# Add counts of observations to the total per output
|
|
303
|
+
self.n_tot_observations_per_output.scatter_add_(
|
|
304
|
+
0,
|
|
305
|
+
outputs_indices,
|
|
306
|
+
torch.ones_like(outputs_indices, device=device, dtype=torch.float64),
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
outputs = torch.as_tensor(this_patient["value"].values, device=device)
|
|
310
|
+
|
|
311
|
+
time_steps = torch.as_tensor(this_patient["time"].values, device=device)
|
|
312
|
+
time_step_indices = torch.as_tensor(
|
|
313
|
+
this_patient["time_step_index"].values, device=device
|
|
314
|
+
).long()
|
|
315
|
+
p_index_repeated = torch.full(
|
|
316
|
+
outputs.shape, ind, dtype=torch.int64, device=device
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
self.observations_tensors.update(
|
|
320
|
+
{
|
|
321
|
+
patient: {
|
|
322
|
+
"observations": outputs,
|
|
323
|
+
"time_steps": time_steps,
|
|
324
|
+
"time_step_indices": time_step_indices,
|
|
325
|
+
"tasks_indices": tasks_indices,
|
|
326
|
+
"outputs_indices": outputs_indices,
|
|
327
|
+
"p_index_repeated": p_index_repeated,
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Build the full data set tensors
|
|
333
|
+
self.full_obs_data = torch.cat(
|
|
334
|
+
[self.observations_tensors[p]["observations"] for p in self.patients]
|
|
335
|
+
).to(device)
|
|
336
|
+
# Construct the indexing tensors
|
|
337
|
+
self.observation_to_patient_index = (
|
|
338
|
+
torch.cat(
|
|
339
|
+
[
|
|
340
|
+
self.observations_tensors[p]["p_index_repeated"]
|
|
341
|
+
for p in self.patients
|
|
342
|
+
]
|
|
343
|
+
)
|
|
344
|
+
.long()
|
|
345
|
+
.to(device)
|
|
346
|
+
)
|
|
347
|
+
self.observation_to_timestep_index = (
|
|
348
|
+
torch.cat(
|
|
349
|
+
[
|
|
350
|
+
self.observations_tensors[p]["time_step_indices"]
|
|
351
|
+
for p in self.patients
|
|
352
|
+
]
|
|
353
|
+
)
|
|
354
|
+
.long()
|
|
355
|
+
.to(device)
|
|
356
|
+
)
|
|
357
|
+
self.observation_to_task_index = (
|
|
358
|
+
torch.cat(
|
|
359
|
+
[self.observations_tensors[p]["tasks_indices"] for p in self.patients]
|
|
360
|
+
)
|
|
361
|
+
.long()
|
|
362
|
+
.to(device)
|
|
363
|
+
)
|
|
364
|
+
# Construct a tuple allowing to index a 3D tensor of outputs into a 1D tensor of outputs
|
|
365
|
+
self.prediction_index = (
|
|
366
|
+
self.observation_to_patient_index,
|
|
367
|
+
self.observation_to_timestep_index,
|
|
368
|
+
self.observation_to_task_index,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
self.full_output_indices = (
|
|
372
|
+
torch.cat(
|
|
373
|
+
[self.observations_tensors[p]["outputs_indices"] for p in self.patients]
|
|
374
|
+
)
|
|
375
|
+
.long()
|
|
376
|
+
.to(device)
|
|
377
|
+
)
|
|
378
|
+
self.chunk_sizes: list[int] = [
|
|
379
|
+
self.observations_tensors[p]["observations"].cpu().shape[0]
|
|
380
|
+
for p in self.patients
|
|
381
|
+
]
|
|
382
|
+
|
|
383
|
+
def update_omega(self, omega: torch.Tensor) -> None:
|
|
384
|
+
"""Update the covariance matrix of the NLME model."""
|
|
385
|
+
assert (
|
|
386
|
+
self.omega_pop.shape == omega.shape
|
|
387
|
+
), f"Wrong omega shape: {omega.shape}, expected: {self.omega_pop.shape}"
|
|
388
|
+
self.omega_pop = omega
|
|
389
|
+
self.omega_pop_lower_chol = torch.linalg.cholesky(self.omega_pop).to(device)
|
|
390
|
+
self.eta_distribution = torch.distributions.MultivariateNormal(
|
|
391
|
+
loc=torch.zeros(self.nb_PDU, device=device),
|
|
392
|
+
covariance_matrix=self.omega_pop,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
def update_res_var(self, residual_var: torch.Tensor) -> None:
|
|
396
|
+
"""Update the residual variance of the NLME model, while ensuring it remains positive."""
|
|
397
|
+
assert (
|
|
398
|
+
self.residual_var.shape == residual_var.shape
|
|
399
|
+
), f"Wrong res var shape: {residual_var.shape}, expected: {self.residual_var.shape}"
|
|
400
|
+
self.residual_var = residual_var.clamp(min=1e-6)
|
|
401
|
+
|
|
402
|
+
def update_betas(self, betas: torch.Tensor) -> None:
|
|
403
|
+
"""Update the betas of the NLME model."""
|
|
404
|
+
assert (
|
|
405
|
+
self.population_betas.shape == betas.shape
|
|
406
|
+
), f"Wrong beta shape: {betas.shape}, expected: {self.population_betas.shape}"
|
|
407
|
+
self.population_betas = betas
|
|
408
|
+
|
|
409
|
+
def update_log_mi(self, log_MI: torch.Tensor) -> None:
|
|
410
|
+
"""Update the model intrinsic parameter values of the NLME model."""
|
|
411
|
+
assert (
|
|
412
|
+
self.log_MI.shape == log_MI.shape
|
|
413
|
+
), f"Wrong MI shape: {log_MI.shape}, expected: {self.log_MI.shape}"
|
|
414
|
+
self.log_MI = log_MI
|
|
415
|
+
|
|
416
|
+
def update_eta_samples(self, eta: torch.Tensor) -> None:
|
|
417
|
+
"""Update the model current individual random effect sampels."""
|
|
418
|
+
assert (
|
|
419
|
+
self.current_eta_samples.shape == eta.shape
|
|
420
|
+
), f"Wrong individual samples shape: {eta.shape}, expected: {self.current_eta_samples.shape}"
|
|
421
|
+
self.current_eta_samples = eta
|
|
422
|
+
|
|
423
|
+
def update_map_estimates(self, theta: torch.Tensor) -> None:
|
|
424
|
+
"""Update the model current maximum a posteriori estimates."""
|
|
425
|
+
assert (
|
|
426
|
+
self.current_map_estimates.shape == theta.shape
|
|
427
|
+
), f"Wrong individual parameters shape: {theta.shape}, expected: {self.current_map_estimates.shape}"
|
|
428
|
+
self.current_map_estimates = theta
|
|
429
|
+
|
|
430
|
+
def sample_individual_etas(self) -> torch.Tensor:
|
|
431
|
+
"""Sample individual random effects from the current estimate of Omega
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
torch.Tensor (size nb_patients x nb_PDUs): individual random effects for all patients in the population
|
|
435
|
+
"""
|
|
436
|
+
etas_dist = self.eta_distribution.expand([self.nb_patients])
|
|
437
|
+
etas = etas_dist.sample()
|
|
438
|
+
return etas
|
|
439
|
+
|
|
440
|
+
def individual_parameters(
|
|
441
|
+
self,
|
|
442
|
+
individual_etas: torch.Tensor,
|
|
443
|
+
) -> torch.Tensor:
|
|
444
|
+
"""Compute individual patient parameters
|
|
445
|
+
|
|
446
|
+
Transforms log(MI) (Model intrinsic), betas: log(mu)s & coeffs for covariates and individual random effects (etas) into individual parameters (theta_i), for each set of etas of the list and corresponding design matrix.
|
|
447
|
+
Assumes log-normal distribution for individual parameters and covariate effects: theta_i[PDU] = mu_pop * exp(eta_i) * exp(covariates_i * cov_coeffs) where eta_i is from N(0, Omega) and theta_i[MI]=MI.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
individual_etas (torch.Tensor): one set of sampled random effects for each patient
|
|
451
|
+
ind_ids_for_etas (list[Union[str, int]]): list of individual ids corresponding to the sampled etas, used to fetch the design matrices
|
|
452
|
+
Returns:
|
|
453
|
+
torch.Tensor [nb_patients x nb_parameters]: One parameter set for each patient. Dim 0 corresponds to the patients, dim 1 is the parameters
|
|
454
|
+
"""
|
|
455
|
+
# Compute the inidividual PDU
|
|
456
|
+
log_thetas_PDU = (
|
|
457
|
+
self.full_design_matrix @ self.population_betas + individual_etas
|
|
458
|
+
)
|
|
459
|
+
# Gather the MI values, and expand them (same for each patient)
|
|
460
|
+
log_MI_expanded = self.log_MI.unsqueeze(0).repeat(self.nb_patients, 1)
|
|
461
|
+
|
|
462
|
+
# list the PDK values for each patient, and assemble them in a tensor
|
|
463
|
+
# This step is crucial: we need to ensure the parameters are stored in the correct order
|
|
464
|
+
# PDK, PDU, MI
|
|
465
|
+
thetas = torch.cat(
|
|
466
|
+
(
|
|
467
|
+
self.patients_pdk_full,
|
|
468
|
+
torch.exp(torch.cat((log_thetas_PDU, log_MI_expanded), dim=1)),
|
|
469
|
+
),
|
|
470
|
+
dim=1,
|
|
471
|
+
)
|
|
472
|
+
return thetas
|
|
473
|
+
|
|
474
|
+
def struc_model_inputs_from_theta(self, thetas: torch.Tensor) -> torch.Tensor:
|
|
475
|
+
"""Return model inputs for all patients
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
thetas (torch.Tensor): Parameter values per patient (one by row)
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
torch.Tensor: the full inputs required to simulate all patients on all time steps
|
|
482
|
+
"""
|
|
483
|
+
|
|
484
|
+
if not hasattr(self, "observations_tensors"):
|
|
485
|
+
raise ValueError(
|
|
486
|
+
"Cannot compute patient predictions without an associated observations data frame."
|
|
487
|
+
)
|
|
488
|
+
# Order the columns of theta, and add a repeat dimension to cover time steps
|
|
489
|
+
theta_expanded = (
|
|
490
|
+
thetas[:, self.model_input_to_descriptor]
|
|
491
|
+
.unsqueeze(1)
|
|
492
|
+
.repeat((1, self.nb_global_time_steps, 1))
|
|
493
|
+
)
|
|
494
|
+
full_inputs = torch.cat(
|
|
495
|
+
(
|
|
496
|
+
theta_expanded,
|
|
497
|
+
self.global_time_steps_expanded,
|
|
498
|
+
),
|
|
499
|
+
dim=2,
|
|
500
|
+
)
|
|
501
|
+
assert full_inputs.shape == (
|
|
502
|
+
self.nb_patients,
|
|
503
|
+
self.nb_global_time_steps,
|
|
504
|
+
self.nb_descriptors + 1,
|
|
505
|
+
)
|
|
506
|
+
return full_inputs
|
|
507
|
+
|
|
508
|
+
def predict_outputs_from_theta(
|
|
509
|
+
self, thetas: torch.Tensor
|
|
510
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
511
|
+
"""Return model predictions for all patients
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
thetas (torch.Tensor): Parameter values per patient (one by row)
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
list[torch.Tensor]: a tensor of predictions for each patient
|
|
518
|
+
"""
|
|
519
|
+
model_inputs = self.struc_model_inputs_from_theta(thetas)
|
|
520
|
+
# shape: [nb_patients, nb_time_steps, nb_params + 1]
|
|
521
|
+
pred_mean, pred_var = self.structural_model.simulate(
|
|
522
|
+
model_inputs,
|
|
523
|
+
self.prediction_index,
|
|
524
|
+
self.chunk_sizes,
|
|
525
|
+
)
|
|
526
|
+
return pred_mean, pred_var
|
|
527
|
+
|
|
528
|
+
def add_residual_error(self, outputs: torch.Tensor) -> torch.Tensor:
|
|
529
|
+
res_var = self.residual_var.index_select(0, self.full_output_indices)
|
|
530
|
+
noise = torch.distributions.Normal(
|
|
531
|
+
torch.zeros(self.full_output_indices.shape[0], device=device), res_var
|
|
532
|
+
).sample()
|
|
533
|
+
if self.error_model_type == "additive":
|
|
534
|
+
new_out = outputs + noise
|
|
535
|
+
elif self.error_model_type == "proportional":
|
|
536
|
+
new_out = outputs * noise
|
|
537
|
+
else:
|
|
538
|
+
raise ValueError(f"Non-implemented error model {self.error_model_type}")
|
|
539
|
+
return new_out
|
|
540
|
+
|
|
541
|
+
def outputs_to_df(self, outputs: torch.Tensor) -> pd.DataFrame:
|
|
542
|
+
"""Transform the NLME model outputs to a data frame in order to compare with observed data
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
outputs (torch.Tensor): Outputs from `self.predict_outputs_from_theta`
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
pd.DataFrame: A data frame containing the following columns
|
|
549
|
+
- `id`
|
|
550
|
+
- `output_name`
|
|
551
|
+
- `protocol_arm`
|
|
552
|
+
- `time`
|
|
553
|
+
- `predicted_value`
|
|
554
|
+
"""
|
|
555
|
+
outputs_list = outputs.cpu().split(self.chunk_sizes)
|
|
556
|
+
df_list = []
|
|
557
|
+
for ind_idx, ind in enumerate(self.patients):
|
|
558
|
+
time_steps = self.observations_tensors[ind]["time_steps"]
|
|
559
|
+
task_list = self.observations_tensors[ind]["tasks_indices"]
|
|
560
|
+
temp_df = pd.DataFrame(
|
|
561
|
+
{
|
|
562
|
+
"time": time_steps.cpu().numpy(),
|
|
563
|
+
"id": ind,
|
|
564
|
+
"task_index": task_list.cpu(),
|
|
565
|
+
"predicted_value": outputs_list[ind_idx].numpy(),
|
|
566
|
+
}
|
|
567
|
+
)
|
|
568
|
+
temp_df["output_name"] = temp_df["task_index"].apply(
|
|
569
|
+
lambda t: self.outputs_names[
|
|
570
|
+
self.structural_model.task_idx_to_output_idx[t]
|
|
571
|
+
]
|
|
572
|
+
)
|
|
573
|
+
temp_df["protocol_arm"] = temp_df["task_index"].apply(
|
|
574
|
+
lambda t: self.structural_model.task_idx_to_protocol[t]
|
|
575
|
+
)
|
|
576
|
+
df_list.append(temp_df)
|
|
577
|
+
out_df = pd.concat(df_list)
|
|
578
|
+
out_df = out_df.drop(columns=["task_index"])
|
|
579
|
+
return out_df
|
|
580
|
+
|
|
581
|
+
def _log_prior_etas(self, etas: torch.Tensor) -> torch.Tensor:
|
|
582
|
+
"""Compute log-prior of random effect samples (etas)
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
etas (torch.Tensor): Individual samples, assuming eta_i ~ N(0, Omega)
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
torch.Tensor [nb_eta_i x nb_PDU]: Values of log-prior, computed according to:
|
|
589
|
+
|
|
590
|
+
P(eta) = (1/sqrt((2pi)^k * |Omega|)) * exp(-0.5 * eta.T * omega.inv * eta)
|
|
591
|
+
log P(eta) = -0.5 * (k * log(2pi) + log|Omega| + eta.T * omega.inv * eta)
|
|
592
|
+
|
|
593
|
+
"""
|
|
594
|
+
|
|
595
|
+
log_priors: torch.Tensor = self.eta_distribution.log_prob(etas).to(device)
|
|
596
|
+
return log_priors
|
|
597
|
+
|
|
598
|
+
def log_posterior_etas(
|
|
599
|
+
self,
|
|
600
|
+
etas: torch.Tensor,
|
|
601
|
+
) -> tuple[
|
|
602
|
+
torch.Tensor,
|
|
603
|
+
torch.Tensor,
|
|
604
|
+
torch.Tensor,
|
|
605
|
+
torch.Tensor,
|
|
606
|
+
list,
|
|
607
|
+
]:
|
|
608
|
+
"""Compute the log-posterior of a list of random effects
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
etas (torch.Tensor): Random effects samples
|
|
612
|
+
|
|
613
|
+
Returns:
|
|
614
|
+
tuple[torch.Tensor, list[torch.Tensor], DataFrame]:
|
|
615
|
+
- log-posterior likelihood of etas
|
|
616
|
+
- current thetas
|
|
617
|
+
- log values of current pdu estimation (useful for SAEM)
|
|
618
|
+
- list of simulated values for each patient
|
|
619
|
+
|
|
620
|
+
"""
|
|
621
|
+
if not hasattr(self, "observations_tensors"):
|
|
622
|
+
raise ValueError(
|
|
623
|
+
"Cannot compute log-posterior without an associated observations data frame."
|
|
624
|
+
)
|
|
625
|
+
# Get individual parameters in a tensor
|
|
626
|
+
individual_params: torch.Tensor = self.individual_parameters(
|
|
627
|
+
individual_etas=etas,
|
|
628
|
+
)
|
|
629
|
+
# Run the surrogate model
|
|
630
|
+
full_pred, full_var = self.predict_outputs_from_theta(individual_params)
|
|
631
|
+
var_list = full_var.split(self.chunk_sizes)
|
|
632
|
+
# Validate the variance magnitude
|
|
633
|
+
warnings = self.variance_level_check(var_list, self.pred_var_threshold)
|
|
634
|
+
flagged_patients = [individual_params[i, :].tolist() for i in warnings]
|
|
635
|
+
|
|
636
|
+
# calculate log-prior of the random samples
|
|
637
|
+
log_priors: torch.Tensor = self._log_prior_etas(etas)
|
|
638
|
+
|
|
639
|
+
# group by individual and calculate log-likelihood for each
|
|
640
|
+
log_likelihood_observations = self.log_likelihood_observation(full_pred)
|
|
641
|
+
|
|
642
|
+
log_posterior = log_likelihood_observations + log_priors
|
|
643
|
+
current_log_pdu = torch.log(
|
|
644
|
+
individual_params[:, self.nb_PDK : self.nb_PDK + self.nb_PDU]
|
|
645
|
+
)
|
|
646
|
+
return (
|
|
647
|
+
log_posterior,
|
|
648
|
+
individual_params,
|
|
649
|
+
current_log_pdu,
|
|
650
|
+
full_pred,
|
|
651
|
+
flagged_patients,
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
def calculate_residuals(
|
|
655
|
+
self, observed_data: torch.Tensor, predictions: torch.Tensor
|
|
656
|
+
) -> torch.Tensor:
|
|
657
|
+
"""Calculates residuals based on the error model for a single patient
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
observed_data: Tensor of observations
|
|
661
|
+
predictions: Tensor of predictions
|
|
662
|
+
|
|
663
|
+
Returns:
|
|
664
|
+
torch.Tensor: a tensor of residual values
|
|
665
|
+
"""
|
|
666
|
+
if self.error_model_type == "additive":
|
|
667
|
+
return observed_data - predictions
|
|
668
|
+
elif self.error_model_type == "proportional":
|
|
669
|
+
return (observed_data - predictions) / predictions
|
|
670
|
+
else:
|
|
671
|
+
raise ValueError("Unsupported error model type.")
|
|
672
|
+
|
|
673
|
+
def sum_sq_residuals(self, prediction: torch.Tensor) -> torch.Tensor:
|
|
674
|
+
sq_residuals = torch.square(
|
|
675
|
+
self.calculate_residuals(self.full_obs_data, prediction)
|
|
676
|
+
)
|
|
677
|
+
sum_residuals = torch.zeros(self.nb_outputs, device=device)
|
|
678
|
+
sum_residuals.scatter_add_(0, self.full_output_indices, sq_residuals)
|
|
679
|
+
return sum_residuals
|
|
680
|
+
|
|
681
|
+
def log_likelihood_observation(
|
|
682
|
+
self,
|
|
683
|
+
predictions: torch.Tensor,
|
|
684
|
+
) -> torch.Tensor:
|
|
685
|
+
"""
|
|
686
|
+
Calculates the log-likelihood of observations given predictions and error model, assuming errors follow N(0,sqrt(residual_error_var))
|
|
687
|
+
observed_data: torch.Tensor of observations for one individual
|
|
688
|
+
predictions: torch.Tensor of predictions for one individual organized in the same way as observed_data
|
|
689
|
+
residual_error_var: torch.Tensor of the error for each output, dim: [nb_outputs]
|
|
690
|
+
"""
|
|
691
|
+
residuals: torch.Tensor = self.calculate_residuals(
|
|
692
|
+
self.full_obs_data, predictions
|
|
693
|
+
)
|
|
694
|
+
res_error_var = self.residual_var.index_select(0, self.full_output_indices)
|
|
695
|
+
# Log-likelihood of normal distribution
|
|
696
|
+
if self.error_model_type == "additive":
|
|
697
|
+
log_lik_full = -0.5 * (
|
|
698
|
+
torch.log(2 * torch.pi * res_error_var) + (residuals**2 / res_error_var)
|
|
699
|
+
)
|
|
700
|
+
elif self.error_model_type == "proportional":
|
|
701
|
+
variance = res_error_var * torch.square(predictions)
|
|
702
|
+
log_lik_full = -0.5 * (
|
|
703
|
+
torch.log(2 * torch.pi * variance) + (residuals**2 / variance)
|
|
704
|
+
)
|
|
705
|
+
else:
|
|
706
|
+
raise ValueError("Non supported error type.")
|
|
707
|
+
log_lik_per_patient = torch.zeros(self.nb_patients, device=device)
|
|
708
|
+
log_lik_per_patient.scatter_add_(
|
|
709
|
+
0, self.observation_to_patient_index, log_lik_full
|
|
710
|
+
)
|
|
711
|
+
return log_lik_per_patient
|
|
712
|
+
|
|
713
|
+
def mh_step(
|
|
714
|
+
self,
|
|
715
|
+
current_etas: torch.Tensor,
|
|
716
|
+
current_log_prob: torch.Tensor,
|
|
717
|
+
current_pred: torch.Tensor,
|
|
718
|
+
current_pdu: torch.Tensor,
|
|
719
|
+
current_thetas: torch.Tensor,
|
|
720
|
+
step_size: float,
|
|
721
|
+
learning_rate: float,
|
|
722
|
+
target_acceptance_rate: float = 0.234,
|
|
723
|
+
verbose: bool = False,
|
|
724
|
+
) -> tuple[
|
|
725
|
+
torch.Tensor,
|
|
726
|
+
torch.Tensor,
|
|
727
|
+
torch.Tensor,
|
|
728
|
+
torch.Tensor,
|
|
729
|
+
torch.Tensor,
|
|
730
|
+
torch.Tensor,
|
|
731
|
+
float,
|
|
732
|
+
list,
|
|
733
|
+
]:
|
|
734
|
+
"""Perform one step of a Metropolis-Hastings transition kernel
|
|
735
|
+
|
|
736
|
+
Args:
|
|
737
|
+
current_etas (torch.Tensor): values of the individual random effects for all patients
|
|
738
|
+
current_log_prob (torch.Tensor): log posterior likelihood of current random effects
|
|
739
|
+
current_pred (list[torch.Tensor]): associated model predictions with current random effects
|
|
740
|
+
step_size (torch.Tensor): current value of MH step size,
|
|
741
|
+
learning_rate (float): current learning rate (defined by the optimization algorithm)
|
|
742
|
+
target_acceptance_rate (float, optional): Target for the MCMC acceptance rate. Defaults to 0.234 [1].
|
|
743
|
+
|
|
744
|
+
[1] Sherlock C. Optimal Scaling of the Random Walk Metropolis: General Criteria for the 0.234 Acceptance Rule. Journal of Applied Probability. 2013;50(1):1-15. doi:10.1239/jap/1363784420
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
tuple[torch.Tensor, torch.Tensor, dict[int | str, torch.Tensor], torch.Tensor, float]:
|
|
748
|
+
- updated individual random effects
|
|
749
|
+
- updated log posterior likelihood
|
|
750
|
+
- updated predictions, for each patient of the observation data set
|
|
751
|
+
- updated thetas
|
|
752
|
+
- updated values of log PDUs
|
|
753
|
+
- updated step size
|
|
754
|
+
- a dict of warnings for all patients with predictive variance above threshold
|
|
755
|
+
"""
|
|
756
|
+
|
|
757
|
+
proposal_noise = (
|
|
758
|
+
torch.randn_like(current_etas, device=device) @ self.omega_pop_lower_chol
|
|
759
|
+
)
|
|
760
|
+
proposal_etas = current_etas + step_size * proposal_noise
|
|
761
|
+
(
|
|
762
|
+
proposal_log_prob,
|
|
763
|
+
proposal_theta,
|
|
764
|
+
proposal_log_pdus,
|
|
765
|
+
proposal_pred,
|
|
766
|
+
warnings,
|
|
767
|
+
) = self.log_posterior_etas(proposal_etas)
|
|
768
|
+
deltas: torch.Tensor = proposal_log_prob - current_log_prob
|
|
769
|
+
log_u: torch.Tensor = torch.log(torch.rand_like(deltas, device=device))
|
|
770
|
+
accept_mask: torch.Tensor = log_u < deltas
|
|
771
|
+
accept_mask_parameters = accept_mask.unsqueeze(-1).expand(
|
|
772
|
+
-1, current_etas.shape[1]
|
|
773
|
+
)
|
|
774
|
+
accept_mask_predictions = accept_mask.index_select(
|
|
775
|
+
0, self.observation_to_patient_index
|
|
776
|
+
)
|
|
777
|
+
new_etas = torch.where(accept_mask_parameters, proposal_etas, current_etas).to(
|
|
778
|
+
device
|
|
779
|
+
)
|
|
780
|
+
new_log_pdus = torch.where(
|
|
781
|
+
accept_mask_parameters, proposal_log_pdus, current_pdu
|
|
782
|
+
).to(device)
|
|
783
|
+
new_log_prob = torch.where(accept_mask, proposal_log_prob, current_log_prob).to(
|
|
784
|
+
device
|
|
785
|
+
)
|
|
786
|
+
new_complete_likelihood = -2 * new_log_prob.sum(dim=0)
|
|
787
|
+
new_pred = torch.where(accept_mask_predictions, proposal_pred, current_pred).to(
|
|
788
|
+
device
|
|
789
|
+
)
|
|
790
|
+
new_acceptance_rate: float = accept_mask.cpu().float().mean().item()
|
|
791
|
+
if verbose:
|
|
792
|
+
print(f" Acceptance rate: {new_acceptance_rate:.2f}")
|
|
793
|
+
new_step_size: float = step_size * np.exp(
|
|
794
|
+
learning_rate * (new_acceptance_rate - target_acceptance_rate)
|
|
795
|
+
)
|
|
796
|
+
new_theta = torch.where(
|
|
797
|
+
accept_mask.unsqueeze(1).expand(-1, current_thetas.shape[1]),
|
|
798
|
+
proposal_theta,
|
|
799
|
+
current_thetas,
|
|
800
|
+
).to(device)
|
|
801
|
+
return (
|
|
802
|
+
new_etas,
|
|
803
|
+
new_log_prob,
|
|
804
|
+
new_complete_likelihood,
|
|
805
|
+
new_pred,
|
|
806
|
+
new_theta,
|
|
807
|
+
new_log_pdus,
|
|
808
|
+
new_step_size,
|
|
809
|
+
warnings,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
def map_estimates_descriptors(self) -> pd.DataFrame:
|
|
813
|
+
theta = self.current_map_estimates
|
|
814
|
+
if theta is None:
|
|
815
|
+
raise ValueError("No estimation available yet. Run the algorithm first.")
|
|
816
|
+
|
|
817
|
+
map_per_patient = pd.DataFrame(
|
|
818
|
+
data=theta.cpu().numpy(), columns=self.descriptors
|
|
819
|
+
)
|
|
820
|
+
return map_per_patient
|
|
821
|
+
|
|
822
|
+
def map_estimates_predictions(self) -> pd.DataFrame:
|
|
823
|
+
theta = self.current_map_estimates
|
|
824
|
+
if theta is None:
|
|
825
|
+
raise ValueError(
|
|
826
|
+
"No estimation available yet. Run the optimization algorithm first."
|
|
827
|
+
)
|
|
828
|
+
simulated_tensor, _ = self.predict_outputs_from_theta(theta)
|
|
829
|
+
simulated_df = self.outputs_to_df(simulated_tensor)
|
|
830
|
+
return simulated_df
|
|
831
|
+
|
|
832
|
+
def variance_level_check(
|
|
833
|
+
self, var_list: tuple[torch.Tensor, ...], threshold: float
|
|
834
|
+
) -> list:
|
|
835
|
+
warnings = []
|
|
836
|
+
for i in range(self.nb_patients):
|
|
837
|
+
var, _ = var_list[i].cpu().max(0)
|
|
838
|
+
if var > threshold:
|
|
839
|
+
warnings.append(i)
|
|
840
|
+
return warnings
|