vpop-calibration 2.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vpop_calibration/__init__.py +22 -0
- vpop_calibration/data_generation.py +186 -0
- vpop_calibration/diagnostics.py +162 -0
- vpop_calibration/model/__init__.py +3 -0
- vpop_calibration/model/data.py +420 -0
- vpop_calibration/model/gp.py +517 -0
- vpop_calibration/model/plot.py +243 -0
- vpop_calibration/nlme.py +840 -0
- vpop_calibration/ode.py +203 -0
- vpop_calibration/saem.py +945 -0
- vpop_calibration/structural_model.py +200 -0
- vpop_calibration/test/__init__.py +11 -0
- vpop_calibration/test/test_data.py +21 -0
- vpop_calibration/test/test_gp_flavors.py +89 -0
- vpop_calibration/test/test_gp_saem.py +175 -0
- vpop_calibration/test/test_ode_saem.py +121 -0
- vpop_calibration/utils.py +9 -0
- vpop_calibration/vpop.py +50 -0
- vpop_calibration-2.2.8.dist-info/METADATA +78 -0
- vpop_calibration-2.2.8.dist-info/RECORD +22 -0
- vpop_calibration-2.2.8.dist-info/WHEEL +4 -0
- vpop_calibration-2.2.8.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ..utils import smoke_test
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def plot_all_solutions(obs_vs_pred: pd.DataFrame) -> None:
|
|
10
|
+
"""Plot the overlapped observations and model predictions for all patients, facetted by output and protocol.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
obs_vs_pred (pd.DataFrame): Full data frame containing observations and predicitons from the model. Should contain the following columns
|
|
14
|
+
- `id`
|
|
15
|
+
- `output_name`
|
|
16
|
+
- `protocol_arm`
|
|
17
|
+
- `time`
|
|
18
|
+
- `value`
|
|
19
|
+
- `pred_mean`
|
|
20
|
+
"""
|
|
21
|
+
outputs = obs_vs_pred["output_name"].unique().tolist()
|
|
22
|
+
nb_outputs = len(outputs)
|
|
23
|
+
protocol_arms = obs_vs_pred["protocol_arm"].unique().tolist()
|
|
24
|
+
nb_protocol_arms = len(protocol_arms)
|
|
25
|
+
patients = obs_vs_pred["id"].unique().tolist()
|
|
26
|
+
|
|
27
|
+
n_cols = nb_outputs
|
|
28
|
+
n_rows = nb_protocol_arms
|
|
29
|
+
_, axes = plt.subplots(
|
|
30
|
+
n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), squeeze=False
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
cmap = plt.get_cmap("Spectral")
|
|
34
|
+
colors = cmap(np.linspace(0, 1, len(patients)))
|
|
35
|
+
for output_index, output_name in enumerate(outputs):
|
|
36
|
+
for protocol_index, protocol_arm in enumerate(protocol_arms):
|
|
37
|
+
data_to_plot = obs_vs_pred.loc[
|
|
38
|
+
(obs_vs_pred["output_name"] == output_name)
|
|
39
|
+
& (obs_vs_pred["protocol_arm"] == protocol_arm)
|
|
40
|
+
]
|
|
41
|
+
ax = axes[protocol_index, output_index]
|
|
42
|
+
ax.set_xlabel("Time")
|
|
43
|
+
for patient_num, patient_ind in enumerate(patients):
|
|
44
|
+
patient_data = data_to_plot.loc[data_to_plot["id"] == patient_ind]
|
|
45
|
+
time_vec = patient_data["time"].values
|
|
46
|
+
sorted_indices = np.argsort(time_vec)
|
|
47
|
+
sorted_times = time_vec[sorted_indices]
|
|
48
|
+
obs_vec = patient_data["value"].values[sorted_indices]
|
|
49
|
+
pred_vec = patient_data["pred_mean"].values[sorted_indices]
|
|
50
|
+
ax.plot(
|
|
51
|
+
sorted_times,
|
|
52
|
+
obs_vec,
|
|
53
|
+
"+",
|
|
54
|
+
color=colors[patient_num],
|
|
55
|
+
linewidth=2,
|
|
56
|
+
alpha=0.6,
|
|
57
|
+
)
|
|
58
|
+
ax.plot(
|
|
59
|
+
sorted_times,
|
|
60
|
+
pred_vec,
|
|
61
|
+
"-",
|
|
62
|
+
color=colors[patient_num],
|
|
63
|
+
linewidth=2,
|
|
64
|
+
alpha=0.5,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
title = f"{output_name} in {protocol_arm}"
|
|
68
|
+
ax.set_title(title)
|
|
69
|
+
if not smoke_test:
|
|
70
|
+
plt.tight_layout()
|
|
71
|
+
plt.show()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def plot_individual_solution(obs_vs_pred: pd.DataFrame) -> None:
|
|
75
|
+
"""Plot the model prediction (and confidence interval) vs. the input data for a single patient"""
|
|
76
|
+
outputs = obs_vs_pred["output_name"].unique().tolist()
|
|
77
|
+
nb_outputs = len(outputs)
|
|
78
|
+
protocol_arms = obs_vs_pred["protocol_arm"].unique().tolist()
|
|
79
|
+
nb_protocol_arms = len(protocol_arms)
|
|
80
|
+
patients = obs_vs_pred["id"].unique().tolist()
|
|
81
|
+
assert len(patients) == 1
|
|
82
|
+
patient_id = patients[0]
|
|
83
|
+
ncols = nb_outputs
|
|
84
|
+
nrows = nb_protocol_arms
|
|
85
|
+
_, axes = plt.subplots(nrows, ncols, figsize=(9.0 * nb_outputs, 4.0), squeeze=False)
|
|
86
|
+
|
|
87
|
+
patient_params = obs_vs_pred.drop(
|
|
88
|
+
columns=[
|
|
89
|
+
"id",
|
|
90
|
+
"output_name",
|
|
91
|
+
"protocol_arm",
|
|
92
|
+
"value",
|
|
93
|
+
"pred_mean",
|
|
94
|
+
"pred_low",
|
|
95
|
+
"pred_high",
|
|
96
|
+
]
|
|
97
|
+
).drop_duplicates()
|
|
98
|
+
|
|
99
|
+
for output_index, output_name in enumerate(outputs):
|
|
100
|
+
for protocol_index, protocol_arm in enumerate(protocol_arms):
|
|
101
|
+
data_to_plot = obs_vs_pred.loc[
|
|
102
|
+
(obs_vs_pred["output_name"] == output_name)
|
|
103
|
+
& (obs_vs_pred["protocol_arm"] == protocol_arm)
|
|
104
|
+
]
|
|
105
|
+
time_steps = np.array(data_to_plot["time"].values)
|
|
106
|
+
sorted_indices = np.argsort(time_steps)
|
|
107
|
+
sorted_time_steps = time_steps[sorted_indices]
|
|
108
|
+
ax = axes[protocol_index, output_index]
|
|
109
|
+
ax.set_xlabel("Time")
|
|
110
|
+
# Plot observations
|
|
111
|
+
ax.plot(
|
|
112
|
+
sorted_time_steps,
|
|
113
|
+
data_to_plot["value"].values[sorted_indices],
|
|
114
|
+
".-",
|
|
115
|
+
color="C0",
|
|
116
|
+
linewidth=2,
|
|
117
|
+
alpha=0.6,
|
|
118
|
+
label=output_name,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Plot model prediction
|
|
122
|
+
ax.plot(
|
|
123
|
+
sorted_time_steps,
|
|
124
|
+
data_to_plot["pred_mean"].values[sorted_indices],
|
|
125
|
+
"-",
|
|
126
|
+
color="C3",
|
|
127
|
+
linewidth=2,
|
|
128
|
+
alpha=0.5,
|
|
129
|
+
label="GP prediction for " + output_name + " (mean)",
|
|
130
|
+
)
|
|
131
|
+
# Add confidence interval
|
|
132
|
+
ax.fill_between(
|
|
133
|
+
sorted_time_steps,
|
|
134
|
+
data_to_plot["pred_low"].values[sorted_indices],
|
|
135
|
+
data_to_plot["pred_high"].values[sorted_indices],
|
|
136
|
+
alpha=0.5,
|
|
137
|
+
color="C3",
|
|
138
|
+
label="GP prediction for " + output_name + " (CI)",
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
ax.legend(loc="upper right")
|
|
142
|
+
title = f"{output_name} in {protocol_arm} for patient {patient_id}"
|
|
143
|
+
ax.set_title(title)
|
|
144
|
+
|
|
145
|
+
param_text = "Parameters:\n"
|
|
146
|
+
for param in patient_params:
|
|
147
|
+
param_text += f" {param}: {patient_params[param][0]:.3f}\n" # Format to 4 decimal places
|
|
148
|
+
|
|
149
|
+
ax.text(
|
|
150
|
+
1.02,
|
|
151
|
+
0.98,
|
|
152
|
+
param_text,
|
|
153
|
+
transform=ax.transAxes, # Coordinate system is relative to the axis
|
|
154
|
+
fontsize=9,
|
|
155
|
+
verticalalignment="top",
|
|
156
|
+
bbox=dict(boxstyle="round,pad=0.5", fc="wheat", alpha=0.5, ec="k"),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if not smoke_test:
|
|
160
|
+
plt.tight_layout()
|
|
161
|
+
plt.show()
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def plot_obs_vs_predicted(
|
|
165
|
+
obs_vs_pred: pd.DataFrame, logScale: Optional[list[bool]] = None
|
|
166
|
+
) -> None:
|
|
167
|
+
"""Plots the observed vs. predicted values on the training or validation data set, or on a new data set."""
|
|
168
|
+
|
|
169
|
+
outputs = obs_vs_pred["output_name"].unique().tolist()
|
|
170
|
+
nb_outputs = len(outputs)
|
|
171
|
+
protocol_arms = obs_vs_pred["protocol_arm"].unique().tolist()
|
|
172
|
+
nb_protocol_arms = len(protocol_arms)
|
|
173
|
+
patients = obs_vs_pred["id"].unique().tolist()
|
|
174
|
+
|
|
175
|
+
n_cols = nb_outputs
|
|
176
|
+
n_rows = nb_protocol_arms
|
|
177
|
+
_, axes = plt.subplots(
|
|
178
|
+
n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), squeeze=False
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if not logScale:
|
|
182
|
+
logScale = [True] * nb_outputs
|
|
183
|
+
|
|
184
|
+
for output_index, output_name in enumerate(outputs):
|
|
185
|
+
for protocol_index, protocol_arm in enumerate(protocol_arms):
|
|
186
|
+
log_viz = logScale[output_index]
|
|
187
|
+
ax = axes[protocol_index, output_index]
|
|
188
|
+
ax.set_xlabel("Observed")
|
|
189
|
+
ax.set_ylabel("Predicted")
|
|
190
|
+
data_to_plot = obs_vs_pred.loc[
|
|
191
|
+
(obs_vs_pred["protocol_arm"] == protocol_arm)
|
|
192
|
+
& (obs_vs_pred["output_name"] == output_name)
|
|
193
|
+
]
|
|
194
|
+
for ind in patients:
|
|
195
|
+
patient_data = data_to_plot.loc[data_to_plot["id"] == ind]
|
|
196
|
+
obs_vec = patient_data["value"]
|
|
197
|
+
pred_vec = patient_data["pred_mean"]
|
|
198
|
+
ax.plot(
|
|
199
|
+
obs_vec,
|
|
200
|
+
pred_vec,
|
|
201
|
+
"o",
|
|
202
|
+
linewidth=1,
|
|
203
|
+
alpha=0.6,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
min_val = data_to_plot["value"].min().min()
|
|
207
|
+
max_val = data_to_plot["value"].max().max()
|
|
208
|
+
ax.plot(
|
|
209
|
+
[min_val, max_val],
|
|
210
|
+
[min_val, max_val],
|
|
211
|
+
"-",
|
|
212
|
+
linewidth=1,
|
|
213
|
+
alpha=0.5,
|
|
214
|
+
color="black",
|
|
215
|
+
)
|
|
216
|
+
ax.fill_between(
|
|
217
|
+
[min_val, max_val],
|
|
218
|
+
[min_val / 2, max_val / 2],
|
|
219
|
+
[min_val * 2, max_val * 2],
|
|
220
|
+
linewidth=1,
|
|
221
|
+
alpha=0.25,
|
|
222
|
+
color="black",
|
|
223
|
+
)
|
|
224
|
+
title = f"{output_name} in {protocol_arm}" # More descriptive title
|
|
225
|
+
ax.set_title(title)
|
|
226
|
+
if log_viz:
|
|
227
|
+
ax.set_xscale("log")
|
|
228
|
+
ax.set_yscale("log")
|
|
229
|
+
|
|
230
|
+
if not smoke_test:
|
|
231
|
+
plt.tight_layout()
|
|
232
|
+
plt.show()
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def plot_loss(iterations: np.ndarray, losses: np.ndarray) -> None:
|
|
236
|
+
# plot the loss over iterations
|
|
237
|
+
plt.plot(iterations, losses)
|
|
238
|
+
plt.xlabel("Iteration")
|
|
239
|
+
plt.ylabel("Loss")
|
|
240
|
+
plt.title("Training Loss over Iterations")
|
|
241
|
+
|
|
242
|
+
if not smoke_test:
|
|
243
|
+
plt.show()
|