guts-base 0.8.6__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- guts_base/__init__.py +2 -1
- guts_base/data/__init__.py +1 -1
- guts_base/data/generator.py +2 -1
- guts_base/data/survival.py +6 -0
- guts_base/mod.py +24 -83
- guts_base/prob.py +23 -275
- guts_base/sim/__init__.py +10 -1
- guts_base/sim/base.py +285 -75
- guts_base/sim/constructors.py +31 -0
- guts_base/sim/ecx.py +174 -59
- guts_base/sim/mempy.py +85 -70
- guts_base/sim/report.py +0 -1
- guts_base/sim/utils.py +10 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.1.dist-info}/METADATA +2 -3
- guts_base-1.0.1.dist-info/RECORD +25 -0
- guts_base/sim.py +0 -0
- guts_base-0.8.6.dist-info/RECORD +0 -24
- {guts_base-0.8.6.dist-info → guts_base-1.0.1.dist-info}/WHEEL +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.1.dist-info}/entry_points.txt +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.1.dist-info}/top_level.txt +0 -0
guts_base/sim/ecx.py
CHANGED
|
@@ -9,39 +9,62 @@ from matplotlib import pyplot as plt
|
|
|
9
9
|
from tqdm import tqdm
|
|
10
10
|
|
|
11
11
|
from pymob import SimulationBase
|
|
12
|
+
from guts_base.sim.utils import GutsBaseError
|
|
12
13
|
|
|
13
14
|
class ECxEstimator:
|
|
14
15
|
"""Estimates the exposure level that corresponds to a given effect. The algorithm
|
|
15
|
-
operates by varying a given exposure profile (x_in)
|
|
16
|
+
operates by varying a given exposure profile (x_in). For each new estimation, a new
|
|
17
|
+
estimator is initialized.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
|
|
22
|
+
sim : SimulationBase
|
|
23
|
+
This must be a pymob.SimulationBase object. If the ECxEstimator.estimate method
|
|
24
|
+
is used with the modes 'draw' or 'mean'
|
|
25
|
+
|
|
16
26
|
"""
|
|
17
27
|
_name = "EC"
|
|
28
|
+
_parameter_msg = (
|
|
29
|
+
"Manual estimation (mode='manual', without using posterior information) requires " +
|
|
30
|
+
"specification of parameters={...}. You can obtain and modify " +
|
|
31
|
+
"parameters using the pymob API: `sim.config.model_parameters.value_dict` " +
|
|
32
|
+
"returns a dictionary of DEFAULT PARAMETERS that you can customize to your liking " +
|
|
33
|
+
"(https://pymob.readthedocs.io/en/stable/api/pymob.sim.html#pymob.sim.config.Modelparameters.value_dict)."
|
|
34
|
+
)
|
|
18
35
|
|
|
19
36
|
def __init__(
|
|
20
37
|
self,
|
|
21
38
|
sim: SimulationBase,
|
|
22
39
|
effect: str,
|
|
23
|
-
x: float
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
x_in: Optional[xr.Dataset]=None,
|
|
40
|
+
x: float,
|
|
41
|
+
time: float,
|
|
42
|
+
x_in: xr.Dataset,
|
|
27
43
|
):
|
|
28
44
|
self.sim = sim.copy()
|
|
29
45
|
self.time = time
|
|
30
46
|
self.x = x
|
|
31
|
-
self.id = id
|
|
32
47
|
self.effect = effect
|
|
33
48
|
self._mode = None
|
|
34
49
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
50
|
+
# creates an empty observation dataset with the coordinates of the
|
|
51
|
+
# original observations (especially time), except the ID, which is overwritten
|
|
52
|
+
# and taken from the x_in dataset
|
|
53
|
+
pseudo_obs = self.sim.observations.isel(id=[0])
|
|
54
|
+
pseudo_obs = pseudo_obs.drop([v for v in pseudo_obs.data_vars.keys()])
|
|
55
|
+
pseudo_obs["id"] = x_in["id"]
|
|
56
|
+
|
|
57
|
+
self.sim.config.data_structure.survival.observed = False
|
|
58
|
+
self.sim.observations = pseudo_obs
|
|
59
|
+
|
|
60
|
+
# overwrite x_in to make sure that parse_input takes x_in from exposure and
|
|
61
|
+
# does not use the string that is tied to another data variable which was
|
|
62
|
+
# originally present
|
|
63
|
+
self.sim.config.simulation.x_in = ["exposure=exposure"]
|
|
39
64
|
|
|
40
65
|
# ensure correct coordinate order with x_in and raise errors early
|
|
41
66
|
self.sim.model_parameters["x_in"] = self.sim.parse_input("x_in", x_in)
|
|
42
67
|
|
|
43
|
-
self.sim.config.data_structure.survival.observed = False
|
|
44
|
-
self.sim.observations = self.sim.observations.sel(id=self.sim.coordinates["id"])
|
|
45
68
|
|
|
46
69
|
# fix time after observations have been set. The outcome of the simulation
|
|
47
70
|
# can dependend on the time vector, because in e.g. IT models, the time resolution
|
|
@@ -51,7 +74,7 @@ class ECxEstimator:
|
|
|
51
74
|
self.sim.coordinates["time"], np.array(time, ndmin=1)
|
|
52
75
|
]))
|
|
53
76
|
|
|
54
|
-
self.sim.model_parameters["y0"] = self.sim.parse_input("y0", drop_dims="time")
|
|
77
|
+
self.sim.model_parameters["y0"] = self.sim.parse_input("y0", drop_dims=["time"])
|
|
55
78
|
self.sim.dispatch_constructor()
|
|
56
79
|
|
|
57
80
|
self.results = pd.Series({
|
|
@@ -63,6 +86,19 @@ class ECxEstimator:
|
|
|
63
86
|
"msg": np.nan
|
|
64
87
|
})
|
|
65
88
|
|
|
89
|
+
self.figure_profile_and_effect = None
|
|
90
|
+
self.figure_loss_curve = None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _assert_posterior(self):
|
|
94
|
+
try:
|
|
95
|
+
p = self.sim.inferer.idata.posterior
|
|
96
|
+
except AttributeError:
|
|
97
|
+
raise GutsBaseError(
|
|
98
|
+
"Using mode='mode' or mode='draws', but sim did not contain a posterior. " +
|
|
99
|
+
"('sim.inferer.idata.posterior'). " + self._parameter_msg
|
|
100
|
+
)
|
|
101
|
+
|
|
66
102
|
|
|
67
103
|
|
|
68
104
|
def _evaluate(self, factor, theta):
|
|
@@ -95,25 +131,128 @@ class ECxEstimator:
|
|
|
95
131
|
sample = {k: v["data"] for k, v in sample.to_dict()["data_vars"].items()}
|
|
96
132
|
return sample
|
|
97
133
|
|
|
98
|
-
def plot_loss_curve(self
|
|
99
|
-
|
|
134
|
+
def plot_loss_curve(self,
|
|
135
|
+
mode: Literal["draws", "mean", "manual"] = "draws",
|
|
136
|
+
draws: Optional[int] = None,
|
|
137
|
+
parameters: Optional[Dict[str,float|List[float]]] = None,
|
|
138
|
+
log_x0: float = 0.0,
|
|
139
|
+
force_draws: bool = False
|
|
140
|
+
):
|
|
141
|
+
"""
|
|
142
|
+
Parameters
|
|
143
|
+
----------
|
|
144
|
+
|
|
145
|
+
mode : Literal['draws', 'mean', 'manual']
|
|
146
|
+
mode of estimation. mode='mean' takes the mean of the posterior and estimate
|
|
147
|
+
the ECx for this singular value. mode='draws' takes samples from the posterior
|
|
148
|
+
and estimate the ECx for each of the parameter draws. mode='manual' takes
|
|
149
|
+
a parameter set (Dict) in the parameters argument and uses that for estimation.
|
|
150
|
+
Default: 'draws'
|
|
151
|
+
|
|
152
|
+
draws : int
|
|
153
|
+
Number of draws to take from the posterior. Only takes effect if mode='draw'.
|
|
154
|
+
Raises an exception if draws < 100, because this is insufficient for a
|
|
155
|
+
reasonable uncertainty estimate. Default: None (using all samples from the
|
|
156
|
+
posterior)
|
|
157
|
+
|
|
158
|
+
parameters : Dict[str,float|list[float]]
|
|
159
|
+
a parameter dictionary passed used as model parameters for finding the ECx
|
|
160
|
+
value. Default: None
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
log_x0 : float
|
|
164
|
+
the starting value for the multiplication factor of the exposure profile for
|
|
165
|
+
the minimization algorithm. This value is on the log scale. This means,
|
|
166
|
+
exp(log_x0=0.0) = 1.0, which means that the log_x0=0.0 will start at an
|
|
167
|
+
unmodified exposure profile. Default: 0.0
|
|
168
|
+
|
|
169
|
+
force_draws : bool
|
|
170
|
+
Force the estimate method to accept a number of draws less than 100. Default: False
|
|
171
|
+
|
|
172
|
+
"""
|
|
173
|
+
draws = self._check_mode_and_draws_and_parameters(mode, draws, parameters, force_draws)
|
|
100
174
|
|
|
101
|
-
factor = np.linspace(-2,2, 100)
|
|
102
|
-
y = list(map(partial(self._loss, theta=posterior_mean), factor))
|
|
103
175
|
|
|
176
|
+
factor = np.linspace(-2,2, 100) + log_x0
|
|
104
177
|
fig, ax = plt.subplots(1,1, sharey=True, figsize=(4, 3))
|
|
178
|
+
|
|
179
|
+
for i in tqdm(range(draws)):
|
|
180
|
+
if mode == "draws":
|
|
181
|
+
sample = self._posterior_sample(i)
|
|
182
|
+
elif mode == "mean":
|
|
183
|
+
sample = self._posterior_mean()
|
|
184
|
+
elif mode == "manual":
|
|
185
|
+
sample = parameters
|
|
186
|
+
else:
|
|
187
|
+
raise NotImplementedError(
|
|
188
|
+
f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
y = list(map(partial(self._loss, theta=sample), factor))
|
|
192
|
+
|
|
193
|
+
ax.plot(
|
|
194
|
+
np.exp(factor), y,
|
|
195
|
+
color="black",
|
|
196
|
+
)
|
|
197
|
+
|
|
105
198
|
ax.plot(
|
|
106
|
-
|
|
107
|
-
color="black",
|
|
199
|
+
[], [], color="black",
|
|
108
200
|
label=f"$\ell = S(t={self.time},x_{{in}}=C_{{ext}} \cdot \phi) - {self.x}$"
|
|
109
201
|
)
|
|
110
202
|
ax.set_ylabel("Loss ($\ell$)")
|
|
111
203
|
ax.set_xlabel("Multiplication factor ($\phi$)")
|
|
112
204
|
ax.set_title(f"ID: {self.sim.coordinates['id'][0]}")
|
|
113
|
-
ax.set_ylim(0,
|
|
205
|
+
ax.set_ylim(0, ax.get_ylim()[1] * 1.25)
|
|
114
206
|
ax.legend(frameon=False)
|
|
115
207
|
fig.tight_layout()
|
|
116
208
|
|
|
209
|
+
self.figure_loss_curve = fig
|
|
210
|
+
|
|
211
|
+
def _check_mode_and_draws_and_parameters(self, mode, draws, parameters, force_draws):
|
|
212
|
+
|
|
213
|
+
if mode == "draws":
|
|
214
|
+
self._assert_posterior()
|
|
215
|
+
|
|
216
|
+
if draws is None:
|
|
217
|
+
draws = (
|
|
218
|
+
self.sim.inferer.idata.posterior.sizes["chain"] *
|
|
219
|
+
self.sim.inferer.idata.posterior.sizes["draw"]
|
|
220
|
+
)
|
|
221
|
+
elif draws < 100 and not force_draws:
|
|
222
|
+
raise GutsBaseError(
|
|
223
|
+
"draws must be larger than 100. Preferably > 1000. " +
|
|
224
|
+
f"If you don't want uncertainty assessment of the {self._name} " +
|
|
225
|
+
"estimates, use mode='mean'. If you really want to use less than " +
|
|
226
|
+
"100 draws, use force_draws=True at your own risk."
|
|
227
|
+
)
|
|
228
|
+
else:
|
|
229
|
+
pass
|
|
230
|
+
|
|
231
|
+
warnings.warn(
|
|
232
|
+
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
elif mode == "mean":
|
|
236
|
+
self._assert_posterior()
|
|
237
|
+
|
|
238
|
+
draws = 1
|
|
239
|
+
|
|
240
|
+
warnings.warn(
|
|
241
|
+
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
elif mode == "manual":
|
|
245
|
+
draws = 1
|
|
246
|
+
if parameters is None:
|
|
247
|
+
raise GutsBaseError(self._parameter_msg)
|
|
248
|
+
else:
|
|
249
|
+
raise GutsBaseError(
|
|
250
|
+
f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
return draws
|
|
254
|
+
|
|
255
|
+
|
|
117
256
|
def estimate(
|
|
118
257
|
self,
|
|
119
258
|
mode: Literal["draws", "mean", "manual"] = "draws",
|
|
@@ -124,6 +263,7 @@ class ECxEstimator:
|
|
|
124
263
|
optimizer_tol: float = 1e-5,
|
|
125
264
|
method: str = "cobyla",
|
|
126
265
|
show_plot: bool = True,
|
|
266
|
+
force_draws: bool = False,
|
|
127
267
|
**optimizer_kwargs
|
|
128
268
|
):
|
|
129
269
|
"""The minimizer for the EC_x operates on the unbounded linear scale, estimating
|
|
@@ -177,48 +317,15 @@ class ECxEstimator:
|
|
|
177
317
|
show_plot : bool
|
|
178
318
|
Show the results plot of the lpx. Default: True
|
|
179
319
|
|
|
320
|
+
force_draws : bool
|
|
321
|
+
Force the estimate method to accept a number of draws less than 100. Default: False
|
|
322
|
+
|
|
180
323
|
optimizer_kwargs :
|
|
181
324
|
Additional arguments to pass to the optimizer
|
|
182
325
|
|
|
183
326
|
"""
|
|
184
327
|
x0_tries = np.array([0.0, -1.0, 1.0, -2.0, 2.0]) + log_x0
|
|
185
|
-
|
|
186
|
-
if mode == "draws":
|
|
187
|
-
if draws is None:
|
|
188
|
-
draws = (
|
|
189
|
-
self.sim.inferer.idata.posterior.sizes["chain"] *
|
|
190
|
-
self.sim.inferer.idata.posterior.sizes["draw"]
|
|
191
|
-
)
|
|
192
|
-
elif draws < 100:
|
|
193
|
-
raise ValueError(
|
|
194
|
-
"draws must be larger than 100. Preferably > 1000. "
|
|
195
|
-
f"If you don't want uncertainty assessment of the {self._name} "
|
|
196
|
-
"estimates, use mode='mean'"
|
|
197
|
-
)
|
|
198
|
-
else:
|
|
199
|
-
pass
|
|
200
|
-
|
|
201
|
-
warnings.warn(
|
|
202
|
-
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
elif mode == "mean":
|
|
206
|
-
draws = 1
|
|
207
|
-
|
|
208
|
-
warnings.warn(
|
|
209
|
-
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
elif mode == "manual":
|
|
213
|
-
draws = 1
|
|
214
|
-
if parameters is None:
|
|
215
|
-
raise ValueError(
|
|
216
|
-
"parameters need to be provided if mode='manual'"
|
|
217
|
-
)
|
|
218
|
-
else:
|
|
219
|
-
raise NotImplementedError(
|
|
220
|
-
f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
|
|
221
|
-
)
|
|
328
|
+
draws = self._check_mode_and_draws_and_parameters(mode, draws, parameters, force_draws)
|
|
222
329
|
|
|
223
330
|
self._mode = mode
|
|
224
331
|
mult_factor = []
|
|
@@ -310,7 +417,7 @@ class ECxEstimator:
|
|
|
310
417
|
self.sim.dispatch_constructor()
|
|
311
418
|
|
|
312
419
|
if self._mode is None:
|
|
313
|
-
raise
|
|
420
|
+
raise GutsBaseError(
|
|
314
421
|
"Run .estimate() before plot_profile_and_effect()"
|
|
315
422
|
)
|
|
316
423
|
elif self._mode == "mean" or self._mode == "draws":
|
|
@@ -375,6 +482,8 @@ class ECxEstimator:
|
|
|
375
482
|
ax2.set_ylim(0, None)
|
|
376
483
|
fig.tight_layout()
|
|
377
484
|
|
|
485
|
+
self.figure_profile_and_effect = fig
|
|
486
|
+
|
|
378
487
|
self.sim.coordinates["time"] = coordinates_backup
|
|
379
488
|
self.sim.dispatch_constructor()
|
|
380
489
|
|
|
@@ -396,4 +505,10 @@ class LPxEstimator(ECxEstimator):
|
|
|
396
505
|
):
|
|
397
506
|
x_in = sim.model_parameters["x_in"].sel(id=[id])
|
|
398
507
|
time = sim.coordinates["time"][-1]
|
|
399
|
-
super().__init__(
|
|
508
|
+
super().__init__(
|
|
509
|
+
sim=sim,
|
|
510
|
+
effect="survival",
|
|
511
|
+
x=x,
|
|
512
|
+
time=time,
|
|
513
|
+
x_in=x_in
|
|
514
|
+
)
|
guts_base/sim/mempy.py
CHANGED
|
@@ -1,44 +1,64 @@
|
|
|
1
1
|
import pathlib
|
|
2
|
-
from typing import Dict, Optional, Literal
|
|
2
|
+
from typing import Dict, Optional, Literal, Protocol, TypedDict, List
|
|
3
3
|
import re
|
|
4
|
-
|
|
4
|
+
import os
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import xarray as xr
|
|
8
|
-
from pymob import SimulationBase
|
|
9
8
|
from pymob.sim.config import Config, DataVariable, Datastructure
|
|
10
9
|
from pymob.sim.parameters import Param
|
|
11
10
|
from guts_base.sim import GutsBase
|
|
12
|
-
from mempy.model import (
|
|
13
|
-
Model,
|
|
14
|
-
RED_IT,
|
|
15
|
-
RED_SD,
|
|
16
|
-
RED_IT_DA,
|
|
17
|
-
RED_SD_DA,
|
|
18
|
-
RED_IT_IA,
|
|
19
|
-
RED_SD_IA,
|
|
20
|
-
BufferGUTS_IT,
|
|
21
|
-
BufferGUTS_IT_CA,
|
|
22
|
-
BufferGUTS_IT_DA
|
|
23
|
-
)
|
|
24
11
|
|
|
25
12
|
__all__ = [
|
|
26
13
|
"PymobSimulator",
|
|
27
14
|
]
|
|
28
15
|
|
|
16
|
+
class ParamsInfoDict(TypedDict):
|
|
17
|
+
name: str
|
|
18
|
+
min: float
|
|
19
|
+
max: float
|
|
20
|
+
initial: float
|
|
21
|
+
vary: bool
|
|
22
|
+
prior: str
|
|
23
|
+
|
|
24
|
+
class StateVariablesDict(TypedDict):
|
|
25
|
+
dimensions: List[str]
|
|
26
|
+
observed: bool
|
|
27
|
+
y0: List[float]
|
|
28
|
+
|
|
29
|
+
class Model(Protocol):
|
|
30
|
+
extra_dim: Optional[str]
|
|
31
|
+
params_info: Dict[str, ParamsInfoDict]
|
|
32
|
+
state_variables: Dict[str, StateVariablesDict]
|
|
33
|
+
_params_info_defaults: Dict[str, ParamsInfoDict]
|
|
34
|
+
_it_model: bool
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def _rhs_jax():
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def _solver_post_processing():
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def _likelihood_func_jax():
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
|
|
29
49
|
class PymobSimulator(GutsBase):
|
|
30
50
|
|
|
31
51
|
@classmethod
|
|
32
|
-
def
|
|
52
|
+
def from_model_and_dataset(
|
|
33
53
|
cls,
|
|
34
|
-
exposure_data: Dict,
|
|
35
|
-
survival_data: Dict,
|
|
36
54
|
model: Model,
|
|
55
|
+
exposure_data: Dict[str, pd.DataFrame],
|
|
56
|
+
survival_data: pd.DataFrame,
|
|
37
57
|
info_dict: Dict = {},
|
|
38
58
|
pymob_config: Optional[Config] = None,
|
|
39
59
|
output_directory: str|pathlib.Path = pathlib.Path("output/pymob"),
|
|
40
60
|
default_prior: Literal["uniform", "lognorm"] = "lognorm",
|
|
41
|
-
) ->
|
|
61
|
+
) -> "PymobSimulator":
|
|
42
62
|
"""Construct a PymobSimulator from the
|
|
43
63
|
"""
|
|
44
64
|
|
|
@@ -46,7 +66,7 @@ class PymobSimulator(GutsBase):
|
|
|
46
66
|
cfg = Config()
|
|
47
67
|
# Configure: The configuration can be overridden in a subclass to override the
|
|
48
68
|
# configuration
|
|
49
|
-
cls.
|
|
69
|
+
cls._configure(config=cfg)
|
|
50
70
|
else:
|
|
51
71
|
cfg = pymob_config
|
|
52
72
|
|
|
@@ -55,6 +75,14 @@ class PymobSimulator(GutsBase):
|
|
|
55
75
|
|
|
56
76
|
cfg.case_study.output = str(output_directory)
|
|
57
77
|
|
|
78
|
+
# overrides scenario path. This means the scenario is also expected in the
|
|
79
|
+
# same folder
|
|
80
|
+
cfg.case_study.scenario_path_override = str(output_directory)
|
|
81
|
+
cfg.case_study.scenario = output_directory.stem
|
|
82
|
+
cfg.case_study.data = cfg.case_study.output_path
|
|
83
|
+
cfg.case_study.observations = "observations.nc"
|
|
84
|
+
cfg.create_directory(directory="results", force=True)
|
|
85
|
+
|
|
58
86
|
# parse observations
|
|
59
87
|
# obs can be simply subset by selection obs.sel(substance="Exposure-Dime")
|
|
60
88
|
observations = xr.combine_by_coords([
|
|
@@ -62,9 +90,19 @@ class PymobSimulator(GutsBase):
|
|
|
62
90
|
cls._survival_data_to_xarray(survival_data)
|
|
63
91
|
])
|
|
64
92
|
|
|
93
|
+
observations.to_netcdf(
|
|
94
|
+
os.path.join(cfg.case_study.output_path, cfg.case_study.observations)
|
|
95
|
+
)
|
|
96
|
+
|
|
65
97
|
# configure model and likelihood function
|
|
66
|
-
|
|
67
|
-
cfg.
|
|
98
|
+
# extract the fully qualified name of the model module.name
|
|
99
|
+
cfg.simulation.model_class = "{module}.{name}".format(
|
|
100
|
+
module=model.__module__, name=type(model).__name__
|
|
101
|
+
)
|
|
102
|
+
cfg.inference_numpyro.user_defined_error_model = "{module}.{name}".format(
|
|
103
|
+
module=model._likelihood_func_jax.__module__,
|
|
104
|
+
name=model._likelihood_func_jax.__name__
|
|
105
|
+
)
|
|
68
106
|
|
|
69
107
|
# derive data structure and params from the model instance
|
|
70
108
|
cls._set_data_structure(config=cfg, model=model)
|
|
@@ -75,35 +113,37 @@ class PymobSimulator(GutsBase):
|
|
|
75
113
|
cfg.simulation.y0 = [f"{k}={v['y0']}" for k, v in model.state_variables.items() if "y0" in v]
|
|
76
114
|
|
|
77
115
|
# create a simulation object
|
|
116
|
+
# It is essential that all post processing tasks are done in self.setup()
|
|
117
|
+
# which is extended below. This ensures that the simulation can also be run
|
|
118
|
+
# from automated tools like pymob-infer
|
|
78
119
|
sim = cls(config=cfg)
|
|
79
|
-
sim.
|
|
80
|
-
|
|
81
|
-
# initialize
|
|
82
|
-
sim.load_modules()
|
|
83
|
-
sim.set_logger()
|
|
84
|
-
|
|
85
|
-
sim.initialize(input={"observations": observations, "model": model})
|
|
86
|
-
|
|
87
|
-
sim.validate()
|
|
88
|
-
sim.dispatch_constructor()
|
|
89
|
-
|
|
90
|
-
|
|
120
|
+
sim.setup()
|
|
91
121
|
return sim
|
|
92
122
|
|
|
93
|
-
def
|
|
94
|
-
|
|
95
|
-
|
|
123
|
+
def reset_observations(self):
|
|
124
|
+
"""Resets the observations to the original observations after using .from_mempy(...)
|
|
125
|
+
This also resets the sim.coordinates dictionary.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
self.observations = self._obs_backup
|
|
96
129
|
|
|
97
|
-
|
|
130
|
+
def setup(self, **evaluator_kwargs):
|
|
131
|
+
super().setup(**evaluator_kwargs)
|
|
132
|
+
self._obs_backup = self.observations.copy(deep=True)
|
|
98
133
|
|
|
99
134
|
|
|
100
135
|
@classmethod
|
|
101
|
-
def
|
|
136
|
+
def _configure(cls, config: Config):
|
|
102
137
|
"""This is normally set in the configuration file passed to a SimulationBase class.
|
|
103
138
|
Since the mempy to pymob converter initializes pymob.SimulationBase from scratch
|
|
104
139
|
(without using a config file), the necessary settings have to be specified here.
|
|
105
140
|
"""
|
|
106
141
|
config.case_study.output = "results"
|
|
142
|
+
config.case_study.simulation = "PymobSimulator"
|
|
143
|
+
|
|
144
|
+
# this must be named guts_base, whihc is the name of the pip package and
|
|
145
|
+
# this regulates which packages are loaded.
|
|
146
|
+
config.case_study.name = "guts_base"
|
|
107
147
|
|
|
108
148
|
config.simulation.x_dimension = "time"
|
|
109
149
|
config.simulation.batch_dimension = "id"
|
|
@@ -132,33 +172,6 @@ class PymobSimulator(GutsBase):
|
|
|
132
172
|
config.inference_numpyro.svi_iterations = 10_000
|
|
133
173
|
config.inference_numpyro.svi_learning_rate = 0.001
|
|
134
174
|
|
|
135
|
-
@staticmethod
|
|
136
|
-
def _exposure_data_to_xarray(exposure_data: Dict[str, pd.DataFrame], dim: str):
|
|
137
|
-
"""
|
|
138
|
-
TODO: Currently no rect interpolation
|
|
139
|
-
"""
|
|
140
|
-
arrays = {}
|
|
141
|
-
for key, df in exposure_data.items():
|
|
142
|
-
# this override is necessary to make all dimensions work out
|
|
143
|
-
df.index.name = "time"
|
|
144
|
-
arrays.update({
|
|
145
|
-
key: df.to_xarray().to_dataarray(dim="id", name=key)
|
|
146
|
-
})
|
|
147
|
-
|
|
148
|
-
exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
|
|
149
|
-
exposure_array = exposure_array.transpose("id", "time", ...)
|
|
150
|
-
return xr.Dataset({"exposure": exposure_array})
|
|
151
|
-
|
|
152
|
-
@staticmethod
|
|
153
|
-
def _survival_data_to_xarray(survival_data: pd.DataFrame):
|
|
154
|
-
# TODO: survival name is currently not kept because the raw data is not transferred from the survival
|
|
155
|
-
survival_data.index.name = "time"
|
|
156
|
-
|
|
157
|
-
survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
|
|
158
|
-
survival_array = survival_array.transpose("id", "time", ...)
|
|
159
|
-
arrays = {"survival": survival_array}
|
|
160
|
-
return xr.Dataset(arrays)
|
|
161
|
-
|
|
162
175
|
@classmethod
|
|
163
176
|
def _set_data_structure(cls, config: Config, model: Model):
|
|
164
177
|
"""Takes a dictionary that is specified in the model and uses only keys that
|
|
@@ -179,10 +192,7 @@ class PymobSimulator(GutsBase):
|
|
|
179
192
|
def _set_params(cls, config: Config, model: Model, default_prior: str):
|
|
180
193
|
params_info = model.params_info
|
|
181
194
|
|
|
182
|
-
if
|
|
183
|
-
RED_IT, RED_IT_DA, RED_IT_IA,
|
|
184
|
-
BufferGUTS_IT, BufferGUTS_IT_CA, BufferGUTS_IT_DA
|
|
185
|
-
)):
|
|
195
|
+
if model._it_model:
|
|
186
196
|
eps = config.jaxsolver.atol * 10
|
|
187
197
|
params_info["eps"] = {'name':'eps', 'initial':eps, 'vary':False}
|
|
188
198
|
|
|
@@ -210,6 +220,11 @@ class PymobSimulator(GutsBase):
|
|
|
210
220
|
_init = group["initial"].values.astype(float)
|
|
211
221
|
_free = group["vary"].values
|
|
212
222
|
|
|
223
|
+
if isinstance(_min, np.ma.core.MaskedConstant):
|
|
224
|
+
_min = None
|
|
225
|
+
if isinstance(_max, np.ma.core.MaskedConstant):
|
|
226
|
+
_max = None
|
|
227
|
+
|
|
213
228
|
# TODO: allow for parsing one N-D prior from multiple priors
|
|
214
229
|
# TODO: Another choice would be to parse vary=False priors as deterministic
|
|
215
230
|
# and use a composite prior from a deterministic and a free prior as
|
guts_base/sim/report.py
CHANGED
guts_base/sim/utils.py
ADDED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: guts_base
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.1
|
|
4
4
|
Summary: Basic GUTS model implementation in pymob
|
|
5
5
|
Author-email: Florian Schunck <fluncki@protonmail.com>
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -692,8 +692,7 @@ License-File: LICENSE
|
|
|
692
692
|
Requires-Dist: openpyxl>=3.1.3
|
|
693
693
|
Requires-Dist: Bottleneck>=1.5.0
|
|
694
694
|
Requires-Dist: expydb>=0.6.0
|
|
695
|
-
Requires-Dist:
|
|
696
|
-
Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.4.1
|
|
695
|
+
Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.5.10
|
|
697
696
|
Provides-Extra: dev
|
|
698
697
|
Requires-Dist: pytest>=7.3; extra == "dev"
|
|
699
698
|
Requires-Dist: bumpver; extra == "dev"
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
guts_base/__init__.py,sha256=AqEsieZlFfc_SqQ4wN-pCXvqRCkjfF35TSOWiO_9hco,227
|
|
2
|
+
guts_base/mod.py,sha256=AzOCg1A8FP5EtVfp-66HT7G7h_wnHkenieaxTc9qCyk,5796
|
|
3
|
+
guts_base/plot.py,sha256=rjIzyPUikOtubjWI5mX_ZY8ARWl2NiHDmwAZAdoaPS8,5693
|
|
4
|
+
guts_base/prob.py,sha256=ITwo5dAGMHr5xTldilHMbKU6AFsWo4_ZwbfaXh97Gew,5443
|
|
5
|
+
guts_base/data/__init__.py,sha256=JBgft1DTledwvB5hRZnyGiKWv-RXo1OIpb5kJXloOmo,826
|
|
6
|
+
guts_base/data/expydb.py,sha256=Kcc6CeZMl3oEelk5UBN9VEfwgNF3CzTh13ooVkufAjE,8218
|
|
7
|
+
guts_base/data/generator.py,sha256=rGOZU3B0Ho8V6KtfjcAmely8lnlqNFV8cRyGboayTRc,2910
|
|
8
|
+
guts_base/data/openguts.py,sha256=WvhYl_AOdvNgzrcVS2f_PYbXNH_wSAz2uIBSR6BMSh0,11078
|
|
9
|
+
guts_base/data/preprocessing.py,sha256=qggYkx2x62ingU1BNhJFyL1eQdFQsDJR2lefVfVWW2U,1732
|
|
10
|
+
guts_base/data/survival.py,sha256=U-Ehloo8vnD81VeIglXLEUHX9lt7SjtEs2YEB0D9FHE,5096
|
|
11
|
+
guts_base/data/time_of_death.py,sha256=hwngUwfRP3u8WmD3dHyXrphuu5d8ZJTKyBovGRwAHNQ,21014
|
|
12
|
+
guts_base/data/utils.py,sha256=u3gGDJK15MfRUP4iIxsS-I1oqxD2qH_ugsT7o_Eac18,236
|
|
13
|
+
guts_base/sim/__init__.py,sha256=5VgzsOXjMsylfWy4neeovZBG1w6G6p_tFaIfhnIeLPM,415
|
|
14
|
+
guts_base/sim/base.py,sha256=7u7gj6HQSCtFUXrrX-ngfWsWnFEMk5NrDVJH0CIolAs,28960
|
|
15
|
+
guts_base/sim/constructors.py,sha256=Kz9FHIH3EHsSIKd9sQgHa3eveniFifFlk1Hf-QR69Pg,875
|
|
16
|
+
guts_base/sim/ecx.py,sha256=peO93pKcQwF5eULo5BdUl4KLnXzrTfvRGI8qsbAYQKU,19137
|
|
17
|
+
guts_base/sim/mempy.py,sha256=IHd87UrmdXpC7y7q0IjYQJH075frjbp2a-dMVBeqZ0U,10164
|
|
18
|
+
guts_base/sim/report.py,sha256=swTAp1qNTjDwWh-YOq_TLqtSxjQvjdwaNVHpEW8kXK4,1959
|
|
19
|
+
guts_base/sim/utils.py,sha256=Qj_FPH6kywVxOwgCerS7w5CyuYR9HKmvBWFpmxwDFgk,256
|
|
20
|
+
guts_base-1.0.1.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
21
|
+
guts_base-1.0.1.dist-info/METADATA,sha256=vYcnGPbbGHEu-1yfF_OWliNRkQ4Ly3IDDK9Khs1DlCc,45406
|
|
22
|
+
guts_base-1.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
guts_base-1.0.1.dist-info/entry_points.txt,sha256=icsHzG2jQ90ZS7XvLsI5Qj0-qGuPv2T0RBVN5daGCPU,183
|
|
24
|
+
guts_base-1.0.1.dist-info/top_level.txt,sha256=PxhBgUd4r39W_VI4FyJjARwKbV5_glgCVnd6v_zAGdE,10
|
|
25
|
+
guts_base-1.0.1.dist-info/RECORD,,
|
guts_base/sim.py
DELETED
|
File without changes
|
guts_base-0.8.6.dist-info/RECORD
DELETED
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
guts_base/__init__.py,sha256=dYv5zH5jYPdPZXIzku6YPkvB47XD8qaTlY-0olh-4cA,208
|
|
2
|
-
guts_base/mod.py,sha256=aLb09E-6CwgKgq0GbwIT_Alv167cVFA5zAxwx80F1aE,7454
|
|
3
|
-
guts_base/plot.py,sha256=rjIzyPUikOtubjWI5mX_ZY8ARWl2NiHDmwAZAdoaPS8,5693
|
|
4
|
-
guts_base/prob.py,sha256=k1b-gc8yZNg8DMnxJHJzCy8Pud8_qEygGkrnmS2AcMY,14004
|
|
5
|
-
guts_base/sim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
-
guts_base/data/__init__.py,sha256=U5sJJgPwEVHVGTLeZNwl8VKLIgn1twpDNkOCMRSV1Nk,808
|
|
7
|
-
guts_base/data/expydb.py,sha256=Kcc6CeZMl3oEelk5UBN9VEfwgNF3CzTh13ooVkufAjE,8218
|
|
8
|
-
guts_base/data/generator.py,sha256=fos7y-xxTlBVpYsHD_2kdYSH40KILPUQ_Jy0mwjO5jc,2897
|
|
9
|
-
guts_base/data/openguts.py,sha256=WvhYl_AOdvNgzrcVS2f_PYbXNH_wSAz2uIBSR6BMSh0,11078
|
|
10
|
-
guts_base/data/preprocessing.py,sha256=qggYkx2x62ingU1BNhJFyL1eQdFQsDJR2lefVfVWW2U,1732
|
|
11
|
-
guts_base/data/survival.py,sha256=yNNZ7MabzrMeqPHOa8-LPxCgsIN4_X0aj_ANmeC9Wd8,4878
|
|
12
|
-
guts_base/data/time_of_death.py,sha256=hwngUwfRP3u8WmD3dHyXrphuu5d8ZJTKyBovGRwAHNQ,21014
|
|
13
|
-
guts_base/data/utils.py,sha256=u3gGDJK15MfRUP4iIxsS-I1oqxD2qH_ugsT7o_Eac18,236
|
|
14
|
-
guts_base/sim/__init__.py,sha256=vgZu2oi4RHIQuLS19570xBdLQasL75puxabGrbLuIGA,276
|
|
15
|
-
guts_base/sim/base.py,sha256=lxfHQ57pXKz1ZGnFqOIoeUplzvFWM1OjDX0xYQpSgaw,20921
|
|
16
|
-
guts_base/sim/ecx.py,sha256=IvZQnmMs3FLX6unuj7mPv2AaZ7l-S4TJmYmPI4fXE1k,14573
|
|
17
|
-
guts_base/sim/mempy.py,sha256=xsA3Q7jQrdcmQZJzKBdQnr8yZk-f36YgbLlV-VvJpxg,9514
|
|
18
|
-
guts_base/sim/report.py,sha256=D_6eoSvjdcsDtB0PlmNkzC1LRk_7WXzb_T6D6AZjQNE,1985
|
|
19
|
-
guts_base-0.8.6.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
20
|
-
guts_base-0.8.6.dist-info/METADATA,sha256=n0OWLZd2CZWg7j7M2sqcXfFuiOSPqSG60KvQ9IuoqHM,45437
|
|
21
|
-
guts_base-0.8.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
-
guts_base-0.8.6.dist-info/entry_points.txt,sha256=icsHzG2jQ90ZS7XvLsI5Qj0-qGuPv2T0RBVN5daGCPU,183
|
|
23
|
-
guts_base-0.8.6.dist-info/top_level.txt,sha256=PxhBgUd4r39W_VI4FyJjARwKbV5_glgCVnd6v_zAGdE,10
|
|
24
|
-
guts_base-0.8.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|