vpop-calibration 2.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,243 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from typing import Optional
5
+
6
+ from ..utils import smoke_test
7
+
8
+
9
+ def plot_all_solutions(obs_vs_pred: pd.DataFrame) -> None:
10
+ """Plot the overlapped observations and model predictions for all patients, facetted by output and protocol.
11
+
12
+ Args:
13
+ obs_vs_pred (pd.DataFrame): Full data frame containing observations and predicitons from the model. Should contain the following columns
14
+ - `id`
15
+ - `output_name`
16
+ - `protocol_arm`
17
+ - `time`
18
+ - `value`
19
+ - `pred_mean`
20
+ """
21
+ outputs = obs_vs_pred["output_name"].unique().tolist()
22
+ nb_outputs = len(outputs)
23
+ protocol_arms = obs_vs_pred["protocol_arm"].unique().tolist()
24
+ nb_protocol_arms = len(protocol_arms)
25
+ patients = obs_vs_pred["id"].unique().tolist()
26
+
27
+ n_cols = nb_outputs
28
+ n_rows = nb_protocol_arms
29
+ _, axes = plt.subplots(
30
+ n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), squeeze=False
31
+ )
32
+
33
+ cmap = plt.get_cmap("Spectral")
34
+ colors = cmap(np.linspace(0, 1, len(patients)))
35
+ for output_index, output_name in enumerate(outputs):
36
+ for protocol_index, protocol_arm in enumerate(protocol_arms):
37
+ data_to_plot = obs_vs_pred.loc[
38
+ (obs_vs_pred["output_name"] == output_name)
39
+ & (obs_vs_pred["protocol_arm"] == protocol_arm)
40
+ ]
41
+ ax = axes[protocol_index, output_index]
42
+ ax.set_xlabel("Time")
43
+ for patient_num, patient_ind in enumerate(patients):
44
+ patient_data = data_to_plot.loc[data_to_plot["id"] == patient_ind]
45
+ time_vec = patient_data["time"].values
46
+ sorted_indices = np.argsort(time_vec)
47
+ sorted_times = time_vec[sorted_indices]
48
+ obs_vec = patient_data["value"].values[sorted_indices]
49
+ pred_vec = patient_data["pred_mean"].values[sorted_indices]
50
+ ax.plot(
51
+ sorted_times,
52
+ obs_vec,
53
+ "+",
54
+ color=colors[patient_num],
55
+ linewidth=2,
56
+ alpha=0.6,
57
+ )
58
+ ax.plot(
59
+ sorted_times,
60
+ pred_vec,
61
+ "-",
62
+ color=colors[patient_num],
63
+ linewidth=2,
64
+ alpha=0.5,
65
+ )
66
+
67
+ title = f"{output_name} in {protocol_arm}"
68
+ ax.set_title(title)
69
+ if not smoke_test:
70
+ plt.tight_layout()
71
+ plt.show()
72
+
73
+
74
+ def plot_individual_solution(obs_vs_pred: pd.DataFrame) -> None:
75
+ """Plot the model prediction (and confidence interval) vs. the input data for a single patient"""
76
+ outputs = obs_vs_pred["output_name"].unique().tolist()
77
+ nb_outputs = len(outputs)
78
+ protocol_arms = obs_vs_pred["protocol_arm"].unique().tolist()
79
+ nb_protocol_arms = len(protocol_arms)
80
+ patients = obs_vs_pred["id"].unique().tolist()
81
+ assert len(patients) == 1
82
+ patient_id = patients[0]
83
+ ncols = nb_outputs
84
+ nrows = nb_protocol_arms
85
+ _, axes = plt.subplots(nrows, ncols, figsize=(9.0 * nb_outputs, 4.0), squeeze=False)
86
+
87
+ patient_params = obs_vs_pred.drop(
88
+ columns=[
89
+ "id",
90
+ "output_name",
91
+ "protocol_arm",
92
+ "value",
93
+ "pred_mean",
94
+ "pred_low",
95
+ "pred_high",
96
+ ]
97
+ ).drop_duplicates()
98
+
99
+ for output_index, output_name in enumerate(outputs):
100
+ for protocol_index, protocol_arm in enumerate(protocol_arms):
101
+ data_to_plot = obs_vs_pred.loc[
102
+ (obs_vs_pred["output_name"] == output_name)
103
+ & (obs_vs_pred["protocol_arm"] == protocol_arm)
104
+ ]
105
+ time_steps = np.array(data_to_plot["time"].values)
106
+ sorted_indices = np.argsort(time_steps)
107
+ sorted_time_steps = time_steps[sorted_indices]
108
+ ax = axes[protocol_index, output_index]
109
+ ax.set_xlabel("Time")
110
+ # Plot observations
111
+ ax.plot(
112
+ sorted_time_steps,
113
+ data_to_plot["value"].values[sorted_indices],
114
+ ".-",
115
+ color="C0",
116
+ linewidth=2,
117
+ alpha=0.6,
118
+ label=output_name,
119
+ )
120
+
121
+ # Plot model prediction
122
+ ax.plot(
123
+ sorted_time_steps,
124
+ data_to_plot["pred_mean"].values[sorted_indices],
125
+ "-",
126
+ color="C3",
127
+ linewidth=2,
128
+ alpha=0.5,
129
+ label="GP prediction for " + output_name + " (mean)",
130
+ )
131
+ # Add confidence interval
132
+ ax.fill_between(
133
+ sorted_time_steps,
134
+ data_to_plot["pred_low"].values[sorted_indices],
135
+ data_to_plot["pred_high"].values[sorted_indices],
136
+ alpha=0.5,
137
+ color="C3",
138
+ label="GP prediction for " + output_name + " (CI)",
139
+ )
140
+
141
+ ax.legend(loc="upper right")
142
+ title = f"{output_name} in {protocol_arm} for patient {patient_id}"
143
+ ax.set_title(title)
144
+
145
+ param_text = "Parameters:\n"
146
+ for param in patient_params:
147
+ param_text += f" {param}: {patient_params[param][0]:.3f}\n" # Format to 4 decimal places
148
+
149
+ ax.text(
150
+ 1.02,
151
+ 0.98,
152
+ param_text,
153
+ transform=ax.transAxes, # Coordinate system is relative to the axis
154
+ fontsize=9,
155
+ verticalalignment="top",
156
+ bbox=dict(boxstyle="round,pad=0.5", fc="wheat", alpha=0.5, ec="k"),
157
+ )
158
+
159
+ if not smoke_test:
160
+ plt.tight_layout()
161
+ plt.show()
162
+
163
+
164
+ def plot_obs_vs_predicted(
165
+ obs_vs_pred: pd.DataFrame, logScale: Optional[list[bool]] = None
166
+ ) -> None:
167
+ """Plots the observed vs. predicted values on the training or validation data set, or on a new data set."""
168
+
169
+ outputs = obs_vs_pred["output_name"].unique().tolist()
170
+ nb_outputs = len(outputs)
171
+ protocol_arms = obs_vs_pred["protocol_arm"].unique().tolist()
172
+ nb_protocol_arms = len(protocol_arms)
173
+ patients = obs_vs_pred["id"].unique().tolist()
174
+
175
+ n_cols = nb_outputs
176
+ n_rows = nb_protocol_arms
177
+ _, axes = plt.subplots(
178
+ n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), squeeze=False
179
+ )
180
+
181
+ if not logScale:
182
+ logScale = [True] * nb_outputs
183
+
184
+ for output_index, output_name in enumerate(outputs):
185
+ for protocol_index, protocol_arm in enumerate(protocol_arms):
186
+ log_viz = logScale[output_index]
187
+ ax = axes[protocol_index, output_index]
188
+ ax.set_xlabel("Observed")
189
+ ax.set_ylabel("Predicted")
190
+ data_to_plot = obs_vs_pred.loc[
191
+ (obs_vs_pred["protocol_arm"] == protocol_arm)
192
+ & (obs_vs_pred["output_name"] == output_name)
193
+ ]
194
+ for ind in patients:
195
+ patient_data = data_to_plot.loc[data_to_plot["id"] == ind]
196
+ obs_vec = patient_data["value"]
197
+ pred_vec = patient_data["pred_mean"]
198
+ ax.plot(
199
+ obs_vec,
200
+ pred_vec,
201
+ "o",
202
+ linewidth=1,
203
+ alpha=0.6,
204
+ )
205
+
206
+ min_val = data_to_plot["value"].min().min()
207
+ max_val = data_to_plot["value"].max().max()
208
+ ax.plot(
209
+ [min_val, max_val],
210
+ [min_val, max_val],
211
+ "-",
212
+ linewidth=1,
213
+ alpha=0.5,
214
+ color="black",
215
+ )
216
+ ax.fill_between(
217
+ [min_val, max_val],
218
+ [min_val / 2, max_val / 2],
219
+ [min_val * 2, max_val * 2],
220
+ linewidth=1,
221
+ alpha=0.25,
222
+ color="black",
223
+ )
224
+ title = f"{output_name} in {protocol_arm}" # More descriptive title
225
+ ax.set_title(title)
226
+ if log_viz:
227
+ ax.set_xscale("log")
228
+ ax.set_yscale("log")
229
+
230
+ if not smoke_test:
231
+ plt.tight_layout()
232
+ plt.show()
233
+
234
+
235
+ def plot_loss(iterations: np.ndarray, losses: np.ndarray) -> None:
236
+ # plot the loss over iterations
237
+ plt.plot(iterations, losses)
238
+ plt.xlabel("Iteration")
239
+ plt.ylabel("Loss")
240
+ plt.title("Training Loss over Iterations")
241
+
242
+ if not smoke_test:
243
+ plt.show()