guts-base 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- guts_base/__init__.py +14 -0
- guts_base/data/__init__.py +34 -0
- guts_base/data/expydb.py +247 -0
- guts_base/data/generator.py +96 -0
- guts_base/data/openguts.py +294 -0
- guts_base/data/preprocessing.py +55 -0
- guts_base/data/survival.py +137 -0
- guts_base/data/time_of_death.py +571 -0
- guts_base/data/utils.py +8 -0
- guts_base/mod.py +251 -0
- guts_base/plot.py +162 -0
- guts_base/prob.py +412 -0
- guts_base/sim/__init__.py +14 -0
- guts_base/sim/base.py +464 -0
- guts_base/sim/ecx.py +357 -0
- guts_base/sim/mempy.py +252 -0
- guts_base/sim/report.py +72 -0
- guts_base/sim.py +0 -0
- guts_base-0.8.2.dist-info/METADATA +836 -0
- guts_base-0.8.2.dist-info/RECORD +24 -0
- guts_base-0.8.2.dist-info/WHEEL +5 -0
- guts_base-0.8.2.dist-info/entry_points.txt +3 -0
- guts_base-0.8.2.dist-info/licenses/LICENSE +674 -0
- guts_base-0.8.2.dist-info/top_level.txt +1 -0
guts_base/sim/ecx.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from functools import partial
|
|
3
|
+
import numpy as np
|
|
4
|
+
import xarray as xr
|
|
5
|
+
from typing import Literal, Optional, Dict, List
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from scipy.optimize import minimize
|
|
8
|
+
from matplotlib import pyplot as plt
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from pymob import SimulationBase
|
|
12
|
+
|
|
13
|
+
class ECxEstimator:
|
|
14
|
+
"""Estimates the exposure level that corresponds to a given effect. The algorithm
|
|
15
|
+
operates by varying a given exposure profile (x_in)
|
|
16
|
+
"""
|
|
17
|
+
_name = "EC"
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
sim: SimulationBase,
|
|
22
|
+
effect: str,
|
|
23
|
+
x: float=0.5,
|
|
24
|
+
id: Optional[str]=None,
|
|
25
|
+
time: Optional[float]=None,
|
|
26
|
+
x_in: Optional[xr.Dataset]=None,
|
|
27
|
+
):
|
|
28
|
+
self.sim = sim.copy()
|
|
29
|
+
self.time = time
|
|
30
|
+
self.x = x
|
|
31
|
+
self.id = id
|
|
32
|
+
self.effect = effect
|
|
33
|
+
self._mode = None
|
|
34
|
+
|
|
35
|
+
if id is None:
|
|
36
|
+
self.sim.coordinates["id"] = [self.sim.coordinates["id"][0]]
|
|
37
|
+
else:
|
|
38
|
+
self.sim.coordinates["id"] = [id]
|
|
39
|
+
|
|
40
|
+
self.sim.model_parameters["x_in"] = x_in
|
|
41
|
+
|
|
42
|
+
# self.sim.observations = self.sim.expand_batch_like_coordinate_to_new_dimension(
|
|
43
|
+
# coordinate="exposure_path",
|
|
44
|
+
# variables=["Flupyradifurone"]
|
|
45
|
+
# )
|
|
46
|
+
|
|
47
|
+
# self.sim.config.data_structure.remove("Flupyradifurone")
|
|
48
|
+
|
|
49
|
+
# # TODO: COnstruct a sim if the input dims change
|
|
50
|
+
# self.sim.config.data_structure.exposure.dimensions = ["id", "time", "exposure_path"]
|
|
51
|
+
self.sim.config.data_structure.survival.observed = False
|
|
52
|
+
self.sim.observations = self.sim.observations.sel(id=self.sim.coordinates["id"])
|
|
53
|
+
|
|
54
|
+
self.sim.model_parameters["y0"] = self.sim.parse_input("y0", drop_dims="time")
|
|
55
|
+
self.sim.dispatch_constructor()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _evaluate(self, factor, theta):
|
|
60
|
+
evaluator = self.sim.dispatch(
|
|
61
|
+
theta=theta,
|
|
62
|
+
x_in=self.sim.validate_model_input(self.sim.model_parameters["x_in"] * factor)
|
|
63
|
+
)
|
|
64
|
+
evaluator()
|
|
65
|
+
return evaluator
|
|
66
|
+
|
|
67
|
+
def _loss(self, log_factor, theta):
|
|
68
|
+
# exponentiate the log factor
|
|
69
|
+
factor = np.exp(log_factor)
|
|
70
|
+
|
|
71
|
+
e = self._evaluate(factor, theta)
|
|
72
|
+
s = e.results.sel(time=self.time)[self.effect].values
|
|
73
|
+
|
|
74
|
+
return (s - (1 - self.x)) ** 2
|
|
75
|
+
|
|
76
|
+
def _posterior_mean(self):
|
|
77
|
+
mean = self.sim.inferer.idata.posterior.mean(("chain", "draw"))
|
|
78
|
+
mean = {k: v["data"] for k, v in mean.to_dict()["data_vars"].items()}
|
|
79
|
+
return mean
|
|
80
|
+
|
|
81
|
+
def _posterior_sample(self, i):
|
|
82
|
+
posterior_stacked = self.sim.inferer.idata.posterior.stack(
|
|
83
|
+
sample=("chain", "draw")
|
|
84
|
+
)
|
|
85
|
+
sample = posterior_stacked.isel(sample=i)
|
|
86
|
+
sample = {k: v["data"] for k, v in sample.to_dict()["data_vars"].items()}
|
|
87
|
+
return sample
|
|
88
|
+
|
|
89
|
+
def plot_loss_curve(self):
|
|
90
|
+
posterior_mean = self._posterior_mean()
|
|
91
|
+
|
|
92
|
+
factor = np.linspace(-2,2, 100)
|
|
93
|
+
y = list(map(partial(self._loss, theta=posterior_mean), factor))
|
|
94
|
+
|
|
95
|
+
fig, ax = plt.subplots(1,1, sharey=True, figsize=(4, 3))
|
|
96
|
+
ax.plot(
|
|
97
|
+
np.exp(factor), y,
|
|
98
|
+
color="black",
|
|
99
|
+
label=f"$\ell = S(t={self.time},x_{{in}}=C_{{ext}} \cdot \phi) - {self.x}$"
|
|
100
|
+
)
|
|
101
|
+
ax.set_ylabel("Loss ($\ell$)")
|
|
102
|
+
ax.set_xlabel("Multiplication factor ($\phi$)")
|
|
103
|
+
ax.set_title(f"ID: {self.sim.coordinates['id'][0]}")
|
|
104
|
+
ax.set_ylim(0, np.max(y) * 1.25)
|
|
105
|
+
ax.legend(frameon=False)
|
|
106
|
+
fig.tight_layout()
|
|
107
|
+
|
|
108
|
+
def estimate(
|
|
109
|
+
self,
|
|
110
|
+
mode: Literal["draws", "mean", "manual"] = "draws",
|
|
111
|
+
draws: Optional[int] = None,
|
|
112
|
+
parameters: Optional[Dict[str,float|List[float]]] = None,
|
|
113
|
+
log_x0: float = 0.0,
|
|
114
|
+
accept_tol: float = 1e-5,
|
|
115
|
+
optimizer_tol: float = 1e-5,
|
|
116
|
+
method: str = "cobyla",
|
|
117
|
+
**optimizer_kwargs
|
|
118
|
+
):
|
|
119
|
+
"""The minimizer for the EC_x operates on the unbounded linear scale, estimating
|
|
120
|
+
the log-modification factor. Converted to the linear scale by factor=exp(x), the
|
|
121
|
+
profile modification factor is obtained.
|
|
122
|
+
|
|
123
|
+
Using x0=0.0 means optimization will start on the linear scale at the unmodified
|
|
124
|
+
exposure profile. Using the log scale for optimization will provide much smoother
|
|
125
|
+
optimization performance because multiplicative steps on the log scale require
|
|
126
|
+
much less adaptation.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
|
|
131
|
+
mode : Literal['draws', 'mean', 'manual']
|
|
132
|
+
mode of estimation. mode='mean' takes the mean of the posterior and estimate
|
|
133
|
+
the ECx for this singular value. mode='draws' takes samples from the posterior
|
|
134
|
+
and estimate the ECx for each of the parameter draws. mode='manual' takes
|
|
135
|
+
a parameter set (Dict) in the parameters argument and uses that for estimation.
|
|
136
|
+
Default: 'draws'
|
|
137
|
+
|
|
138
|
+
draws : int
|
|
139
|
+
Number of draws to take from the posterior. Only takes effect if mode='draw'.
|
|
140
|
+
Raises an exception if draws < 100, because this is insufficient for a
|
|
141
|
+
reasonable uncertainty estimate. Default: None (using all samples from the
|
|
142
|
+
posterior)
|
|
143
|
+
|
|
144
|
+
parameters : Dict[str,float|list[float]]
|
|
145
|
+
a parameter dictionary passed used as model parameters for finding the ECx
|
|
146
|
+
value. Default: None
|
|
147
|
+
|
|
148
|
+
log_x0 : float
|
|
149
|
+
the starting value for the multiplication factor of the exposure profile for
|
|
150
|
+
the minimization algorithm. This value is on the log scale. This means,
|
|
151
|
+
exp(log_x0=0.0) = 1.0, which means that the log_x0=0.0 will start at an
|
|
152
|
+
unmodified exposure profile. Default: 0.0
|
|
153
|
+
|
|
154
|
+
accept_tol : float
|
|
155
|
+
After optimization is finished, accept_tol is used to assess if the loss
|
|
156
|
+
function for the individual draws exceed a tolerance. These results are
|
|
157
|
+
discarded and a warning is emitted. This is to assert that no faulty optimization
|
|
158
|
+
results enter the estimate. Default: 1e-5
|
|
159
|
+
|
|
160
|
+
optimizer_tol : float
|
|
161
|
+
Tolerance limit for the minimzer to stop optimization. Default 1e-5
|
|
162
|
+
|
|
163
|
+
method : str
|
|
164
|
+
Minization algorithm. See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
|
|
165
|
+
Default: 'cobyla'
|
|
166
|
+
|
|
167
|
+
optimizer_kwargs :
|
|
168
|
+
Additional arguments to pass to the optimizer
|
|
169
|
+
|
|
170
|
+
"""
|
|
171
|
+
x0_tries = np.array([0.0, -1.0, 1.0, -2.0, 2.0]) + log_x0
|
|
172
|
+
|
|
173
|
+
if mode == "draws":
|
|
174
|
+
if draws is None:
|
|
175
|
+
draws = (
|
|
176
|
+
self.sim.inferer.idata.posterior.sizes["chain"] *
|
|
177
|
+
self.sim.inferer.idata.posterior.sizes["draw"]
|
|
178
|
+
)
|
|
179
|
+
elif draws < 100:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
"draws must be larger than 100. Preferably > 1000. "
|
|
182
|
+
f"If you don't want uncertainty assessment of the {self._name} "
|
|
183
|
+
"estimates, use mode='mean'"
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
pass
|
|
187
|
+
|
|
188
|
+
elif mode == "mean":
|
|
189
|
+
draws = 1
|
|
190
|
+
elif mode == "manual":
|
|
191
|
+
draws = 1
|
|
192
|
+
if parameters is None:
|
|
193
|
+
raise ValueError(
|
|
194
|
+
"parameters need to be provided if mode='manual'"
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
raise NotImplementedError(
|
|
198
|
+
f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
self._mode = mode
|
|
202
|
+
mult_factor = []
|
|
203
|
+
loss = []
|
|
204
|
+
iterations = []
|
|
205
|
+
for i in tqdm(range(draws)):
|
|
206
|
+
if mode == "draws":
|
|
207
|
+
sample = self._posterior_sample(i)
|
|
208
|
+
elif mode == "mean":
|
|
209
|
+
sample = self._posterior_mean()
|
|
210
|
+
elif mode == "manual":
|
|
211
|
+
sample = parameters
|
|
212
|
+
else:
|
|
213
|
+
raise NotImplementedError(
|
|
214
|
+
f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
success = False
|
|
218
|
+
iteration = 0
|
|
219
|
+
while not success and iteration < len(x0_tries):
|
|
220
|
+
opt_res = minimize(
|
|
221
|
+
self._loss, x0=x0_tries[iteration],
|
|
222
|
+
method=method,
|
|
223
|
+
tol=optimizer_tol,
|
|
224
|
+
args=(sample,),
|
|
225
|
+
**optimizer_kwargs
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
success = opt_res.fun < accept_tol
|
|
229
|
+
|
|
230
|
+
# convert to linear scale from log scale
|
|
231
|
+
factor = np.exp(opt_res.x)
|
|
232
|
+
|
|
233
|
+
mult_factor.extend(factor)
|
|
234
|
+
iterations.append(iteration)
|
|
235
|
+
loss.append(opt_res.fun)
|
|
236
|
+
|
|
237
|
+
res_full = pd.DataFrame(dict(factor = mult_factor, loss=loss, retries=iterations))
|
|
238
|
+
if sum(res_full.loss >= accept_tol) > 0:
|
|
239
|
+
warnings.warn(
|
|
240
|
+
f"Not all optimizations converged on the {self._name}_{self.x}. " +
|
|
241
|
+
"Adjust starting values and method")
|
|
242
|
+
print(res_full)
|
|
243
|
+
|
|
244
|
+
res = res_full.loc[res_full.loss < accept_tol,:]
|
|
245
|
+
|
|
246
|
+
summary = {
|
|
247
|
+
"mean": np.round(np.mean(res.factor.values), 4),
|
|
248
|
+
"q05": np.round(np.quantile(res.factor.values, 0.05), 4),
|
|
249
|
+
"q95": np.round(np.quantile(res.factor.values, 0.95), 4),
|
|
250
|
+
"std": np.round(np.std(res.factor.values), 4),
|
|
251
|
+
"cv": np.round(np.std(res.factor.values)/np.mean(res.factor.values), 2),
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
self.results = pd.Series(summary)
|
|
255
|
+
self.results_full = res_full
|
|
256
|
+
|
|
257
|
+
print("{name}_{x}".format(name=self._name, x=int(self.x * 100),))
|
|
258
|
+
print(self.results)
|
|
259
|
+
print("\n")
|
|
260
|
+
|
|
261
|
+
def plot_profile_and_effect(
|
|
262
|
+
self,
|
|
263
|
+
parameters: Optional[Dict[str,float|List[float]]] = None
|
|
264
|
+
):
|
|
265
|
+
coordinates_backup = self.sim.coordinates["time"].copy()
|
|
266
|
+
|
|
267
|
+
self.sim.coordinates["time"] = np.linspace(0, self.time, 100)
|
|
268
|
+
self.sim.dispatch_constructor()
|
|
269
|
+
|
|
270
|
+
if self._mode is None:
|
|
271
|
+
raise RuntimeError(
|
|
272
|
+
"Run .estimate() before plot_profile_and_effect()"
|
|
273
|
+
)
|
|
274
|
+
elif self._mode == "mean" or self._mode == "draws":
|
|
275
|
+
e_new = self._evaluate(factor=self.results["mean"], theta=self._posterior_mean())
|
|
276
|
+
e_old = self._evaluate(factor=1.0, theta=self._posterior_mean())
|
|
277
|
+
elif self._mode == "manual":
|
|
278
|
+
if parameters is None:
|
|
279
|
+
raise RuntimeError(
|
|
280
|
+
f"If {self._name}_x was estimated using manual mode, parameters must "+
|
|
281
|
+
"also be provided here."
|
|
282
|
+
)
|
|
283
|
+
e_new = self._evaluate(factor=self.results["mean"], theta=parameters)
|
|
284
|
+
e_old = self._evaluate(factor=1.0, theta=parameters)
|
|
285
|
+
|
|
286
|
+
extra_dim = [k for k in list(e_old.results.coords.keys()) if k not in ["time", "id"]]
|
|
287
|
+
|
|
288
|
+
if len(extra_dim) > 0:
|
|
289
|
+
labels_old = [
|
|
290
|
+
f"{l} (original)" for l
|
|
291
|
+
in e_old.results.coords[extra_dim[0]].values
|
|
292
|
+
]
|
|
293
|
+
labels_new = [
|
|
294
|
+
f"{l} (modified)" for l
|
|
295
|
+
in e_new.results.coords[extra_dim[0]].values
|
|
296
|
+
]
|
|
297
|
+
else:
|
|
298
|
+
labels_old = "original"
|
|
299
|
+
labels_new = "modified"
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
fig, (ax1, ax2) = plt.subplots(2,1, height_ratios=[1,3], sharex=True)
|
|
304
|
+
ax1.plot(
|
|
305
|
+
e_old.results.time, e_old.results.exposure.isel(id=0),
|
|
306
|
+
ls="--", label=labels_old,
|
|
307
|
+
)
|
|
308
|
+
ax1.set_prop_cycle(None)
|
|
309
|
+
ax1.plot(
|
|
310
|
+
e_new.results.time, e_new.results.exposure.isel(id=0),
|
|
311
|
+
label=labels_new
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
ax2.plot(
|
|
316
|
+
e_new.results.time, e_new.results.survival.isel(id=0),
|
|
317
|
+
color="black", ls="--", label="modified"
|
|
318
|
+
)
|
|
319
|
+
ax1.set_prop_cycle(None)
|
|
320
|
+
|
|
321
|
+
ax2.plot(
|
|
322
|
+
e_old.results.time, e_old.results.survival.isel(id=0),
|
|
323
|
+
color="black", ls="-", label="original"
|
|
324
|
+
)
|
|
325
|
+
ax2.hlines(self.x, e_new.results.time[0], self.time, color="grey")
|
|
326
|
+
ax1.set_ylabel("Exposure")
|
|
327
|
+
ax2.set_ylabel("Survival")
|
|
328
|
+
ax2.set_xlabel("Time")
|
|
329
|
+
ax1.legend()
|
|
330
|
+
ax2.legend()
|
|
331
|
+
ax2.set_xlim(0, None)
|
|
332
|
+
ax1.set_ylim(0, None)
|
|
333
|
+
ax2.set_ylim(0, None)
|
|
334
|
+
fig.tight_layout()
|
|
335
|
+
|
|
336
|
+
self.sim.coordinates["time"] = coordinates_backup
|
|
337
|
+
self.sim.dispatch_constructor()
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class LPxEstimator(ECxEstimator):
|
|
342
|
+
"""
|
|
343
|
+
the LPx is computed, using the existing exposure profile for
|
|
344
|
+
the specified ID and estimating the multiplication factor for the profile that results
|
|
345
|
+
in an effect of X %
|
|
346
|
+
"""
|
|
347
|
+
_name = "LP"
|
|
348
|
+
|
|
349
|
+
def __init__(
|
|
350
|
+
self,
|
|
351
|
+
sim: SimulationBase,
|
|
352
|
+
id: str,
|
|
353
|
+
x: float=0.5
|
|
354
|
+
):
|
|
355
|
+
x_in = sim.model_parameters["x_in"].sel(id=[id])
|
|
356
|
+
time = sim.coordinates["time"][-1]
|
|
357
|
+
super().__init__(sim=sim, effect="survival", x=x, id=id, time=time, x_in=x_in)
|
guts_base/sim/mempy.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
from typing import Dict, Optional, Literal
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import xarray as xr
|
|
8
|
+
from pymob import SimulationBase
|
|
9
|
+
from pymob.sim.config import Config, DataVariable, Datastructure
|
|
10
|
+
from pymob.sim.parameters import Param
|
|
11
|
+
from guts_base.sim import GutsBase
|
|
12
|
+
from mempy.model import (
|
|
13
|
+
Model,
|
|
14
|
+
RED_IT,
|
|
15
|
+
RED_SD,
|
|
16
|
+
RED_IT_DA,
|
|
17
|
+
RED_SD_DA,
|
|
18
|
+
RED_IT_IA,
|
|
19
|
+
RED_SD_IA,
|
|
20
|
+
BufferGUTS_IT,
|
|
21
|
+
BufferGUTS_IT_CA,
|
|
22
|
+
BufferGUTS_IT_DA
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"PymobSimulator",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
class PymobSimulator(GutsBase):
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_mempy(
|
|
33
|
+
cls,
|
|
34
|
+
exposure_data: Dict,
|
|
35
|
+
survival_data: Dict,
|
|
36
|
+
model: Model,
|
|
37
|
+
info_dict: Dict = {},
|
|
38
|
+
pymob_config: Optional[Config] = None,
|
|
39
|
+
output_directory: str|pathlib.Path = pathlib.Path("output/pymob"),
|
|
40
|
+
default_prior: Literal["uniform", "lognorm"] = "lognorm",
|
|
41
|
+
) -> SimulationBase:
|
|
42
|
+
"""Construct a PymobSimulator from the
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
if pymob_config is None:
|
|
46
|
+
cfg = Config()
|
|
47
|
+
# Configure: The configuration can be overridden in a subclass to override the
|
|
48
|
+
# configuration
|
|
49
|
+
cls.configure(config=cfg)
|
|
50
|
+
else:
|
|
51
|
+
cfg = pymob_config
|
|
52
|
+
|
|
53
|
+
if isinstance(output_directory, str):
|
|
54
|
+
output_directory = pathlib.Path(output_directory)
|
|
55
|
+
|
|
56
|
+
cfg.case_study.output = str(output_directory)
|
|
57
|
+
|
|
58
|
+
# parse observations
|
|
59
|
+
# obs can be simply subset by selection obs.sel(substance="Exposure-Dime")
|
|
60
|
+
observations = xr.combine_by_coords([
|
|
61
|
+
cls._exposure_data_to_xarray(exposure_data, dim=model.extra_dim),
|
|
62
|
+
cls._survival_data_to_xarray(survival_data)
|
|
63
|
+
])
|
|
64
|
+
|
|
65
|
+
# configure model and likelihood function
|
|
66
|
+
cfg.simulation.model = type(model).__name__
|
|
67
|
+
cfg.inference_numpyro.user_defined_error_model = str(model._likelihood_func_jax.__name__)
|
|
68
|
+
|
|
69
|
+
# derive data structure and params from the model instance
|
|
70
|
+
cls._set_data_structure(config=cfg, model=model)
|
|
71
|
+
cls._set_params(config=cfg, model=model, default_prior=default_prior)
|
|
72
|
+
|
|
73
|
+
# configure starting values and input
|
|
74
|
+
cfg.simulation.x_in = ["exposure=exposure"]
|
|
75
|
+
cfg.simulation.y0 = [f"{k}={v['y0']}" for k, v in model.state_variables.items() if "y0" in v]
|
|
76
|
+
|
|
77
|
+
# create a simulation object
|
|
78
|
+
sim = cls(config=cfg)
|
|
79
|
+
sim.config.create_directory(directory="results", force=True)
|
|
80
|
+
|
|
81
|
+
# initialize
|
|
82
|
+
sim.load_modules()
|
|
83
|
+
sim.set_logger()
|
|
84
|
+
|
|
85
|
+
sim.initialize(input={"observations": observations, "model": model})
|
|
86
|
+
|
|
87
|
+
sim.validate()
|
|
88
|
+
sim.dispatch_constructor()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
return sim
|
|
92
|
+
|
|
93
|
+
def initialize(self, input=None):
|
|
94
|
+
self.model = input["model"]._rhs_jax
|
|
95
|
+
self.solver_post_processing = input["model"]._solver_post_processing
|
|
96
|
+
|
|
97
|
+
super().initialize(input=input)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def configure(cls, config: Config):
|
|
102
|
+
"""This is normally set in the configuration file passed to a SimulationBase class.
|
|
103
|
+
Since the mempy to pymob converter initializes pymob.SimulationBase from scratch
|
|
104
|
+
(without using a config file), the necessary settings have to be specified here.
|
|
105
|
+
"""
|
|
106
|
+
config.case_study.output = "results"
|
|
107
|
+
|
|
108
|
+
config.simulation.x_dimension = "time"
|
|
109
|
+
config.simulation.batch_dimension = "id"
|
|
110
|
+
config.simulation.solver_post_processing = None
|
|
111
|
+
config.simulation.unit_time = "day"
|
|
112
|
+
config.simulation.n_reindexed_x = 100
|
|
113
|
+
config.simulation.forward_interpolate_exposure_data = True
|
|
114
|
+
|
|
115
|
+
config.inference.extra_vars = ["eps", "survivors_before_t"]
|
|
116
|
+
config.inference.n_predictions = 100
|
|
117
|
+
|
|
118
|
+
config.jaxsolver.diffrax_solver = "Tsit5"
|
|
119
|
+
config.jaxsolver.rtol = 1e-10
|
|
120
|
+
config.jaxsolver.atol = 1e-12
|
|
121
|
+
config.jaxsolver.throw_exception = True
|
|
122
|
+
config.jaxsolver.pcoeff = 0.3
|
|
123
|
+
config.jaxsolver.icoeff = 0.3
|
|
124
|
+
config.jaxsolver.dcoeff = 0.0
|
|
125
|
+
config.jaxsolver.max_steps = 1000000
|
|
126
|
+
config.jaxsolver.throw_exception = True
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
config.inference_numpyro.gaussian_base_distribution = True
|
|
130
|
+
config.inference_numpyro.kernel = "svi"
|
|
131
|
+
config.inference_numpyro.init_strategy = "init_to_median"
|
|
132
|
+
config.inference_numpyro.svi_iterations = 10_000
|
|
133
|
+
config.inference_numpyro.svi_learning_rate = 0.001
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def _exposure_data_to_xarray(exposure_data: Dict[str, pd.DataFrame], dim: str):
|
|
137
|
+
"""
|
|
138
|
+
TODO: Currently no rect interpolation
|
|
139
|
+
"""
|
|
140
|
+
arrays = {}
|
|
141
|
+
for key, df in exposure_data.items():
|
|
142
|
+
# this override is necessary to make all dimensions work out
|
|
143
|
+
df.index.name = "time"
|
|
144
|
+
arrays.update({
|
|
145
|
+
key: df.to_xarray().to_dataarray(dim="id", name=key)
|
|
146
|
+
})
|
|
147
|
+
|
|
148
|
+
exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
|
|
149
|
+
exposure_array = exposure_array.transpose("id", "time", ...)
|
|
150
|
+
return xr.Dataset({"exposure": exposure_array})
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def _survival_data_to_xarray(survival_data: pd.DataFrame):
|
|
154
|
+
# TODO: survival name is currently not kept because the raw data is not transferred from the survival
|
|
155
|
+
survival_data.index.name = "time"
|
|
156
|
+
|
|
157
|
+
survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
|
|
158
|
+
survival_array = survival_array.transpose("id", "time", ...)
|
|
159
|
+
arrays = {"survival": survival_array}
|
|
160
|
+
return xr.Dataset(arrays)
|
|
161
|
+
|
|
162
|
+
@classmethod
|
|
163
|
+
def _set_data_structure(cls, config: Config, model: Model):
|
|
164
|
+
"""Takes a dictionary that is specified in the model and uses only keys that
|
|
165
|
+
are fields of the DataVariable config-model"""
|
|
166
|
+
|
|
167
|
+
state_dict = model.state_variables
|
|
168
|
+
|
|
169
|
+
config.data_structure = Datastructure(**{
|
|
170
|
+
key: DataVariable(**{
|
|
171
|
+
k: v for k, v in state_info.items()
|
|
172
|
+
if k in DataVariable.model_fields
|
|
173
|
+
})
|
|
174
|
+
for key, state_info in state_dict.items()
|
|
175
|
+
})
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@classmethod
|
|
179
|
+
def _set_params(cls, config: Config, model: Model, default_prior: str):
|
|
180
|
+
params_info = model.params_info
|
|
181
|
+
|
|
182
|
+
if isinstance(model, (
|
|
183
|
+
RED_IT, RED_IT_DA, RED_IT_IA,
|
|
184
|
+
BufferGUTS_IT, BufferGUTS_IT_CA, BufferGUTS_IT_DA
|
|
185
|
+
)):
|
|
186
|
+
eps = config.jaxsolver.atol * 10
|
|
187
|
+
params_info["eps"] = {'name':'eps', 'initial':eps, 'vary':False}
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
for par, param_dict in params_info.items():
|
|
191
|
+
for k, v in model._params_info_defaults.items():
|
|
192
|
+
if k not in param_dict:
|
|
193
|
+
param_dict.update({k:v})
|
|
194
|
+
|
|
195
|
+
param_df = pd.DataFrame(params_info).T
|
|
196
|
+
param_df["param_index"] = param_df.name.apply(lambda x: re.findall(r"\d+", x))
|
|
197
|
+
param_df["param_index"] = param_df.param_index.apply(lambda x: int(x[0])-1 if len(x) == 1 else None)
|
|
198
|
+
param_df["name"] = param_df.name.apply(lambda x: re.sub(r"\d+", "", x).strip("_"))
|
|
199
|
+
|
|
200
|
+
for (param_name, ), group in param_df.groupby(["name"]):
|
|
201
|
+
|
|
202
|
+
dims = list(dict.fromkeys(group["dims"]))
|
|
203
|
+
dims = tuple([]) if dims == [None] else tuple(dims)
|
|
204
|
+
|
|
205
|
+
prior = list(dict.fromkeys(group["prior"]))
|
|
206
|
+
prior = prior[0] if len(prior) == 1 else prior
|
|
207
|
+
|
|
208
|
+
_min = np.min(np.ma.masked_invalid(group["min"].values.astype(float)))
|
|
209
|
+
_max = np.max(np.ma.masked_invalid(group["max"].values.astype(float)))
|
|
210
|
+
_init = group["initial"].values.astype(float)
|
|
211
|
+
_free = group["vary"].values
|
|
212
|
+
|
|
213
|
+
# TODO: allow for parsing one N-D prior from multiple priors
|
|
214
|
+
# TODO: Another choice would be to parse vary=False priors as deterministic
|
|
215
|
+
# and use a composite prior from a deterministic and a free prior as
|
|
216
|
+
# the input into the model
|
|
217
|
+
|
|
218
|
+
if prior is None:
|
|
219
|
+
if default_prior == "uniform":
|
|
220
|
+
_loc = _init * np.logical_not(_free) + _min * _free - config.jaxsolver.atol * 10 * np.logical_not(_free)
|
|
221
|
+
_scale = _init * np.logical_not(_free) + _max * _free + config.jaxsolver.atol * 10 * np.logical_not(_free)
|
|
222
|
+
_loc = _loc[0] if len(_loc) == 1 else _loc
|
|
223
|
+
_scale = _scale[0] if len(_scale) == 1 else _scale
|
|
224
|
+
prior = f"uniform(loc={_loc},scale={_scale})"
|
|
225
|
+
elif default_prior == "lognorm":
|
|
226
|
+
_s = 3 * _free + config.jaxsolver.atol * 10 * np.logical_not(_free)
|
|
227
|
+
_init = _init[0] if len(_init) == 1 else _init
|
|
228
|
+
_s = _s[0] if len(_s) == 1 else _s
|
|
229
|
+
|
|
230
|
+
prior = f"lognorm(scale={_init},s={_s})"
|
|
231
|
+
else:
|
|
232
|
+
raise ValueError(
|
|
233
|
+
f"Default prior: '{default_prior}' is not implemented. "+
|
|
234
|
+
"Use one of 'uniform', 'lognorm' or specify priors for each "+
|
|
235
|
+
"parameter directly with: "+
|
|
236
|
+
f"`model.params_dict['prior'] = {default_prior}(...)`"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
prior = prior.replace(" ", ",")
|
|
240
|
+
|
|
241
|
+
# if isinstance(value,float):
|
|
242
|
+
param = Param(
|
|
243
|
+
value=_init,
|
|
244
|
+
free=np.max(_free),
|
|
245
|
+
min=_min,
|
|
246
|
+
max=_max,
|
|
247
|
+
prior=prior,
|
|
248
|
+
dims=dims
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
setattr(config.model_parameters, param_name, param)
|
|
252
|
+
|
guts_base/sim/report.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import itertools as it
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
from pymob import SimulationBase
|
|
6
|
+
from pymob.sim.report import Report, reporting
|
|
7
|
+
|
|
8
|
+
from guts_base.plot import plot_survival_multipanel
|
|
9
|
+
from guts_base.sim.ecx import ECxEstimator
|
|
10
|
+
|
|
11
|
+
class GutsReport(Report):
|
|
12
|
+
|
|
13
|
+
def additional_reports(self, sim: "SimulationBase"):
|
|
14
|
+
super().additional_reports(sim=sim)
|
|
15
|
+
self.model_fits(sim)
|
|
16
|
+
self.LCx_estimates(sim)
|
|
17
|
+
|
|
18
|
+
@reporting
|
|
19
|
+
def model_fits(self, sim: SimulationBase):
|
|
20
|
+
self._write("### Survival model fits")
|
|
21
|
+
|
|
22
|
+
out_mp = plot_survival_multipanel(
|
|
23
|
+
sim=sim,
|
|
24
|
+
results=sim.inferer.idata.posterior_model_fits,
|
|
25
|
+
ncols=6,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
lab = self._label.format(placeholder='survival_fits')
|
|
29
|
+
self._write(f"})")
|
|
30
|
+
|
|
31
|
+
return out_mp
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@reporting
|
|
35
|
+
def LCx_estimates(self, sim):
|
|
36
|
+
X = [0.1, 0.25, 0.5, 0.75, 0.9]
|
|
37
|
+
T = [1, 2]
|
|
38
|
+
P = sim.predefined_scenarios
|
|
39
|
+
|
|
40
|
+
estimates = pd.DataFrame(
|
|
41
|
+
it.product(X, T, P.keys()),
|
|
42
|
+
columns=["x", "time", "scenario"]
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
ecx = []
|
|
46
|
+
|
|
47
|
+
for i, row in estimates.iterrows():
|
|
48
|
+
ecx_estimator = ECxEstimator(
|
|
49
|
+
sim=sim,
|
|
50
|
+
effect="survival",
|
|
51
|
+
x=row.x,
|
|
52
|
+
id=None,
|
|
53
|
+
time=row.time,
|
|
54
|
+
x_in=P[row.scenario],
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
ecx_estimator.estimate(
|
|
58
|
+
mode=sim.ecx_mode,
|
|
59
|
+
draws=250,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
ecx.append(ecx_estimator.results)
|
|
63
|
+
|
|
64
|
+
results = pd.DataFrame(ecx)
|
|
65
|
+
estimates[results.columns] = results
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
estimates.to_csv()
|
|
69
|
+
file = os.path.join(sim.output_path, "lcx_estimates.csv")
|
|
70
|
+
lab = self._label.format(placeholder='$LC_x$ estimates')
|
|
71
|
+
self._write_table(tab=estimates, label_insert=f"$LC_x$ estimates \label{{{lab}}}]({os.path.basename(file)})")
|
|
72
|
+
|
guts_base/sim.py
ADDED
|
File without changes
|