vpop-calibration 2.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vpop_calibration/__init__.py +22 -0
- vpop_calibration/data_generation.py +186 -0
- vpop_calibration/diagnostics.py +162 -0
- vpop_calibration/model/__init__.py +3 -0
- vpop_calibration/model/data.py +420 -0
- vpop_calibration/model/gp.py +517 -0
- vpop_calibration/model/plot.py +243 -0
- vpop_calibration/nlme.py +840 -0
- vpop_calibration/ode.py +203 -0
- vpop_calibration/saem.py +945 -0
- vpop_calibration/structural_model.py +200 -0
- vpop_calibration/test/__init__.py +11 -0
- vpop_calibration/test/test_data.py +21 -0
- vpop_calibration/test/test_gp_flavors.py +89 -0
- vpop_calibration/test/test_gp_saem.py +175 -0
- vpop_calibration/test/test_ode_saem.py +121 -0
- vpop_calibration/utils.py +9 -0
- vpop_calibration/vpop.py +50 -0
- vpop_calibration-2.2.8.dist-info/METADATA +78 -0
- vpop_calibration-2.2.8.dist-info/RECORD +22 -0
- vpop_calibration-2.2.8.dist-info/WHEEL +4 -0
- vpop_calibration-2.2.8.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .nlme import NlmeModel
|
|
2
|
+
from .saem import PySaem
|
|
3
|
+
from .structural_model import StructuralGp, StructuralOdeModel
|
|
4
|
+
from .model import *
|
|
5
|
+
from .ode import OdeModel
|
|
6
|
+
from .vpop import generate_vpop_from_ranges
|
|
7
|
+
from .data_generation import simulate_dataset_from_omega, simulate_dataset_from_ranges
|
|
8
|
+
from .diagnostics import check_surrogate_validity_gp, plot_map_estimates
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"GP",
|
|
12
|
+
"OdeModel",
|
|
13
|
+
"StructuralGp",
|
|
14
|
+
"StructuralOdeModel",
|
|
15
|
+
"NlmeModel",
|
|
16
|
+
"PySaem",
|
|
17
|
+
"simulate_dataset_from_omega",
|
|
18
|
+
"simulate_dataset_from_ranges",
|
|
19
|
+
"generate_vpop_from_ranges",
|
|
20
|
+
"check_surrogate_validity_gp",
|
|
21
|
+
"plot_map_estimates",
|
|
22
|
+
]
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from .ode import OdeModel
|
|
6
|
+
from .vpop import generate_vpop_from_ranges
|
|
7
|
+
from .structural_model import StructuralOdeModel
|
|
8
|
+
from .nlme import NlmeModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def simulate_dataset_from_ranges(
|
|
12
|
+
ode_model: OdeModel,
|
|
13
|
+
log_nb_individuals: int,
|
|
14
|
+
param_ranges: dict[str, dict[str, float | bool]],
|
|
15
|
+
initial_conditions: np.ndarray,
|
|
16
|
+
protocol_design: Optional[pd.DataFrame],
|
|
17
|
+
residual_error_variance: Optional[np.ndarray],
|
|
18
|
+
error_model: Optional[str], # "additive" or "proportional"
|
|
19
|
+
time_steps: np.ndarray,
|
|
20
|
+
) -> pd.DataFrame:
|
|
21
|
+
"""Generate a simulated data set with an ODE model
|
|
22
|
+
|
|
23
|
+
Simulates a dataset for training a surrogate model. Timesteps can be different for each output.
|
|
24
|
+
The parameter space is explored with Sobol sequences.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
log_nb_individuals (int): The number of simulated patients will be 2^this parameter
|
|
28
|
+
param_ranges (list[dict]): For each parameter in the model, a dict describing the search space 'low': low bound, 'high': high bound, and 'log': True if the search space is log-scaled
|
|
29
|
+
initial_conditions (array): set of initial conditions, one for each variable
|
|
30
|
+
protocol_design (optional): a DataFrame with a `protocol_arm` column, and one column per parameter override
|
|
31
|
+
residual_error_variance (np.array): A 1D array of residual error variances for each output.
|
|
32
|
+
error_model (str): the type of error model ("additive" or "proportional").
|
|
33
|
+
time_steps (np.array): an array with the time points
|
|
34
|
+
Returns:
|
|
35
|
+
pd.DataFrame: A DataFrame with columns 'id', parameter names, 'time', 'output_name', and 'value'.
|
|
36
|
+
|
|
37
|
+
Notes:
|
|
38
|
+
If a parameter appears both in the ranges and in the protocol design, the ranges take precedence.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
# Validate input data
|
|
42
|
+
params_to_explore = list(param_ranges.keys())
|
|
43
|
+
|
|
44
|
+
if protocol_design is None:
|
|
45
|
+
print("No protocol")
|
|
46
|
+
params = params_to_explore
|
|
47
|
+
params_in_protocol = []
|
|
48
|
+
protocol_design_filt = pd.DataFrame({"protocol_arm": ["identity"]})
|
|
49
|
+
else:
|
|
50
|
+
params_in_protocol = protocol_design.drop(
|
|
51
|
+
"protocol_arm", axis=1
|
|
52
|
+
).columns.tolist()
|
|
53
|
+
# Find the paramaters that appear both in the ranges and the protocol
|
|
54
|
+
overlap = set(params_to_explore) & set(params_in_protocol)
|
|
55
|
+
if overlap != set():
|
|
56
|
+
protocol_design_filt = protocol_design.drop(list(overlap), axis=1)
|
|
57
|
+
print(
|
|
58
|
+
f"Warning: ignoring entries {overlap} from the protocol design (already defined in the ranges)."
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
protocol_design_filt = protocol_design
|
|
62
|
+
|
|
63
|
+
params = params_to_explore + params_in_protocol
|
|
64
|
+
if set(params) != set(ode_model.param_names):
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Under-defined system: missing {set(ode_model.param_names) - set(params)}"
|
|
67
|
+
)
|
|
68
|
+
# Generate the vpop using sobol sequences
|
|
69
|
+
patients_df = generate_vpop_from_ranges(log_nb_individuals, param_ranges)
|
|
70
|
+
|
|
71
|
+
# Add a choice of protocol arm for each patient
|
|
72
|
+
protocol_arms = pd.DataFrame(protocol_design_filt["protocol_arm"].drop_duplicates())
|
|
73
|
+
patients_df = patients_df.merge(protocol_arms, how="cross")
|
|
74
|
+
# Add the outputs for each patient
|
|
75
|
+
outputs = pd.DataFrame({"output_name": ode_model.variable_names})
|
|
76
|
+
patients_df = patients_df.merge(outputs, how="cross")
|
|
77
|
+
# Simulate the ODE model
|
|
78
|
+
output_df = ode_model.run_trial(
|
|
79
|
+
patients_df, initial_conditions, protocol_design_filt, time_steps
|
|
80
|
+
)
|
|
81
|
+
# Pivot to wide to add noise per model output
|
|
82
|
+
wide_output = output_df.pivot_table(
|
|
83
|
+
index=["id", *ode_model.param_names, "time", "protocol_arm"],
|
|
84
|
+
columns="output_name",
|
|
85
|
+
values="predicted_value",
|
|
86
|
+
).reset_index()
|
|
87
|
+
|
|
88
|
+
if error_model is None:
|
|
89
|
+
pass
|
|
90
|
+
else:
|
|
91
|
+
if residual_error_variance is None:
|
|
92
|
+
raise ValueError("Undefined residual error variance.")
|
|
93
|
+
else:
|
|
94
|
+
# Add noise to the data
|
|
95
|
+
noise = np.random.normal(
|
|
96
|
+
np.zeros_like(residual_error_variance),
|
|
97
|
+
np.sqrt(residual_error_variance),
|
|
98
|
+
(wide_output.shape[0], ode_model.nb_outputs),
|
|
99
|
+
)
|
|
100
|
+
if error_model == "additive":
|
|
101
|
+
wide_output[ode_model.variable_names] += noise
|
|
102
|
+
elif error_model == "proportional":
|
|
103
|
+
wide_output[ode_model.variable_names] += (
|
|
104
|
+
noise * wide_output[ode_model.variable_names]
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
raise ValueError(f"Incorrect error_model choice: {error_model}")
|
|
108
|
+
# Pivot back to long format
|
|
109
|
+
long_output = wide_output.melt(
|
|
110
|
+
id_vars=[
|
|
111
|
+
"id",
|
|
112
|
+
"protocol_arm",
|
|
113
|
+
"time",
|
|
114
|
+
*ode_model.param_names,
|
|
115
|
+
],
|
|
116
|
+
value_vars=ode_model.variable_names,
|
|
117
|
+
var_name="output_name",
|
|
118
|
+
value_name="value",
|
|
119
|
+
)
|
|
120
|
+
# Remove the protocol arm overrides from the data set, they described by the protocol_arm column now
|
|
121
|
+
long_output = long_output.drop(params_in_protocol, axis=1)
|
|
122
|
+
return long_output
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def simulate_dataset_from_omega(
|
|
126
|
+
ode_model: OdeModel,
|
|
127
|
+
protocol_design: pd.DataFrame,
|
|
128
|
+
time_steps: np.ndarray,
|
|
129
|
+
init_conditions: np.ndarray,
|
|
130
|
+
log_mi: dict[str, float],
|
|
131
|
+
log_pdu: dict[str, dict[str, float]],
|
|
132
|
+
error_model: str,
|
|
133
|
+
res_var: list[float],
|
|
134
|
+
covariate_map: dict[str, dict[str, dict[str, str | float]]],
|
|
135
|
+
patient_covariates: pd.DataFrame,
|
|
136
|
+
) -> pd.DataFrame:
|
|
137
|
+
"""Generate synthetic data set using an ODE model and population distributions of parameters
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
ode_model (OdeModel): The equations to be simulated
|
|
141
|
+
protocol_design (pd.DataFrame): _description_
|
|
142
|
+
time_steps (np.ndarray): _description_
|
|
143
|
+
init_conditions (np.ndarray): _description_
|
|
144
|
+
log_mi (dict[str, float]): _description_
|
|
145
|
+
log_pdu (dict[str, dict[str, float]]): _description_
|
|
146
|
+
error_model (str): _description_
|
|
147
|
+
res_var (list[float]): _description_
|
|
148
|
+
covariate_map (dict[str, dict[str, dict[str, str | float]]]): _description_
|
|
149
|
+
patient_covariates (pd.DataFrame): _description_
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
pd.DataFrame: _description_
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
structural_model = StructuralOdeModel(ode_model, protocol_design, init_conditions)
|
|
156
|
+
nlme_model = NlmeModel(
|
|
157
|
+
structural_model,
|
|
158
|
+
patient_covariates,
|
|
159
|
+
log_mi,
|
|
160
|
+
log_pdu,
|
|
161
|
+
res_var,
|
|
162
|
+
covariate_map,
|
|
163
|
+
error_model,
|
|
164
|
+
)
|
|
165
|
+
etas = nlme_model.sample_individual_etas()
|
|
166
|
+
theta = nlme_model.individual_parameters(etas)
|
|
167
|
+
vpop = pd.DataFrame(data=theta.cpu().numpy(), columns=nlme_model.descriptors)
|
|
168
|
+
vpop["id"] = nlme_model.patients
|
|
169
|
+
protocol_arms = patient_covariates[["id", "protocol_arm"]]
|
|
170
|
+
vpop = vpop.merge(protocol_arms, on=["id"], how="left")
|
|
171
|
+
vpop = vpop.merge(
|
|
172
|
+
pd.DataFrame(data=nlme_model.outputs_names, columns=["output_name"]),
|
|
173
|
+
how="cross",
|
|
174
|
+
)
|
|
175
|
+
time_df = pd.DataFrame(data=time_steps, columns=["time"])
|
|
176
|
+
vpop = vpop.merge(time_df, how="cross")
|
|
177
|
+
# add a dummy observation value
|
|
178
|
+
vpop["value"] = 1.0
|
|
179
|
+
nlme_model.add_observations(vpop)
|
|
180
|
+
|
|
181
|
+
out_tensor, _ = nlme_model.predict_outputs_from_theta(theta)
|
|
182
|
+
out_with_noise = nlme_model.add_residual_error(out_tensor)
|
|
183
|
+
out_df = nlme_model.outputs_to_df(out_with_noise)
|
|
184
|
+
out_df = out_df.rename(columns={"predicted_value": "value"})
|
|
185
|
+
|
|
186
|
+
return out_df
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
|
|
5
|
+
from .nlme import NlmeModel
|
|
6
|
+
from .saem import PySaem
|
|
7
|
+
from .model.gp import GP
|
|
8
|
+
from .structural_model import StructuralGp
|
|
9
|
+
from .utils import smoke_test
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def check_surrogate_validity_gp(nlme_model: NlmeModel) -> tuple[dict, dict]:
|
|
13
|
+
pdus = nlme_model.descriptors
|
|
14
|
+
gp_model_struct = nlme_model.structural_model
|
|
15
|
+
assert isinstance(
|
|
16
|
+
gp_model_struct, StructuralGp
|
|
17
|
+
), "Posterior surrogate validity check only implemented for GP structural model."
|
|
18
|
+
|
|
19
|
+
gp_model: GP = gp_model_struct.gp_model
|
|
20
|
+
train_data = gp_model.data.full_df_raw[pdus].drop_duplicates()
|
|
21
|
+
|
|
22
|
+
map_data = nlme_model.map_estimates_descriptors()
|
|
23
|
+
patients = nlme_model.patients
|
|
24
|
+
|
|
25
|
+
n_plots = len(pdus)
|
|
26
|
+
n_cols = 3
|
|
27
|
+
n_rows = int(np.ceil(n_plots / n_cols))
|
|
28
|
+
|
|
29
|
+
scaling_indiv_plots = 3
|
|
30
|
+
_, axes1 = plt.subplots(
|
|
31
|
+
n_rows,
|
|
32
|
+
n_cols,
|
|
33
|
+
squeeze=False,
|
|
34
|
+
figsize=[scaling_indiv_plots * n_cols, scaling_indiv_plots * n_rows],
|
|
35
|
+
)
|
|
36
|
+
diagnostics = {}
|
|
37
|
+
recommended_ranges = {}
|
|
38
|
+
for k, param in enumerate(pdus):
|
|
39
|
+
i, j = k // n_cols, k % n_cols
|
|
40
|
+
train_samples = np.log(train_data[param])
|
|
41
|
+
train_min, train_max = train_samples.min(axis=0), train_samples.max(axis=0)
|
|
42
|
+
|
|
43
|
+
map_samples = np.log(map_data[param])
|
|
44
|
+
flag_high = np.where(map_samples > train_max)[0]
|
|
45
|
+
flag_low = np.where(map_samples < train_min)[0]
|
|
46
|
+
recommend_low, recommend_high = train_min, train_max
|
|
47
|
+
param_diagnostic = {}
|
|
48
|
+
if flag_high.shape[0] > 0:
|
|
49
|
+
param_diagnostic.update({"above": [patients[p] for p in flag_high]})
|
|
50
|
+
recommend_high = map_samples.max()
|
|
51
|
+
else:
|
|
52
|
+
param_diagnostic.update({"above": None})
|
|
53
|
+
if flag_low.shape[0] > 0:
|
|
54
|
+
param_diagnostic.update({"below": [patients[p] for p in flag_low]})
|
|
55
|
+
recommend_low = map_samples.min()
|
|
56
|
+
else:
|
|
57
|
+
param_diagnostic.update({"below": None})
|
|
58
|
+
diagnostics.update({param: param_diagnostic})
|
|
59
|
+
recommended_ranges.update(
|
|
60
|
+
{
|
|
61
|
+
param: {
|
|
62
|
+
"low": f"{recommend_low:.2f}",
|
|
63
|
+
"high": f"{recommend_high:.2f}",
|
|
64
|
+
"log": True,
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
ax = axes1[i, j]
|
|
70
|
+
ax.hist([train_samples, map_samples], density=True)
|
|
71
|
+
ax.axvline(train_min, linestyle="dashed", color="black")
|
|
72
|
+
ax.axvline(train_max, linestyle="dashed", color="black")
|
|
73
|
+
ax.set_title(f"{param}")
|
|
74
|
+
|
|
75
|
+
scaling_2by2_plots = 2
|
|
76
|
+
_, axes2 = plt.subplots(
|
|
77
|
+
n_plots,
|
|
78
|
+
n_plots,
|
|
79
|
+
squeeze=False,
|
|
80
|
+
figsize=[scaling_2by2_plots * n_plots, scaling_2by2_plots * n_plots],
|
|
81
|
+
sharex="col",
|
|
82
|
+
sharey="row",
|
|
83
|
+
)
|
|
84
|
+
for k1, param1 in enumerate(pdus):
|
|
85
|
+
train_samples_1 = np.log(train_data[param1])
|
|
86
|
+
map_samples_1 = np.log(map_data[param1])
|
|
87
|
+
for k2, param2 in enumerate(pdus):
|
|
88
|
+
train_samples_2 = np.log(train_data[param2])
|
|
89
|
+
map_samples_2 = np.log(map_data[param2])
|
|
90
|
+
ax = axes2[k1, k2]
|
|
91
|
+
if k1 != k2:
|
|
92
|
+
# param 1 is the row -> y axis
|
|
93
|
+
# param 2 is the column -> x axis
|
|
94
|
+
ax.scatter(train_samples_2, train_samples_1, alpha=0.5, s=1.0)
|
|
95
|
+
ax.scatter(map_samples_2, map_samples_1, s=5)
|
|
96
|
+
if k2 == 0:
|
|
97
|
+
ax.set_ylabel(param1)
|
|
98
|
+
if k1 == len(pdus) - 1:
|
|
99
|
+
ax.set_xlabel(param2)
|
|
100
|
+
return diagnostics, recommended_ranges
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def plot_map_estimates(nlme_model: NlmeModel) -> None:
|
|
104
|
+
observed = nlme_model.observations_df
|
|
105
|
+
simulated_df = nlme_model.map_estimates_predictions()
|
|
106
|
+
|
|
107
|
+
n_cols = nlme_model.nb_outputs
|
|
108
|
+
n_rows = nlme_model.structural_model.nb_protocols
|
|
109
|
+
_, axes = plt.subplots(
|
|
110
|
+
n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), squeeze=False
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
cmap = plt.get_cmap("Spectral")
|
|
114
|
+
colors = cmap(np.linspace(0, 1, nlme_model.nb_patients))
|
|
115
|
+
for output_index, output_name in enumerate(nlme_model.outputs_names):
|
|
116
|
+
for protocol_index, protocol_arm in enumerate(
|
|
117
|
+
nlme_model.structural_model.protocols
|
|
118
|
+
):
|
|
119
|
+
obs_loop = observed.loc[
|
|
120
|
+
(observed["output_name"] == output_name)
|
|
121
|
+
& (observed["protocol_arm"] == protocol_arm)
|
|
122
|
+
]
|
|
123
|
+
pred_loop = simulated_df.loc[
|
|
124
|
+
(simulated_df["output_name"] == output_name)
|
|
125
|
+
& (simulated_df["protocol_arm"] == protocol_arm)
|
|
126
|
+
]
|
|
127
|
+
ax = axes[protocol_index, output_index]
|
|
128
|
+
ax.set_xlabel("Time")
|
|
129
|
+
patients_protocol = obs_loop["id"].drop_duplicates().to_list()
|
|
130
|
+
for patient_ind in patients_protocol:
|
|
131
|
+
patient_num = nlme_model.patients.index(patient_ind)
|
|
132
|
+
patient_obs = obs_loop.loc[obs_loop["id"] == patient_ind]
|
|
133
|
+
patient_pred = pred_loop.loc[pred_loop["id"] == patient_ind]
|
|
134
|
+
time_vec = patient_obs["time"].values
|
|
135
|
+
sorted_indices = np.argsort(time_vec)
|
|
136
|
+
sorted_times = time_vec[sorted_indices]
|
|
137
|
+
obs_vec = patient_obs["value"].values[sorted_indices]
|
|
138
|
+
ax.plot(
|
|
139
|
+
sorted_times,
|
|
140
|
+
obs_vec,
|
|
141
|
+
"+",
|
|
142
|
+
color=colors[patient_num],
|
|
143
|
+
linewidth=2,
|
|
144
|
+
alpha=0.6,
|
|
145
|
+
)
|
|
146
|
+
if patient_pred.shape[0] > 0:
|
|
147
|
+
pred_vec = patient_pred["predicted_value"].values[sorted_indices]
|
|
148
|
+
ax.plot(
|
|
149
|
+
sorted_times,
|
|
150
|
+
pred_vec,
|
|
151
|
+
"-",
|
|
152
|
+
color=colors[patient_num],
|
|
153
|
+
linewidth=2,
|
|
154
|
+
alpha=0.5,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
title = f"{output_name} in {protocol_arm}" # More descriptive title
|
|
158
|
+
ax.set_title(title)
|
|
159
|
+
|
|
160
|
+
if not smoke_test:
|
|
161
|
+
plt.tight_layout()
|
|
162
|
+
plt.show()
|