guts-base 2.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guts_base/__init__.py +15 -0
- guts_base/data/__init__.py +35 -0
- guts_base/data/expydb.py +248 -0
- guts_base/data/generator.py +191 -0
- guts_base/data/openguts.py +296 -0
- guts_base/data/preprocessing.py +55 -0
- guts_base/data/survival.py +148 -0
- guts_base/data/time_of_death.py +595 -0
- guts_base/data/utils.py +8 -0
- guts_base/mod.py +332 -0
- guts_base/plot.py +201 -0
- guts_base/prob/__init__.py +13 -0
- guts_base/prob/binom.py +18 -0
- guts_base/prob/conditional_binom.py +118 -0
- guts_base/prob/conditional_binom_mv.py +233 -0
- guts_base/prob/predictions.py +164 -0
- guts_base/sim/__init__.py +28 -0
- guts_base/sim/base.py +1286 -0
- guts_base/sim/config.py +170 -0
- guts_base/sim/constructors.py +31 -0
- guts_base/sim/ecx.py +585 -0
- guts_base/sim/mempy.py +290 -0
- guts_base/sim/report.py +405 -0
- guts_base/sim/transformer.py +548 -0
- guts_base/sim/units.py +313 -0
- guts_base/sim/utils.py +10 -0
- guts_base-2.0.0b0.dist-info/METADATA +853 -0
- guts_base-2.0.0b0.dist-info/RECORD +32 -0
- guts_base-2.0.0b0.dist-info/WHEEL +5 -0
- guts_base-2.0.0b0.dist-info/entry_points.txt +3 -0
- guts_base-2.0.0b0.dist-info/licenses/LICENSE +674 -0
- guts_base-2.0.0b0.dist-info/top_level.txt +1 -0
guts_base/sim/mempy.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
from typing import Dict, Optional, Literal, Protocol, TypedDict, List, Tuple
|
|
3
|
+
import re
|
|
4
|
+
import os
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.ma.core import MaskedConstant
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import xarray as xr
|
|
9
|
+
from pymob.sim.config import Config, DataVariable, Datastructure
|
|
10
|
+
from pymob.sim.config import Param
|
|
11
|
+
from guts_base.sim import GutsBase
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"PymobSimulator",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
class ParamsInfoDict(TypedDict):
|
|
18
|
+
name: str
|
|
19
|
+
min: float
|
|
20
|
+
max: float
|
|
21
|
+
initial: float
|
|
22
|
+
vary: bool
|
|
23
|
+
prior: str
|
|
24
|
+
dims: Tuple[str]
|
|
25
|
+
unit: str|List[str]
|
|
26
|
+
module: str
|
|
27
|
+
|
|
28
|
+
class StateVariablesDict(TypedDict):
|
|
29
|
+
dimensions: List[str]
|
|
30
|
+
observed: bool
|
|
31
|
+
y0: List[float]
|
|
32
|
+
|
|
33
|
+
class Model(Protocol):
|
|
34
|
+
extra_dim: str
|
|
35
|
+
params_info: Dict[str, ParamsInfoDict]
|
|
36
|
+
state_variables: Dict[str, StateVariablesDict]
|
|
37
|
+
_params_info_defaults: Dict[str, ParamsInfoDict]
|
|
38
|
+
_it_model: bool
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def _rhs_jax():
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def _solver_post_processing():
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def _likelihood_func_jax():
|
|
50
|
+
raise NotImplementedError
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class PymobSimulator(GutsBase):
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_model_and_dataset(
|
|
57
|
+
cls,
|
|
58
|
+
model: Model,
|
|
59
|
+
exposure_data: Dict[str, pd.DataFrame],
|
|
60
|
+
survival_data: pd.DataFrame,
|
|
61
|
+
info_dict: Dict = {},
|
|
62
|
+
pymob_config: Optional[Config] = None,
|
|
63
|
+
output_directory: str|pathlib.Path = pathlib.Path("output/pymob"),
|
|
64
|
+
default_prior: Literal["uniform", "lognorm"] = "lognorm",
|
|
65
|
+
) -> "PymobSimulator":
|
|
66
|
+
"""Construct a PymobSimulator from the
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
if pymob_config is None:
|
|
70
|
+
cfg = Config()
|
|
71
|
+
# Configure: The configuration can be overridden in a subclass to override the
|
|
72
|
+
# configuration
|
|
73
|
+
cls._configure(config=cfg)
|
|
74
|
+
else:
|
|
75
|
+
cfg = pymob_config
|
|
76
|
+
|
|
77
|
+
if isinstance(output_directory, str):
|
|
78
|
+
output_directory = pathlib.Path(output_directory)
|
|
79
|
+
|
|
80
|
+
cfg.case_study.output = str(output_directory)
|
|
81
|
+
|
|
82
|
+
# overrides scenario path. This means the scenario is also expected in the
|
|
83
|
+
# same folder
|
|
84
|
+
cfg.case_study.scenario_path_override = str(output_directory)
|
|
85
|
+
cfg.case_study.scenario = output_directory.stem
|
|
86
|
+
cfg.case_study.data = cfg.case_study.output_path
|
|
87
|
+
cfg.case_study.observations = "observations.nc"
|
|
88
|
+
cfg.create_directory(directory="results", force=True)
|
|
89
|
+
|
|
90
|
+
obs = cls._observations_from_dataframes(
|
|
91
|
+
exposure_data=exposure_data,
|
|
92
|
+
survival_data=survival_data,
|
|
93
|
+
exposure_dim=model.extra_dim,
|
|
94
|
+
unit_input=cfg.guts_base.unit_input,
|
|
95
|
+
unit_time=cfg.guts_base.unit_time,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
obs.to_netcdf(os.path.join(cfg.case_study.output_path, cfg.case_study.observations))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# configure model and likelihood function
|
|
102
|
+
# extract the fully qualified name of the model module.name
|
|
103
|
+
if isinstance(model, type):
|
|
104
|
+
raise TypeError(
|
|
105
|
+
f"model '{model.__name__}' must be initialized. Initialize with "+
|
|
106
|
+
f"{model.__name__}(). Dont forget to specifiy the number of exposures "+
|
|
107
|
+
"with e.g. RED_SD_DA(num_expos=2) if your model has two exposures."
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
cfg.simulation.model_class = "{module}.{name}".format(
|
|
113
|
+
module=model.__module__, name=type(model).__name__
|
|
114
|
+
)
|
|
115
|
+
cfg.inference_numpyro.user_defined_error_model = "{module}.{name}".format(
|
|
116
|
+
module=model._likelihood_func_jax.__module__,
|
|
117
|
+
name=model._likelihood_func_jax.__name__
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# derive data structure and params from the model instance
|
|
121
|
+
cls._set_data_structure(config=cfg, model=model)
|
|
122
|
+
cls._set_params(config=cfg, model=model, default_prior=default_prior)
|
|
123
|
+
|
|
124
|
+
# configure starting values and input
|
|
125
|
+
cfg.simulation.x_in = ["exposure=exposure"]
|
|
126
|
+
cfg.simulation.y0 = [f"{k}={v['y0']}" for k, v in model.state_variables.items() if "y0" in v]
|
|
127
|
+
|
|
128
|
+
cfg.guts_base.background_mortality_parameters = cls._get_background_mortality_params(model)
|
|
129
|
+
|
|
130
|
+
# create a simulation object
|
|
131
|
+
# It is essential that all post processing tasks are done in self.setup()
|
|
132
|
+
# which is extended below. This ensures that the simulation can also be run
|
|
133
|
+
# from automated tools like pymob-infer
|
|
134
|
+
sim = cls(config=cfg)
|
|
135
|
+
sim.setup()
|
|
136
|
+
return sim
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def setup(self, **evaluator_kwargs):
|
|
140
|
+
super().setup(**evaluator_kwargs)
|
|
141
|
+
self._obs_backup = self.observations.copy(deep=True)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def _configure(cls, config: Config):
|
|
146
|
+
"""This is normally set in the configuration file passed to a SimulationBase class.
|
|
147
|
+
Since the mempy to pymob converter initializes pymob.SimulationBase from scratch
|
|
148
|
+
(without using a config file), the necessary settings have to be specified here.
|
|
149
|
+
"""
|
|
150
|
+
config.case_study.output = "results"
|
|
151
|
+
config.case_study.simulation = "PymobSimulator"
|
|
152
|
+
|
|
153
|
+
# this must be named guts_base, whihc is the name of the pip package and
|
|
154
|
+
# this regulates which packages are loaded.
|
|
155
|
+
config.case_study.name = "guts_base"
|
|
156
|
+
|
|
157
|
+
config.simulation.x_dimension = "time"
|
|
158
|
+
config.simulation.batch_dimension = "id"
|
|
159
|
+
config.simulation.solver_post_processing = None
|
|
160
|
+
|
|
161
|
+
# this is the registered guts-base section
|
|
162
|
+
# No longer necessary, because these are saved as defaults
|
|
163
|
+
# config.simulation.unit_time = "day"
|
|
164
|
+
# config.simulation.n_reindexed_x = 100
|
|
165
|
+
# config.simulation.forward_interpolate_exposure_data = True
|
|
166
|
+
|
|
167
|
+
config.inference.extra_vars = ["eps", "survivors_before_t", "survivors_at_start"]
|
|
168
|
+
config.inference.n_predictions = 100
|
|
169
|
+
|
|
170
|
+
config.jaxsolver.diffrax_solver = "Tsit5"
|
|
171
|
+
config.jaxsolver.rtol = 1e-10
|
|
172
|
+
config.jaxsolver.atol = 1e-12
|
|
173
|
+
config.jaxsolver.throw_exception = True
|
|
174
|
+
config.jaxsolver.pcoeff = 0.3
|
|
175
|
+
config.jaxsolver.icoeff = 0.3
|
|
176
|
+
config.jaxsolver.dcoeff = 0.0
|
|
177
|
+
config.jaxsolver.max_steps = 1000000
|
|
178
|
+
config.jaxsolver.throw_exception = True
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
config.inference_numpyro.gaussian_base_distribution = True
|
|
182
|
+
config.inference_numpyro.kernel = "svi"
|
|
183
|
+
config.inference_numpyro.init_strategy = "init_to_median"
|
|
184
|
+
config.inference_numpyro.svi_iterations = 10_000
|
|
185
|
+
config.inference_numpyro.svi_learning_rate = 0.001
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def _set_data_structure(cls, config: Config, model: Model):
|
|
189
|
+
"""Takes a dictionary that is specified in the model and uses only keys that
|
|
190
|
+
are fields of the DataVariable config-model"""
|
|
191
|
+
|
|
192
|
+
state_dict = model.state_variables
|
|
193
|
+
|
|
194
|
+
config.data_structure = Datastructure.model_validate({
|
|
195
|
+
key: DataVariable.model_validate({
|
|
196
|
+
k: v for k, v in state_info.items()
|
|
197
|
+
if k in DataVariable.model_fields
|
|
198
|
+
})
|
|
199
|
+
for key, state_info in state_dict.items()
|
|
200
|
+
})
|
|
201
|
+
|
|
202
|
+
@staticmethod
|
|
203
|
+
def _get_background_mortality_params(model: Model):
|
|
204
|
+
return [k for k, v in model.params_info.items() if v["module"] == "background-mortality"]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def _set_params(cls, config: Config, model: Model, default_prior: str):
|
|
209
|
+
params_info = model.params_info
|
|
210
|
+
|
|
211
|
+
if model._it_model:
|
|
212
|
+
eps = config.jaxsolver.atol * 10
|
|
213
|
+
params_info["eps"] = {'name':'eps', 'initial':eps, 'vary':False}
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
for par, param_dict in params_info.items():
|
|
217
|
+
for k, v in model._params_info_defaults.items():
|
|
218
|
+
if k not in param_dict:
|
|
219
|
+
param_dict.update({k:v})
|
|
220
|
+
|
|
221
|
+
param_df = pd.DataFrame(params_info).T
|
|
222
|
+
param_df["param_index"] = param_df.name.apply(lambda x: re.findall(r"\d+", x))
|
|
223
|
+
param_df["param_index"] = param_df.param_index.apply(lambda x: int(x[0])-1 if len(x) == 1 else None)
|
|
224
|
+
param_df["name"] = param_df.name.apply(lambda x: re.sub(r"\d+", "", x).strip("_"))
|
|
225
|
+
|
|
226
|
+
for (param_name, ), group in param_df.groupby(["name"]):
|
|
227
|
+
|
|
228
|
+
dims = list(dict.fromkeys(group["dims"]))
|
|
229
|
+
dims = tuple([]) if dims == [None] else tuple(dims)
|
|
230
|
+
|
|
231
|
+
prior = list(dict.fromkeys(group["prior"]))
|
|
232
|
+
prior = prior[0] if len(prior) == 1 else prior
|
|
233
|
+
|
|
234
|
+
_min = np.min(np.ma.masked_invalid(group["min"].values.astype(float)))
|
|
235
|
+
_max = np.max(np.ma.masked_invalid(group["max"].values.astype(float)))
|
|
236
|
+
_init = np.array(group["initial"].values.astype(float))
|
|
237
|
+
_free = np.array(group["vary"].values)
|
|
238
|
+
|
|
239
|
+
unit = list(dict.fromkeys(group["unit"]))
|
|
240
|
+
unit = unit[0] if len(unit) == 1 else unit
|
|
241
|
+
|
|
242
|
+
if isinstance(_min, MaskedConstant):
|
|
243
|
+
_min = None
|
|
244
|
+
if isinstance(_max, MaskedConstant):
|
|
245
|
+
_max = None
|
|
246
|
+
|
|
247
|
+
# TODO: allow for parsing one N-D prior from multiple priors
|
|
248
|
+
# TODO: Another choice would be to parse vary=False priors as deterministic
|
|
249
|
+
# and use a composite prior from a deterministic and a free prior as
|
|
250
|
+
# the input into the model
|
|
251
|
+
|
|
252
|
+
if prior is None:
|
|
253
|
+
if _min is None or _max is None:
|
|
254
|
+
prior = None
|
|
255
|
+
elif default_prior == "uniform":
|
|
256
|
+
_loc = _init * np.logical_not(_free) + _min * _free - config.jaxsolver.atol * 10 * np.logical_not(_free)
|
|
257
|
+
_scale = _init * np.logical_not(_free) + _max * _free + config.jaxsolver.atol * 10 * np.logical_not(_free)
|
|
258
|
+
_loc = _loc[0] if len(_loc) == 1 else _loc
|
|
259
|
+
_scale = _scale[0] if len(_scale) == 1 else _scale
|
|
260
|
+
prior = f"uniform(loc={_loc},scale={_scale})"
|
|
261
|
+
elif default_prior == "lognorm":
|
|
262
|
+
_s = 3 * _free + config.jaxsolver.atol * 10 * np.logical_not(_free)
|
|
263
|
+
_init = _init[0] if len(_init) == 1 else _init
|
|
264
|
+
_s = _s[0] if len(_s) == 1 else _s
|
|
265
|
+
|
|
266
|
+
prior = f"lognorm(scale={_init},s={_s})"
|
|
267
|
+
else:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"Default prior: '{default_prior}' is not implemented. "+
|
|
270
|
+
"Use one of 'uniform', 'lognorm' or specify priors for each "+
|
|
271
|
+
"parameter directly with: "+
|
|
272
|
+
f"`model.params_dict['prior'] = {default_prior}(...)`"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if prior is not None:
|
|
276
|
+
prior = prior.replace(" ", ",")
|
|
277
|
+
|
|
278
|
+
# if isinstance(value,float):
|
|
279
|
+
param = Param.model_validate(dict(
|
|
280
|
+
value=_init,
|
|
281
|
+
free=np.max(_free),
|
|
282
|
+
min=_min,
|
|
283
|
+
max=_max,
|
|
284
|
+
prior=prior,
|
|
285
|
+
dims=dims,
|
|
286
|
+
unit=unit,
|
|
287
|
+
))
|
|
288
|
+
|
|
289
|
+
setattr(config.model_parameters, param_name, param)
|
|
290
|
+
|
guts_base/sim/report.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
import os
|
|
3
|
+
import itertools as it
|
|
4
|
+
from typing import List, Dict, Literal, Optional, Union, TYPE_CHECKING
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import xarray as xr
|
|
8
|
+
import arviz as az
|
|
9
|
+
from matplotlib import pyplot as plt
|
|
10
|
+
import pint
|
|
11
|
+
|
|
12
|
+
from pymob import SimulationBase
|
|
13
|
+
from pymob.sim.report import Report, reporting
|
|
14
|
+
from pymob.sim.config import Config, string_to_dict
|
|
15
|
+
from pymob.inference.analysis import round_to_sigfig, format_parameter
|
|
16
|
+
|
|
17
|
+
from guts_base.plot import plot_survival_multipanel, plot_exposure_multipanel
|
|
18
|
+
from guts_base.sim.ecx import ECxEstimator
|
|
19
|
+
from guts_base.sim.config import GutsBaseConfig
|
|
20
|
+
from guts_base.sim import units
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from guts_base.sim import GutsBase
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GutsReport(Report):
|
|
27
|
+
# ecx_estimates_times: List = [1, 2, 4, 10]
|
|
28
|
+
# ecx_estimates_x: List = [0.1, 0.25, 0.5, 0.75, 0.9]
|
|
29
|
+
# ecx_draws: int = 250
|
|
30
|
+
# ecx_force_draws: bool = False
|
|
31
|
+
# set_background_mortality_to_zero = True
|
|
32
|
+
# table_parameter_stat_focus = "mean"
|
|
33
|
+
# units = xr.Dataset({
|
|
34
|
+
# "metric": ["unit"],
|
|
35
|
+
# "k_d": ("metric", ["1/T"])
|
|
36
|
+
# })
|
|
37
|
+
# format_unit = "~P"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def additional_reports(self, sim: "GutsBase"):
|
|
41
|
+
super().additional_reports(sim=sim)
|
|
42
|
+
self.model_input(sim)
|
|
43
|
+
self.model_fits(sim)
|
|
44
|
+
self.LCx_estimates(sim)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def units(self):
|
|
48
|
+
return units.derive_explicit_units(
|
|
49
|
+
config=self.config,
|
|
50
|
+
unit=self.observations["unit"]
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@reporting
|
|
55
|
+
def model_input(self, sim: "GutsBase"):
|
|
56
|
+
self._write("### Exposure conditions")
|
|
57
|
+
self._write(
|
|
58
|
+
"These are the exposure conditions that were assumed for parameter inference. "+
|
|
59
|
+
"Double check if they are aligned with your expectations. Especially short " +
|
|
60
|
+
"exposure durations may not be perceivable in this view. In this case it is "+
|
|
61
|
+
"recommended to have a look at the exposure conditions in the numerical "+
|
|
62
|
+
"tables provided below."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
out_mp = plot_exposure_multipanel(
|
|
66
|
+
sim=sim,
|
|
67
|
+
results=sim.model_parameters["x_in"],
|
|
68
|
+
ncols=6,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
lab = self._label.format(placeholder='exposure')
|
|
72
|
+
self._write(f"})")
|
|
73
|
+
|
|
74
|
+
return out_mp
|
|
75
|
+
|
|
76
|
+
@reporting
|
|
77
|
+
def model_fits(self, sim: "GutsBase"):
|
|
78
|
+
self._write("### Survival model fits")
|
|
79
|
+
|
|
80
|
+
self._write(
|
|
81
|
+
"Survival observations on the unit scale with model fits. The solid line is "+
|
|
82
|
+
"the average of individual survival probability predictions from multiple "+
|
|
83
|
+
"draws from the posterior parameter distribution. In case a point estimator "+
|
|
84
|
+
"was used the solid line indicates the best fit. Grey uncertainty intervals "+
|
|
85
|
+
"indicate the uncertainty in survival probabilities. Note that the survival "+
|
|
86
|
+
"probabilities indicate the probability for a given individual or population "+
|
|
87
|
+
"to be alive when observed at time t."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
if sim._exclude_controls_after_fixing_background_mortality: #type: ignore
|
|
92
|
+
lab = self._label.format(placeholder='survival_fits_controls')
|
|
93
|
+
|
|
94
|
+
# use the plot that is generated in GutsBase.estimate_background_mortality
|
|
95
|
+
self._write(f"![Surival model fits (control treatments).\label{{{lab}}}]"+
|
|
96
|
+
"(survival_multipanel_control_treatments.png)")
|
|
97
|
+
|
|
98
|
+
out_mp = plot_survival_multipanel(
|
|
99
|
+
sim=sim,
|
|
100
|
+
results=sim.inferer.idata.posterior_model_fits,
|
|
101
|
+
ncols=6,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
lab = self._label.format(placeholder='survival_fits')
|
|
105
|
+
self._write(f"})")
|
|
106
|
+
|
|
107
|
+
return out_mp
|
|
108
|
+
|
|
109
|
+
@reporting
|
|
110
|
+
def LCx_estimates(self, sim: "GutsBase"):
|
|
111
|
+
X = self.config.guts_base.ecx_estimates_x
|
|
112
|
+
T = self.config.guts_base.ecx_estimates_times
|
|
113
|
+
P = sim.predefined_scenarios()
|
|
114
|
+
|
|
115
|
+
if self.config.guts_base.ecx_set_background_mortality_to_zero:
|
|
116
|
+
conditions = {p: 0.0 for p in self.config.guts_base.background_mortality_parameters}
|
|
117
|
+
|
|
118
|
+
estimates = pd.DataFrame(
|
|
119
|
+
it.product(X, T, P.keys()),
|
|
120
|
+
columns=["x", "time", "scenario"]
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
ecx = []
|
|
124
|
+
|
|
125
|
+
for i, row in estimates.iterrows():
|
|
126
|
+
msg = "Estimating EC_{x}(t={t} {tu}) for scenario '{s}'".format(
|
|
127
|
+
x=row.x*100, t=row.time, s=row.scenario, tu=self.config.guts_base.unit_time
|
|
128
|
+
)
|
|
129
|
+
print(msg)
|
|
130
|
+
print("=" * len(msg), end="\n")
|
|
131
|
+
|
|
132
|
+
ecx_estimator = ECxEstimator(
|
|
133
|
+
sim=sim,
|
|
134
|
+
effect="survival",
|
|
135
|
+
x=row.x,
|
|
136
|
+
time=row.time,
|
|
137
|
+
x_in=P[row.scenario],
|
|
138
|
+
conditions_posterior=conditions
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# find a good starting value
|
|
142
|
+
ecx_estimator.plot_loss_curve(
|
|
143
|
+
mode="mean",
|
|
144
|
+
log_x0=0,
|
|
145
|
+
# this is a huge interval
|
|
146
|
+
log_interval_radius=20,
|
|
147
|
+
log_interval_num=100
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
plt.close(ecx_estimator.figure_loss_curve)
|
|
151
|
+
|
|
152
|
+
ecx_estimator.estimate(
|
|
153
|
+
mode=sim.config.guts_base.ecx_mode,
|
|
154
|
+
log_x0=float(np.mean(ecx_estimator.ecx_candidates)),
|
|
155
|
+
draws=self.config.guts_base.ecx_draws,
|
|
156
|
+
force_draws=self.config.guts_base.ecx_force_draws,
|
|
157
|
+
show_plot=False
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
ecx.append(ecx_estimator.results.copy(deep=True))
|
|
161
|
+
|
|
162
|
+
results = pd.DataFrame(ecx)
|
|
163
|
+
estimates[results.columns] = results
|
|
164
|
+
|
|
165
|
+
out = self._write_table(tab=estimates, label_insert="$LC_x$ estimates")
|
|
166
|
+
|
|
167
|
+
return out
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@reporting
|
|
171
|
+
def table_parameter_estimates(self, posterior, indices):
|
|
172
|
+
|
|
173
|
+
if self.rc.table_parameter_estimates_with_batch_dim_vars:
|
|
174
|
+
var_names = {
|
|
175
|
+
k: k for k, v in self.config.model_parameters.free.items()
|
|
176
|
+
}
|
|
177
|
+
else:
|
|
178
|
+
var_names = {
|
|
179
|
+
k: k for k, v in self.config.model_parameters.free.items()
|
|
180
|
+
if self.config.simulation.batch_dimension not in v.dims
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
var_names.update(self.rc.table_parameter_estimates_override_names) # type: ignore
|
|
184
|
+
|
|
185
|
+
if len(self.rc.table_parameter_estimates_exclude_vars) > 0:
|
|
186
|
+
self._write(f"Excluding parameters: {self.rc.table_parameter_estimates_exclude_vars} for meaningful visualization")
|
|
187
|
+
|
|
188
|
+
var_names = {
|
|
189
|
+
k: k for k, v in var_names.items()
|
|
190
|
+
if k not in self.rc.table_parameter_estimates_exclude_vars
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
tab_report = create_table(
|
|
194
|
+
posterior=posterior,
|
|
195
|
+
vars=var_names,
|
|
196
|
+
error_metric=self.rc.table_parameter_estimates_error_metric,
|
|
197
|
+
units=self.units,
|
|
198
|
+
stat_focus=self.config.guts_base.table_parameter_stat_focus,
|
|
199
|
+
significant_figures=self.rc.table_parameter_estimates_significant_figures,
|
|
200
|
+
nesting_dimension=indices.keys(),
|
|
201
|
+
parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# rewrite table in the desired output format
|
|
205
|
+
tab = create_table(
|
|
206
|
+
posterior=posterior,
|
|
207
|
+
vars=var_names,
|
|
208
|
+
error_metric=self.rc.table_parameter_estimates_error_metric,
|
|
209
|
+
units=self.units,
|
|
210
|
+
stat_focus=self.config.guts_base.table_parameter_stat_focus,
|
|
211
|
+
significant_figures=self.rc.table_parameter_estimates_significant_figures,
|
|
212
|
+
fmt=self.rc.table_parameter_estimates_format,
|
|
213
|
+
nesting_dimension=indices.keys(),
|
|
214
|
+
parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
self._write_table(tab=tab, tab_report=tab_report, label_insert="Parameter estimates")
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def create_table(
|
|
221
|
+
posterior,
|
|
222
|
+
error_metric: Literal["hdi","sd"] = "hdi",
|
|
223
|
+
vars: Dict = {},
|
|
224
|
+
nesting_dimension: Optional[Union[List,str]] = None,
|
|
225
|
+
stat_focus: Literal["mean", "median"] = "mean",
|
|
226
|
+
units: xr.Dataset = xr.Dataset(),
|
|
227
|
+
fmt: Literal["csv", "tsv", "latex"] = "csv",
|
|
228
|
+
significant_figures: int = 3,
|
|
229
|
+
parameters_as_rows: bool = True,
|
|
230
|
+
) -> pd.DataFrame:
|
|
231
|
+
"""The function is not ready to deal with any nesting dimensionality
|
|
232
|
+
and currently expects the 2-D case
|
|
233
|
+
"""
|
|
234
|
+
tab = az.summary(
|
|
235
|
+
posterior, var_names=list(vars.keys()),
|
|
236
|
+
fmt="xarray", kind="stats", stat_focus=stat_focus,
|
|
237
|
+
hdi_prob=0.94
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
if TYPE_CHECKING:
|
|
241
|
+
# just for
|
|
242
|
+
assert isinstance(tab, xr.Dataset)
|
|
243
|
+
|
|
244
|
+
tab = tab.rename(vars)
|
|
245
|
+
|
|
246
|
+
_units = flatten_coords(
|
|
247
|
+
dataset=create_units(dataset=tab, defined_units=units, stat_focus=stat_focus),
|
|
248
|
+
keep_dims=["metric"]
|
|
249
|
+
)
|
|
250
|
+
tab = flatten_coords(dataset=tab, keep_dims=["metric"])
|
|
251
|
+
|
|
252
|
+
tab = tab.apply(np.vectorize(
|
|
253
|
+
partial(round_to_sigfig, sig_fig=significant_figures)
|
|
254
|
+
))
|
|
255
|
+
|
|
256
|
+
if stat_focus == "mean" and error_metric == "sd":
|
|
257
|
+
metrics = ["mean", "sd"]
|
|
258
|
+
elif stat_focus == "mean" and error_metric == "hdi":
|
|
259
|
+
metrics = ["mean", "hdi_3%", "hdi_97%"]
|
|
260
|
+
elif stat_focus == "median" and error_metric == "sd":
|
|
261
|
+
metrics = ["median", "mad"]
|
|
262
|
+
else:
|
|
263
|
+
metrics = ["median", "eti_3%", "eti_97%"]
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
if error_metric == "sd":
|
|
268
|
+
arrays = []
|
|
269
|
+
for _, data_var in tab.data_vars.items():
|
|
270
|
+
par_formatted = data_var.sel(metric=metrics)\
|
|
271
|
+
.astype(str).str\
|
|
272
|
+
.join("metric", sep=" ± ")
|
|
273
|
+
arrays.append(par_formatted)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
table = xr.combine_by_coords(arrays)
|
|
277
|
+
table = table.assign_coords(metric=" ± ".join(metrics)).expand_dims("metric")
|
|
278
|
+
table = table.to_dataframe().T
|
|
279
|
+
|
|
280
|
+
elif error_metric == "hdi":
|
|
281
|
+
stacked_tab = tab.sel(metric=metrics)\
|
|
282
|
+
.assign_coords(metric=[m.replace("_", " ") for m in metrics])
|
|
283
|
+
table = stacked_tab.to_dataframe().T
|
|
284
|
+
|
|
285
|
+
else:
|
|
286
|
+
raise NotImplementedError("Must use one of 'sd' or 'hdi'")
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
if fmt == "latex":
|
|
290
|
+
table.columns.names = [str(c).replace('_',' ') for c in table.columns.names]
|
|
291
|
+
table.index = [format_parameter(i) for i in list(table.index)]
|
|
292
|
+
table = table.rename(
|
|
293
|
+
columns={c: c.replace("%", "\\%") for c in table.columns}
|
|
294
|
+
)
|
|
295
|
+
else:
|
|
296
|
+
pass
|
|
297
|
+
|
|
298
|
+
table["unit"] = _units.to_pandas().T
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
if parameters_as_rows:
|
|
302
|
+
return table
|
|
303
|
+
else:
|
|
304
|
+
return table.T
|
|
305
|
+
|
|
306
|
+
def flatten_coords(dataset: xr.Dataset, keep_dims):
|
|
307
|
+
"""flattens extra coordinates beside the keep_dim dimension for all data variables
|
|
308
|
+
producing a array with harmonized dimensions
|
|
309
|
+
"""
|
|
310
|
+
ds = dataset.copy()
|
|
311
|
+
ds = ds.reset_coords(drop=True)
|
|
312
|
+
for var_name, data_var in ds.data_vars.items():
|
|
313
|
+
extra_coords = [k for k in list(data_var.coords.keys()) if k not in keep_dims]
|
|
314
|
+
if len(extra_coords) == 0:
|
|
315
|
+
continue
|
|
316
|
+
|
|
317
|
+
data_var_ = data_var.stack(index=extra_coords)
|
|
318
|
+
|
|
319
|
+
# otherwise
|
|
320
|
+
for idx in data_var_["index"].values:
|
|
321
|
+
new_var_name = f"{var_name}[{','.join([str(e) for e in idx])}]"
|
|
322
|
+
# reset coordinates to move non-dim index coords from coordinates to the
|
|
323
|
+
# data variables and then select only the var_name from the data vars
|
|
324
|
+
new_data_var = data_var_.sel({"index": idx}).reset_coords()[var_name]
|
|
325
|
+
ds[new_var_name] = new_data_var
|
|
326
|
+
|
|
327
|
+
ds = ds.drop(var_name)
|
|
328
|
+
|
|
329
|
+
# drop any coordinates that should not be in the dataset at this stage
|
|
330
|
+
extra_coords = [k for k in list(ds.coords.keys()) if k not in keep_dims]
|
|
331
|
+
ds = ds.drop(extra_coords)
|
|
332
|
+
|
|
333
|
+
return ds
|
|
334
|
+
|
|
335
|
+
def create_units(dataset: xr.Dataset, defined_units: xr.Dataset, stat_focus):
|
|
336
|
+
units = dataset.sel(metric=[stat_focus]).astype(str)
|
|
337
|
+
units = units.assign_coords({"metric": ("metric", ["unit"])})
|
|
338
|
+
for k, u in units.data_vars.items():
|
|
339
|
+
if k in defined_units:
|
|
340
|
+
units = units.assign({k: defined_units[k].astype(units[k].dtype)})
|
|
341
|
+
else:
|
|
342
|
+
units[k].values = np.full_like(u.values, "")
|
|
343
|
+
|
|
344
|
+
return units
|
|
345
|
+
|
|
346
|
+
class ParameterConverter:
|
|
347
|
+
def __init__(
|
|
348
|
+
self,
|
|
349
|
+
sim: "GutsBase",
|
|
350
|
+
):
|
|
351
|
+
self.sim = sim.copy()
|
|
352
|
+
|
|
353
|
+
# this converts the units of exposure in the copied simulation
|
|
354
|
+
# and scales the exposure dataarray
|
|
355
|
+
self.sim._convert_exposure_units()
|
|
356
|
+
self.convert_parameters()
|
|
357
|
+
self.sim.prepare_simulation_input()
|
|
358
|
+
self.sim.dispatch_constructor()
|
|
359
|
+
|
|
360
|
+
# self.plot_exposure_and_effect(self.sim, sim, _id=7, data_var="D")
|
|
361
|
+
|
|
362
|
+
# if parameters are not rescaled this method should raise an error
|
|
363
|
+
self.validate_parameter_conversion_default_params(sim_copy=self.sim, sim_orig=sim)
|
|
364
|
+
self.validate_parameter_conversion_posterior_mean(sim_copy=self.sim, sim_orig=sim)
|
|
365
|
+
self.validate_parameter_conversion_posterior_map(sim_copy=self.sim, sim_orig=sim)
|
|
366
|
+
|
|
367
|
+
def convert_parameters(self):
|
|
368
|
+
raise NotImplementedError
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
@staticmethod
|
|
372
|
+
def plot_exposure_and_effect(sim_copy, sim_orig, _id=1, data_var="survival"):
|
|
373
|
+
from matplotlib import pyplot as plt
|
|
374
|
+
fig, (ax1, ax2) = plt.subplots(2,1)
|
|
375
|
+
results_copy = sim_copy.evaluate(parameters=sim_copy.config.model_parameters.value_dict)
|
|
376
|
+
results_orig = sim_orig.evaluate(parameters=sim_orig.config.model_parameters.value_dict)
|
|
377
|
+
|
|
378
|
+
ax1.plot(results_orig.time, results_orig["exposure"].isel(id=_id), color="red", label="unscaled")
|
|
379
|
+
ax1.plot(results_copy.time, results_copy["exposure"].isel(id=_id), color="blue", ls="--", label="scaled")
|
|
380
|
+
ax2.plot(results_orig.time, results_orig[data_var].isel(id=_id), color="red", label="unscaled")
|
|
381
|
+
ax2.plot(results_copy.time, results_copy[data_var].isel(id=_id), color="blue", ls="--", label="scaled")
|
|
382
|
+
ax1.legend()
|
|
383
|
+
ax2.legend()
|
|
384
|
+
return fig
|
|
385
|
+
|
|
386
|
+
@staticmethod
|
|
387
|
+
def validate_parameter_conversion_default_params(sim_copy, sim_orig):
|
|
388
|
+
results_copy = sim_copy.evaluate(parameters=sim_copy.config.model_parameters.value_dict)
|
|
389
|
+
results_orig = sim_orig.evaluate(parameters=sim_orig.config.model_parameters.value_dict)
|
|
390
|
+
|
|
391
|
+
np.testing.assert_allclose(results_copy.H, results_orig.H, atol=0.1, rtol=0.05)
|
|
392
|
+
|
|
393
|
+
@staticmethod
|
|
394
|
+
def validate_parameter_conversion_posterior_mean(sim_copy, sim_orig):
|
|
395
|
+
results_copy = sim_copy.evaluate(parameters=sim_copy.point_estimate("mean", to="dict"))
|
|
396
|
+
results_orig = sim_orig.evaluate(parameters=sim_orig.point_estimate("mean", to="dict"))
|
|
397
|
+
|
|
398
|
+
np.testing.assert_allclose(results_copy.H, results_orig.H, atol=0.1, rtol=0.05)
|
|
399
|
+
|
|
400
|
+
@staticmethod
|
|
401
|
+
def validate_parameter_conversion_posterior_map(sim_copy, sim_orig):
|
|
402
|
+
results_copy = sim_copy.evaluate(parameters=sim_copy.point_estimate("map", to="dict"))
|
|
403
|
+
results_orig = sim_orig.evaluate(parameters=sim_orig.point_estimate("map", to="dict"))
|
|
404
|
+
|
|
405
|
+
np.testing.assert_allclose(results_copy.H, results_orig.H, atol=0.1, rtol=0.05)
|