guts-base 0.8.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- guts_base/__init__.py +2 -1
- guts_base/data/__init__.py +1 -1
- guts_base/data/generator.py +2 -1
- guts_base/data/survival.py +6 -0
- guts_base/mod.py +24 -83
- guts_base/prob.py +23 -275
- guts_base/sim/__init__.py +10 -1
- guts_base/sim/base.py +285 -75
- guts_base/sim/constructors.py +31 -0
- guts_base/sim/ecx.py +168 -58
- guts_base/sim/mempy.py +85 -70
- guts_base/sim/report.py +0 -1
- guts_base/sim/utils.py +10 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/METADATA +2 -3
- guts_base-1.0.0.dist-info/RECORD +25 -0
- guts_base/sim.py +0 -0
- guts_base-0.8.6.dist-info/RECORD +0 -24
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/WHEEL +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/entry_points.txt +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/top_level.txt +0 -0
guts_base/sim/ecx.py
CHANGED
|
@@ -9,39 +9,57 @@ from matplotlib import pyplot as plt
|
|
|
9
9
|
from tqdm import tqdm
|
|
10
10
|
|
|
11
11
|
from pymob import SimulationBase
|
|
12
|
+
from guts_base.sim.utils import GutsBaseError
|
|
12
13
|
|
|
13
14
|
class ECxEstimator:
|
|
14
15
|
"""Estimates the exposure level that corresponds to a given effect. The algorithm
|
|
15
|
-
operates by varying a given exposure profile (x_in)
|
|
16
|
+
operates by varying a given exposure profile (x_in). For each new estimation, a new
|
|
17
|
+
estimator is initialized.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
|
|
22
|
+
sim : SimulationBase
|
|
23
|
+
This must be a pymob.SimulationBase object. If the ECxEstimator.estimate method
|
|
24
|
+
is used with the modes 'draw' or 'mean'
|
|
25
|
+
|
|
16
26
|
"""
|
|
17
27
|
_name = "EC"
|
|
28
|
+
_parameter_msg = (
|
|
29
|
+
"Manual estimation (mode='manual', without using posterior information) requires " +
|
|
30
|
+
"specification of parameters={...}. You can obtain and modify " +
|
|
31
|
+
"parameters using the pymob API: `sim.config.model_parameters.value_dict` " +
|
|
32
|
+
"returns a dictionary of DEFAULT PARAMETERS that you can customize to your liking " +
|
|
33
|
+
"(https://pymob.readthedocs.io/en/stable/api/pymob.sim.html#pymob.sim.config.Modelparameters.value_dict)."
|
|
34
|
+
)
|
|
18
35
|
|
|
19
36
|
def __init__(
|
|
20
37
|
self,
|
|
21
38
|
sim: SimulationBase,
|
|
22
39
|
effect: str,
|
|
23
|
-
x: float
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
x_in: Optional[xr.Dataset]=None,
|
|
40
|
+
x: float,
|
|
41
|
+
time: float,
|
|
42
|
+
x_in: xr.Dataset,
|
|
27
43
|
):
|
|
28
44
|
self.sim = sim.copy()
|
|
29
45
|
self.time = time
|
|
30
46
|
self.x = x
|
|
31
|
-
self.id = id
|
|
32
47
|
self.effect = effect
|
|
33
48
|
self._mode = None
|
|
34
49
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
50
|
+
# creates an empty observation dataset with the coordinates of the
|
|
51
|
+
# original observations (especially time), except the ID, which is overwritten
|
|
52
|
+
# and taken from the x_in dataset
|
|
53
|
+
pseudo_obs = self.sim.observations.isel(id=[0])
|
|
54
|
+
pseudo_obs = pseudo_obs.drop(["exposure","survival"])
|
|
55
|
+
pseudo_obs["id"] = x_in["id"]
|
|
56
|
+
|
|
57
|
+
self.sim.config.data_structure.survival.observed = False
|
|
58
|
+
self.sim.observations = pseudo_obs
|
|
39
59
|
|
|
40
60
|
# ensure correct coordinate order with x_in and raise errors early
|
|
41
61
|
self.sim.model_parameters["x_in"] = self.sim.parse_input("x_in", x_in)
|
|
42
62
|
|
|
43
|
-
self.sim.config.data_structure.survival.observed = False
|
|
44
|
-
self.sim.observations = self.sim.observations.sel(id=self.sim.coordinates["id"])
|
|
45
63
|
|
|
46
64
|
# fix time after observations have been set. The outcome of the simulation
|
|
47
65
|
# can dependend on the time vector, because in e.g. IT models, the time resolution
|
|
@@ -63,6 +81,19 @@ class ECxEstimator:
|
|
|
63
81
|
"msg": np.nan
|
|
64
82
|
})
|
|
65
83
|
|
|
84
|
+
self.figure_profile_and_effect = None
|
|
85
|
+
self.figure_loss_curve = None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _assert_posterior(self):
|
|
89
|
+
try:
|
|
90
|
+
p = self.sim.inferer.idata.posterior
|
|
91
|
+
except AttributeError:
|
|
92
|
+
raise GutsBaseError(
|
|
93
|
+
"Using mode='mode' or mode='draws', but sim did not contain a posterior. " +
|
|
94
|
+
"('sim.inferer.idata.posterior'). " + self._parameter_msg
|
|
95
|
+
)
|
|
96
|
+
|
|
66
97
|
|
|
67
98
|
|
|
68
99
|
def _evaluate(self, factor, theta):
|
|
@@ -95,25 +126,128 @@ class ECxEstimator:
|
|
|
95
126
|
sample = {k: v["data"] for k, v in sample.to_dict()["data_vars"].items()}
|
|
96
127
|
return sample
|
|
97
128
|
|
|
98
|
-
def plot_loss_curve(self
|
|
99
|
-
|
|
129
|
+
def plot_loss_curve(self,
|
|
130
|
+
mode: Literal["draws", "mean", "manual"] = "draws",
|
|
131
|
+
draws: Optional[int] = None,
|
|
132
|
+
parameters: Optional[Dict[str,float|List[float]]] = None,
|
|
133
|
+
log_x0: float = 0.0,
|
|
134
|
+
force_draws: bool = False
|
|
135
|
+
):
|
|
136
|
+
"""
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
|
|
140
|
+
mode : Literal['draws', 'mean', 'manual']
|
|
141
|
+
mode of estimation. mode='mean' takes the mean of the posterior and estimate
|
|
142
|
+
the ECx for this singular value. mode='draws' takes samples from the posterior
|
|
143
|
+
and estimate the ECx for each of the parameter draws. mode='manual' takes
|
|
144
|
+
a parameter set (Dict) in the parameters argument and uses that for estimation.
|
|
145
|
+
Default: 'draws'
|
|
146
|
+
|
|
147
|
+
draws : int
|
|
148
|
+
Number of draws to take from the posterior. Only takes effect if mode='draw'.
|
|
149
|
+
Raises an exception if draws < 100, because this is insufficient for a
|
|
150
|
+
reasonable uncertainty estimate. Default: None (using all samples from the
|
|
151
|
+
posterior)
|
|
152
|
+
|
|
153
|
+
parameters : Dict[str,float|list[float]]
|
|
154
|
+
a parameter dictionary passed used as model parameters for finding the ECx
|
|
155
|
+
value. Default: None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
log_x0 : float
|
|
159
|
+
the starting value for the multiplication factor of the exposure profile for
|
|
160
|
+
the minimization algorithm. This value is on the log scale. This means,
|
|
161
|
+
exp(log_x0=0.0) = 1.0, which means that the log_x0=0.0 will start at an
|
|
162
|
+
unmodified exposure profile. Default: 0.0
|
|
163
|
+
|
|
164
|
+
force_draws : bool
|
|
165
|
+
Force the estimate method to accept a number of draws less than 100. Default: False
|
|
166
|
+
|
|
167
|
+
"""
|
|
168
|
+
draws = self._check_mode_and_draws_and_parameters(mode, draws, parameters, force_draws)
|
|
100
169
|
|
|
101
|
-
factor = np.linspace(-2,2, 100)
|
|
102
|
-
y = list(map(partial(self._loss, theta=posterior_mean), factor))
|
|
103
170
|
|
|
171
|
+
factor = np.linspace(-2,2, 100) + log_x0
|
|
104
172
|
fig, ax = plt.subplots(1,1, sharey=True, figsize=(4, 3))
|
|
173
|
+
|
|
174
|
+
for i in tqdm(range(draws)):
|
|
175
|
+
if mode == "draws":
|
|
176
|
+
sample = self._posterior_sample(i)
|
|
177
|
+
elif mode == "mean":
|
|
178
|
+
sample = self._posterior_mean()
|
|
179
|
+
elif mode == "manual":
|
|
180
|
+
sample = parameters
|
|
181
|
+
else:
|
|
182
|
+
raise NotImplementedError(
|
|
183
|
+
f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
y = list(map(partial(self._loss, theta=sample), factor))
|
|
187
|
+
|
|
188
|
+
ax.plot(
|
|
189
|
+
np.exp(factor), y,
|
|
190
|
+
color="black",
|
|
191
|
+
)
|
|
192
|
+
|
|
105
193
|
ax.plot(
|
|
106
|
-
|
|
107
|
-
color="black",
|
|
194
|
+
[], [], color="black",
|
|
108
195
|
label=f"$\ell = S(t={self.time},x_{{in}}=C_{{ext}} \cdot \phi) - {self.x}$"
|
|
109
196
|
)
|
|
110
197
|
ax.set_ylabel("Loss ($\ell$)")
|
|
111
198
|
ax.set_xlabel("Multiplication factor ($\phi$)")
|
|
112
199
|
ax.set_title(f"ID: {self.sim.coordinates['id'][0]}")
|
|
113
|
-
ax.set_ylim(0,
|
|
200
|
+
ax.set_ylim(0, ax.get_ylim()[1] * 1.25)
|
|
114
201
|
ax.legend(frameon=False)
|
|
115
202
|
fig.tight_layout()
|
|
116
203
|
|
|
204
|
+
self.figure_loss_curve = fig
|
|
205
|
+
|
|
206
|
+
def _check_mode_and_draws_and_parameters(self, mode, draws, parameters, force_draws):
|
|
207
|
+
|
|
208
|
+
if mode == "draws":
|
|
209
|
+
self._assert_posterior()
|
|
210
|
+
|
|
211
|
+
if draws is None:
|
|
212
|
+
draws = (
|
|
213
|
+
self.sim.inferer.idata.posterior.sizes["chain"] *
|
|
214
|
+
self.sim.inferer.idata.posterior.sizes["draw"]
|
|
215
|
+
)
|
|
216
|
+
elif draws < 100 and not force_draws:
|
|
217
|
+
raise GutsBaseError(
|
|
218
|
+
"draws must be larger than 100. Preferably > 1000. " +
|
|
219
|
+
f"If you don't want uncertainty assessment of the {self._name} " +
|
|
220
|
+
"estimates, use mode='mean'. If you really want to use less than " +
|
|
221
|
+
"100 draws, use force_draws=True at your own risk."
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
pass
|
|
225
|
+
|
|
226
|
+
warnings.warn(
|
|
227
|
+
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
elif mode == "mean":
|
|
231
|
+
self._assert_posterior()
|
|
232
|
+
|
|
233
|
+
draws = 1
|
|
234
|
+
|
|
235
|
+
warnings.warn(
|
|
236
|
+
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
elif mode == "manual":
|
|
240
|
+
draws = 1
|
|
241
|
+
if parameters is None:
|
|
242
|
+
raise GutsBaseError(self._parameter_msg)
|
|
243
|
+
else:
|
|
244
|
+
raise GutsBaseError(
|
|
245
|
+
f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return draws
|
|
249
|
+
|
|
250
|
+
|
|
117
251
|
def estimate(
|
|
118
252
|
self,
|
|
119
253
|
mode: Literal["draws", "mean", "manual"] = "draws",
|
|
@@ -124,6 +258,7 @@ class ECxEstimator:
|
|
|
124
258
|
optimizer_tol: float = 1e-5,
|
|
125
259
|
method: str = "cobyla",
|
|
126
260
|
show_plot: bool = True,
|
|
261
|
+
force_draws: bool = False,
|
|
127
262
|
**optimizer_kwargs
|
|
128
263
|
):
|
|
129
264
|
"""The minimizer for the EC_x operates on the unbounded linear scale, estimating
|
|
@@ -177,48 +312,15 @@ class ECxEstimator:
|
|
|
177
312
|
show_plot : bool
|
|
178
313
|
Show the results plot of the lpx. Default: True
|
|
179
314
|
|
|
315
|
+
force_draws : bool
|
|
316
|
+
Force the estimate method to accept a number of draws less than 100. Default: False
|
|
317
|
+
|
|
180
318
|
optimizer_kwargs :
|
|
181
319
|
Additional arguments to pass to the optimizer
|
|
182
320
|
|
|
183
321
|
"""
|
|
184
322
|
x0_tries = np.array([0.0, -1.0, 1.0, -2.0, 2.0]) + log_x0
|
|
185
|
-
|
|
186
|
-
if mode == "draws":
|
|
187
|
-
if draws is None:
|
|
188
|
-
draws = (
|
|
189
|
-
self.sim.inferer.idata.posterior.sizes["chain"] *
|
|
190
|
-
self.sim.inferer.idata.posterior.sizes["draw"]
|
|
191
|
-
)
|
|
192
|
-
elif draws < 100:
|
|
193
|
-
raise ValueError(
|
|
194
|
-
"draws must be larger than 100. Preferably > 1000. "
|
|
195
|
-
f"If you don't want uncertainty assessment of the {self._name} "
|
|
196
|
-
"estimates, use mode='mean'"
|
|
197
|
-
)
|
|
198
|
-
else:
|
|
199
|
-
pass
|
|
200
|
-
|
|
201
|
-
warnings.warn(
|
|
202
|
-
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
elif mode == "mean":
|
|
206
|
-
draws = 1
|
|
207
|
-
|
|
208
|
-
warnings.warn(
|
|
209
|
-
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
elif mode == "manual":
|
|
213
|
-
draws = 1
|
|
214
|
-
if parameters is None:
|
|
215
|
-
raise ValueError(
|
|
216
|
-
"parameters need to be provided if mode='manual'"
|
|
217
|
-
)
|
|
218
|
-
else:
|
|
219
|
-
raise NotImplementedError(
|
|
220
|
-
f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
|
|
221
|
-
)
|
|
323
|
+
draws = self._check_mode_and_draws_and_parameters(mode, draws, parameters, force_draws)
|
|
222
324
|
|
|
223
325
|
self._mode = mode
|
|
224
326
|
mult_factor = []
|
|
@@ -310,7 +412,7 @@ class ECxEstimator:
|
|
|
310
412
|
self.sim.dispatch_constructor()
|
|
311
413
|
|
|
312
414
|
if self._mode is None:
|
|
313
|
-
raise
|
|
415
|
+
raise GutsBaseError(
|
|
314
416
|
"Run .estimate() before plot_profile_and_effect()"
|
|
315
417
|
)
|
|
316
418
|
elif self._mode == "mean" or self._mode == "draws":
|
|
@@ -375,6 +477,8 @@ class ECxEstimator:
|
|
|
375
477
|
ax2.set_ylim(0, None)
|
|
376
478
|
fig.tight_layout()
|
|
377
479
|
|
|
480
|
+
self.figure_profile_and_effect = fig
|
|
481
|
+
|
|
378
482
|
self.sim.coordinates["time"] = coordinates_backup
|
|
379
483
|
self.sim.dispatch_constructor()
|
|
380
484
|
|
|
@@ -396,4 +500,10 @@ class LPxEstimator(ECxEstimator):
|
|
|
396
500
|
):
|
|
397
501
|
x_in = sim.model_parameters["x_in"].sel(id=[id])
|
|
398
502
|
time = sim.coordinates["time"][-1]
|
|
399
|
-
super().__init__(
|
|
503
|
+
super().__init__(
|
|
504
|
+
sim=sim,
|
|
505
|
+
effect="survival",
|
|
506
|
+
x=x,
|
|
507
|
+
time=time,
|
|
508
|
+
x_in=x_in
|
|
509
|
+
)
|
guts_base/sim/mempy.py
CHANGED
|
@@ -1,44 +1,64 @@
|
|
|
1
1
|
import pathlib
|
|
2
|
-
from typing import Dict, Optional, Literal
|
|
2
|
+
from typing import Dict, Optional, Literal, Protocol, TypedDict, List
|
|
3
3
|
import re
|
|
4
|
-
|
|
4
|
+
import os
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import xarray as xr
|
|
8
|
-
from pymob import SimulationBase
|
|
9
8
|
from pymob.sim.config import Config, DataVariable, Datastructure
|
|
10
9
|
from pymob.sim.parameters import Param
|
|
11
10
|
from guts_base.sim import GutsBase
|
|
12
|
-
from mempy.model import (
|
|
13
|
-
Model,
|
|
14
|
-
RED_IT,
|
|
15
|
-
RED_SD,
|
|
16
|
-
RED_IT_DA,
|
|
17
|
-
RED_SD_DA,
|
|
18
|
-
RED_IT_IA,
|
|
19
|
-
RED_SD_IA,
|
|
20
|
-
BufferGUTS_IT,
|
|
21
|
-
BufferGUTS_IT_CA,
|
|
22
|
-
BufferGUTS_IT_DA
|
|
23
|
-
)
|
|
24
11
|
|
|
25
12
|
__all__ = [
|
|
26
13
|
"PymobSimulator",
|
|
27
14
|
]
|
|
28
15
|
|
|
16
|
+
class ParamsInfoDict(TypedDict):
|
|
17
|
+
name: str
|
|
18
|
+
min: float
|
|
19
|
+
max: float
|
|
20
|
+
initial: float
|
|
21
|
+
vary: bool
|
|
22
|
+
prior: str
|
|
23
|
+
|
|
24
|
+
class StateVariablesDict(TypedDict):
|
|
25
|
+
dimensions: List[str]
|
|
26
|
+
observed: bool
|
|
27
|
+
y0: List[float]
|
|
28
|
+
|
|
29
|
+
class Model(Protocol):
|
|
30
|
+
extra_dim: Optional[str]
|
|
31
|
+
params_info: Dict[str, ParamsInfoDict]
|
|
32
|
+
state_variables: Dict[str, StateVariablesDict]
|
|
33
|
+
_params_info_defaults: Dict[str, ParamsInfoDict]
|
|
34
|
+
_it_model: bool
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def _rhs_jax():
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def _solver_post_processing():
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def _likelihood_func_jax():
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
|
|
29
49
|
class PymobSimulator(GutsBase):
|
|
30
50
|
|
|
31
51
|
@classmethod
|
|
32
|
-
def
|
|
52
|
+
def from_model_and_dataset(
|
|
33
53
|
cls,
|
|
34
|
-
exposure_data: Dict,
|
|
35
|
-
survival_data: Dict,
|
|
36
54
|
model: Model,
|
|
55
|
+
exposure_data: Dict[str, pd.DataFrame],
|
|
56
|
+
survival_data: pd.DataFrame,
|
|
37
57
|
info_dict: Dict = {},
|
|
38
58
|
pymob_config: Optional[Config] = None,
|
|
39
59
|
output_directory: str|pathlib.Path = pathlib.Path("output/pymob"),
|
|
40
60
|
default_prior: Literal["uniform", "lognorm"] = "lognorm",
|
|
41
|
-
) ->
|
|
61
|
+
) -> "PymobSimulator":
|
|
42
62
|
"""Construct a PymobSimulator from the
|
|
43
63
|
"""
|
|
44
64
|
|
|
@@ -46,7 +66,7 @@ class PymobSimulator(GutsBase):
|
|
|
46
66
|
cfg = Config()
|
|
47
67
|
# Configure: The configuration can be overridden in a subclass to override the
|
|
48
68
|
# configuration
|
|
49
|
-
cls.
|
|
69
|
+
cls._configure(config=cfg)
|
|
50
70
|
else:
|
|
51
71
|
cfg = pymob_config
|
|
52
72
|
|
|
@@ -55,6 +75,14 @@ class PymobSimulator(GutsBase):
|
|
|
55
75
|
|
|
56
76
|
cfg.case_study.output = str(output_directory)
|
|
57
77
|
|
|
78
|
+
# overrides scenario path. This means the scenario is also expected in the
|
|
79
|
+
# same folder
|
|
80
|
+
cfg.case_study.scenario_path_override = str(output_directory)
|
|
81
|
+
cfg.case_study.scenario = output_directory.stem
|
|
82
|
+
cfg.case_study.data = cfg.case_study.output_path
|
|
83
|
+
cfg.case_study.observations = "observations.nc"
|
|
84
|
+
cfg.create_directory(directory="results", force=True)
|
|
85
|
+
|
|
58
86
|
# parse observations
|
|
59
87
|
# obs can be simply subset by selection obs.sel(substance="Exposure-Dime")
|
|
60
88
|
observations = xr.combine_by_coords([
|
|
@@ -62,9 +90,19 @@ class PymobSimulator(GutsBase):
|
|
|
62
90
|
cls._survival_data_to_xarray(survival_data)
|
|
63
91
|
])
|
|
64
92
|
|
|
93
|
+
observations.to_netcdf(
|
|
94
|
+
os.path.join(cfg.case_study.output_path, cfg.case_study.observations)
|
|
95
|
+
)
|
|
96
|
+
|
|
65
97
|
# configure model and likelihood function
|
|
66
|
-
|
|
67
|
-
cfg.
|
|
98
|
+
# extract the fully qualified name of the model module.name
|
|
99
|
+
cfg.simulation.model_class = "{module}.{name}".format(
|
|
100
|
+
module=model.__module__, name=type(model).__name__
|
|
101
|
+
)
|
|
102
|
+
cfg.inference_numpyro.user_defined_error_model = "{module}.{name}".format(
|
|
103
|
+
module=model._likelihood_func_jax.__module__,
|
|
104
|
+
name=model._likelihood_func_jax.__name__
|
|
105
|
+
)
|
|
68
106
|
|
|
69
107
|
# derive data structure and params from the model instance
|
|
70
108
|
cls._set_data_structure(config=cfg, model=model)
|
|
@@ -75,35 +113,37 @@ class PymobSimulator(GutsBase):
|
|
|
75
113
|
cfg.simulation.y0 = [f"{k}={v['y0']}" for k, v in model.state_variables.items() if "y0" in v]
|
|
76
114
|
|
|
77
115
|
# create a simulation object
|
|
116
|
+
# It is essential that all post processing tasks are done in self.setup()
|
|
117
|
+
# which is extended below. This ensures that the simulation can also be run
|
|
118
|
+
# from automated tools like pymob-infer
|
|
78
119
|
sim = cls(config=cfg)
|
|
79
|
-
sim.
|
|
80
|
-
|
|
81
|
-
# initialize
|
|
82
|
-
sim.load_modules()
|
|
83
|
-
sim.set_logger()
|
|
84
|
-
|
|
85
|
-
sim.initialize(input={"observations": observations, "model": model})
|
|
86
|
-
|
|
87
|
-
sim.validate()
|
|
88
|
-
sim.dispatch_constructor()
|
|
89
|
-
|
|
90
|
-
|
|
120
|
+
sim.setup()
|
|
91
121
|
return sim
|
|
92
122
|
|
|
93
|
-
def
|
|
94
|
-
|
|
95
|
-
|
|
123
|
+
def reset_observations(self):
|
|
124
|
+
"""Resets the observations to the original observations after using .from_mempy(...)
|
|
125
|
+
This also resets the sim.coordinates dictionary.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
self.observations = self._obs_backup
|
|
96
129
|
|
|
97
|
-
|
|
130
|
+
def setup(self, **evaluator_kwargs):
|
|
131
|
+
super().setup(**evaluator_kwargs)
|
|
132
|
+
self._obs_backup = self.observations.copy(deep=True)
|
|
98
133
|
|
|
99
134
|
|
|
100
135
|
@classmethod
|
|
101
|
-
def
|
|
136
|
+
def _configure(cls, config: Config):
|
|
102
137
|
"""This is normally set in the configuration file passed to a SimulationBase class.
|
|
103
138
|
Since the mempy to pymob converter initializes pymob.SimulationBase from scratch
|
|
104
139
|
(without using a config file), the necessary settings have to be specified here.
|
|
105
140
|
"""
|
|
106
141
|
config.case_study.output = "results"
|
|
142
|
+
config.case_study.simulation = "PymobSimulator"
|
|
143
|
+
|
|
144
|
+
# this must be named guts_base, whihc is the name of the pip package and
|
|
145
|
+
# this regulates which packages are loaded.
|
|
146
|
+
config.case_study.name = "guts_base"
|
|
107
147
|
|
|
108
148
|
config.simulation.x_dimension = "time"
|
|
109
149
|
config.simulation.batch_dimension = "id"
|
|
@@ -132,33 +172,6 @@ class PymobSimulator(GutsBase):
|
|
|
132
172
|
config.inference_numpyro.svi_iterations = 10_000
|
|
133
173
|
config.inference_numpyro.svi_learning_rate = 0.001
|
|
134
174
|
|
|
135
|
-
@staticmethod
|
|
136
|
-
def _exposure_data_to_xarray(exposure_data: Dict[str, pd.DataFrame], dim: str):
|
|
137
|
-
"""
|
|
138
|
-
TODO: Currently no rect interpolation
|
|
139
|
-
"""
|
|
140
|
-
arrays = {}
|
|
141
|
-
for key, df in exposure_data.items():
|
|
142
|
-
# this override is necessary to make all dimensions work out
|
|
143
|
-
df.index.name = "time"
|
|
144
|
-
arrays.update({
|
|
145
|
-
key: df.to_xarray().to_dataarray(dim="id", name=key)
|
|
146
|
-
})
|
|
147
|
-
|
|
148
|
-
exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
|
|
149
|
-
exposure_array = exposure_array.transpose("id", "time", ...)
|
|
150
|
-
return xr.Dataset({"exposure": exposure_array})
|
|
151
|
-
|
|
152
|
-
@staticmethod
|
|
153
|
-
def _survival_data_to_xarray(survival_data: pd.DataFrame):
|
|
154
|
-
# TODO: survival name is currently not kept because the raw data is not transferred from the survival
|
|
155
|
-
survival_data.index.name = "time"
|
|
156
|
-
|
|
157
|
-
survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
|
|
158
|
-
survival_array = survival_array.transpose("id", "time", ...)
|
|
159
|
-
arrays = {"survival": survival_array}
|
|
160
|
-
return xr.Dataset(arrays)
|
|
161
|
-
|
|
162
175
|
@classmethod
|
|
163
176
|
def _set_data_structure(cls, config: Config, model: Model):
|
|
164
177
|
"""Takes a dictionary that is specified in the model and uses only keys that
|
|
@@ -179,10 +192,7 @@ class PymobSimulator(GutsBase):
|
|
|
179
192
|
def _set_params(cls, config: Config, model: Model, default_prior: str):
|
|
180
193
|
params_info = model.params_info
|
|
181
194
|
|
|
182
|
-
if
|
|
183
|
-
RED_IT, RED_IT_DA, RED_IT_IA,
|
|
184
|
-
BufferGUTS_IT, BufferGUTS_IT_CA, BufferGUTS_IT_DA
|
|
185
|
-
)):
|
|
195
|
+
if model._it_model:
|
|
186
196
|
eps = config.jaxsolver.atol * 10
|
|
187
197
|
params_info["eps"] = {'name':'eps', 'initial':eps, 'vary':False}
|
|
188
198
|
|
|
@@ -210,6 +220,11 @@ class PymobSimulator(GutsBase):
|
|
|
210
220
|
_init = group["initial"].values.astype(float)
|
|
211
221
|
_free = group["vary"].values
|
|
212
222
|
|
|
223
|
+
if isinstance(_min, np.ma.core.MaskedConstant):
|
|
224
|
+
_min = None
|
|
225
|
+
if isinstance(_max, np.ma.core.MaskedConstant):
|
|
226
|
+
_max = None
|
|
227
|
+
|
|
213
228
|
# TODO: allow for parsing one N-D prior from multiple priors
|
|
214
229
|
# TODO: Another choice would be to parse vary=False priors as deterministic
|
|
215
230
|
# and use a composite prior from a deterministic and a free prior as
|
guts_base/sim/report.py
CHANGED
guts_base/sim/utils.py
ADDED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: guts_base
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Basic GUTS model implementation in pymob
|
|
5
5
|
Author-email: Florian Schunck <fluncki@protonmail.com>
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -692,8 +692,7 @@ License-File: LICENSE
|
|
|
692
692
|
Requires-Dist: openpyxl>=3.1.3
|
|
693
693
|
Requires-Dist: Bottleneck>=1.5.0
|
|
694
694
|
Requires-Dist: expydb>=0.6.0
|
|
695
|
-
Requires-Dist:
|
|
696
|
-
Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.4.1
|
|
695
|
+
Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.5.10
|
|
697
696
|
Provides-Extra: dev
|
|
698
697
|
Requires-Dist: pytest>=7.3; extra == "dev"
|
|
699
698
|
Requires-Dist: bumpver; extra == "dev"
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
guts_base/__init__.py,sha256=ulUbdy3eks46Fl2VAVRRymWiAmzDOj_pqRCBpItSnMI,227
|
|
2
|
+
guts_base/mod.py,sha256=AzOCg1A8FP5EtVfp-66HT7G7h_wnHkenieaxTc9qCyk,5796
|
|
3
|
+
guts_base/plot.py,sha256=rjIzyPUikOtubjWI5mX_ZY8ARWl2NiHDmwAZAdoaPS8,5693
|
|
4
|
+
guts_base/prob.py,sha256=ITwo5dAGMHr5xTldilHMbKU6AFsWo4_ZwbfaXh97Gew,5443
|
|
5
|
+
guts_base/data/__init__.py,sha256=JBgft1DTledwvB5hRZnyGiKWv-RXo1OIpb5kJXloOmo,826
|
|
6
|
+
guts_base/data/expydb.py,sha256=Kcc6CeZMl3oEelk5UBN9VEfwgNF3CzTh13ooVkufAjE,8218
|
|
7
|
+
guts_base/data/generator.py,sha256=rGOZU3B0Ho8V6KtfjcAmely8lnlqNFV8cRyGboayTRc,2910
|
|
8
|
+
guts_base/data/openguts.py,sha256=WvhYl_AOdvNgzrcVS2f_PYbXNH_wSAz2uIBSR6BMSh0,11078
|
|
9
|
+
guts_base/data/preprocessing.py,sha256=qggYkx2x62ingU1BNhJFyL1eQdFQsDJR2lefVfVWW2U,1732
|
|
10
|
+
guts_base/data/survival.py,sha256=U-Ehloo8vnD81VeIglXLEUHX9lt7SjtEs2YEB0D9FHE,5096
|
|
11
|
+
guts_base/data/time_of_death.py,sha256=hwngUwfRP3u8WmD3dHyXrphuu5d8ZJTKyBovGRwAHNQ,21014
|
|
12
|
+
guts_base/data/utils.py,sha256=u3gGDJK15MfRUP4iIxsS-I1oqxD2qH_ugsT7o_Eac18,236
|
|
13
|
+
guts_base/sim/__init__.py,sha256=5VgzsOXjMsylfWy4neeovZBG1w6G6p_tFaIfhnIeLPM,415
|
|
14
|
+
guts_base/sim/base.py,sha256=7u7gj6HQSCtFUXrrX-ngfWsWnFEMk5NrDVJH0CIolAs,28960
|
|
15
|
+
guts_base/sim/constructors.py,sha256=Kz9FHIH3EHsSIKd9sQgHa3eveniFifFlk1Hf-QR69Pg,875
|
|
16
|
+
guts_base/sim/ecx.py,sha256=F2ImQbtJjP57udebY2c_vtOf4gB3yKSje0P8s9VIbtQ,18858
|
|
17
|
+
guts_base/sim/mempy.py,sha256=IHd87UrmdXpC7y7q0IjYQJH075frjbp2a-dMVBeqZ0U,10164
|
|
18
|
+
guts_base/sim/report.py,sha256=swTAp1qNTjDwWh-YOq_TLqtSxjQvjdwaNVHpEW8kXK4,1959
|
|
19
|
+
guts_base/sim/utils.py,sha256=Qj_FPH6kywVxOwgCerS7w5CyuYR9HKmvBWFpmxwDFgk,256
|
|
20
|
+
guts_base-1.0.0.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
21
|
+
guts_base-1.0.0.dist-info/METADATA,sha256=g8sJM482NTXEGv7dJKbLVujO_sRzkpg3fuHadmDMdhc,45406
|
|
22
|
+
guts_base-1.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
guts_base-1.0.0.dist-info/entry_points.txt,sha256=icsHzG2jQ90ZS7XvLsI5Qj0-qGuPv2T0RBVN5daGCPU,183
|
|
24
|
+
guts_base-1.0.0.dist-info/top_level.txt,sha256=PxhBgUd4r39W_VI4FyJjARwKbV5_glgCVnd6v_zAGdE,10
|
|
25
|
+
guts_base-1.0.0.dist-info/RECORD,,
|
guts_base/sim.py
DELETED
|
File without changes
|
guts_base-0.8.6.dist-info/RECORD
DELETED
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
guts_base/__init__.py,sha256=dYv5zH5jYPdPZXIzku6YPkvB47XD8qaTlY-0olh-4cA,208
|
|
2
|
-
guts_base/mod.py,sha256=aLb09E-6CwgKgq0GbwIT_Alv167cVFA5zAxwx80F1aE,7454
|
|
3
|
-
guts_base/plot.py,sha256=rjIzyPUikOtubjWI5mX_ZY8ARWl2NiHDmwAZAdoaPS8,5693
|
|
4
|
-
guts_base/prob.py,sha256=k1b-gc8yZNg8DMnxJHJzCy8Pud8_qEygGkrnmS2AcMY,14004
|
|
5
|
-
guts_base/sim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
-
guts_base/data/__init__.py,sha256=U5sJJgPwEVHVGTLeZNwl8VKLIgn1twpDNkOCMRSV1Nk,808
|
|
7
|
-
guts_base/data/expydb.py,sha256=Kcc6CeZMl3oEelk5UBN9VEfwgNF3CzTh13ooVkufAjE,8218
|
|
8
|
-
guts_base/data/generator.py,sha256=fos7y-xxTlBVpYsHD_2kdYSH40KILPUQ_Jy0mwjO5jc,2897
|
|
9
|
-
guts_base/data/openguts.py,sha256=WvhYl_AOdvNgzrcVS2f_PYbXNH_wSAz2uIBSR6BMSh0,11078
|
|
10
|
-
guts_base/data/preprocessing.py,sha256=qggYkx2x62ingU1BNhJFyL1eQdFQsDJR2lefVfVWW2U,1732
|
|
11
|
-
guts_base/data/survival.py,sha256=yNNZ7MabzrMeqPHOa8-LPxCgsIN4_X0aj_ANmeC9Wd8,4878
|
|
12
|
-
guts_base/data/time_of_death.py,sha256=hwngUwfRP3u8WmD3dHyXrphuu5d8ZJTKyBovGRwAHNQ,21014
|
|
13
|
-
guts_base/data/utils.py,sha256=u3gGDJK15MfRUP4iIxsS-I1oqxD2qH_ugsT7o_Eac18,236
|
|
14
|
-
guts_base/sim/__init__.py,sha256=vgZu2oi4RHIQuLS19570xBdLQasL75puxabGrbLuIGA,276
|
|
15
|
-
guts_base/sim/base.py,sha256=lxfHQ57pXKz1ZGnFqOIoeUplzvFWM1OjDX0xYQpSgaw,20921
|
|
16
|
-
guts_base/sim/ecx.py,sha256=IvZQnmMs3FLX6unuj7mPv2AaZ7l-S4TJmYmPI4fXE1k,14573
|
|
17
|
-
guts_base/sim/mempy.py,sha256=xsA3Q7jQrdcmQZJzKBdQnr8yZk-f36YgbLlV-VvJpxg,9514
|
|
18
|
-
guts_base/sim/report.py,sha256=D_6eoSvjdcsDtB0PlmNkzC1LRk_7WXzb_T6D6AZjQNE,1985
|
|
19
|
-
guts_base-0.8.6.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
20
|
-
guts_base-0.8.6.dist-info/METADATA,sha256=n0OWLZd2CZWg7j7M2sqcXfFuiOSPqSG60KvQ9IuoqHM,45437
|
|
21
|
-
guts_base-0.8.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
-
guts_base-0.8.6.dist-info/entry_points.txt,sha256=icsHzG2jQ90ZS7XvLsI5Qj0-qGuPv2T0RBVN5daGCPU,183
|
|
23
|
-
guts_base-0.8.6.dist-info/top_level.txt,sha256=PxhBgUd4r39W_VI4FyJjARwKbV5_glgCVnd6v_zAGdE,10
|
|
24
|
-
guts_base-0.8.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|