guts-base 0.8.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- guts_base/__init__.py +2 -1
- guts_base/data/__init__.py +1 -1
- guts_base/data/generator.py +2 -1
- guts_base/data/survival.py +6 -0
- guts_base/mod.py +24 -83
- guts_base/prob.py +23 -275
- guts_base/sim/__init__.py +10 -1
- guts_base/sim/base.py +285 -75
- guts_base/sim/constructors.py +31 -0
- guts_base/sim/ecx.py +168 -58
- guts_base/sim/mempy.py +85 -70
- guts_base/sim/report.py +0 -1
- guts_base/sim/utils.py +10 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/METADATA +2 -3
- guts_base-1.0.0.dist-info/RECORD +25 -0
- guts_base/sim.py +0 -0
- guts_base-0.8.6.dist-info/RECORD +0 -24
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/WHEEL +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/entry_points.txt +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/top_level.txt +0 -0
guts_base/sim/base.py
CHANGED
|
@@ -2,28 +2,30 @@ import os
|
|
|
2
2
|
import glob
|
|
3
3
|
from functools import partial
|
|
4
4
|
from copy import deepcopy
|
|
5
|
+
import importlib
|
|
5
6
|
import warnings
|
|
6
7
|
import numpy as np
|
|
7
8
|
import xarray as xr
|
|
8
9
|
from diffrax import Dopri5
|
|
9
|
-
from typing import Literal, Optional, List, Dict
|
|
10
|
+
from typing import Literal, Optional, List, Dict, Mapping, Sequence, Tuple
|
|
10
11
|
import tempfile
|
|
11
12
|
import pandas as pd
|
|
12
13
|
|
|
13
14
|
from pymob import SimulationBase
|
|
14
|
-
from pymob.sim.config import DataVariable, Param, string_to_list
|
|
15
|
+
from pymob.sim.config import DataVariable, Param, string_to_list, NumericArray
|
|
15
16
|
|
|
16
17
|
from pymob.solvers import JaxSolver
|
|
17
18
|
from pymob.solvers.base import rect_interpolation
|
|
18
|
-
from pymob.sim.config import ParameterDict
|
|
19
19
|
from expyDB.intervention_model import (
|
|
20
20
|
Treatment, Timeseries, select, from_expydb
|
|
21
21
|
)
|
|
22
22
|
|
|
23
|
+
|
|
24
|
+
from guts_base.sim.utils import GutsBaseError
|
|
23
25
|
from guts_base import mod
|
|
24
26
|
from guts_base.data import (
|
|
25
27
|
to_dataset, reduce_multiindex_to_flat_index, create_artificial_data,
|
|
26
|
-
create_database_and_import_data_main, design_exposure_scenario
|
|
28
|
+
create_database_and_import_data_main, design_exposure_scenario, ExposureDataDict
|
|
27
29
|
)
|
|
28
30
|
from guts_base.sim.report import GutsReport
|
|
29
31
|
|
|
@@ -35,48 +37,77 @@ class GutsBase(SimulationBase):
|
|
|
35
37
|
1. check if necessary entries are made in the configuration, otherwise add defaults
|
|
36
38
|
2. read data or take from input
|
|
37
39
|
3. process data (add dimensions, or add indices)
|
|
40
|
+
4. Prepare model input
|
|
38
41
|
"""
|
|
39
42
|
solver = JaxSolver
|
|
40
43
|
Report = GutsReport
|
|
44
|
+
results_interpolation: Tuple[float,float,int] = (np.nan, np.nan, 100)
|
|
45
|
+
_skip_data_processing: bool = False
|
|
46
|
+
|
|
47
|
+
def initialize(self, input: Optional[Dict] = None):
|
|
48
|
+
"""Initiaization goes through a couple of steps:
|
|
49
|
+
|
|
50
|
+
1. Configuration: This makes case-study specific changes to the configuration
|
|
51
|
+
file or sets state variables that are relevant for the simulation
|
|
52
|
+
TODO: Ideally everything that is configurable ends up in the config so it
|
|
53
|
+
can be serialized
|
|
54
|
+
|
|
55
|
+
2. Import data: This method consists of submethods that can be adapted or
|
|
56
|
+
overwritten in subclass methods.
|
|
57
|
+
- .read_data
|
|
58
|
+
- .save_observations
|
|
59
|
+
- .process_data
|
|
60
|
+
process_data itself utilizes the submethods _create_indices and
|
|
61
|
+
_indices_to_dimensions which are empty methods by default, but can be used
|
|
62
|
+
in subclasses if needed
|
|
63
|
+
|
|
64
|
+
3. Initialize the simulation input (parameters, y0, x_in). This can
|
|
65
|
+
|
|
66
|
+
By splitting up the simulation init method, into these three steps, modifcations
|
|
67
|
+
of the initialize method allows for higher granularity in subclasses.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
# 1. Configuration
|
|
71
|
+
self.configure_case_study()
|
|
72
|
+
|
|
73
|
+
# 2. Import data
|
|
74
|
+
self.observations = self.read_data()
|
|
75
|
+
# FIXME: Saving observations here is not intuituve. If i export a simulation,
|
|
76
|
+
# I want to use the last used state, not some obscure intermediate state
|
|
77
|
+
# self.save_observations(filename="observations.nc", directory=self.output_path, force=True)
|
|
78
|
+
if not self._skip_data_processing:
|
|
79
|
+
self.process_data()
|
|
80
|
+
|
|
81
|
+
# 3. prepare y0 and x_in
|
|
82
|
+
self.prepare_simulation_input()
|
|
83
|
+
|
|
84
|
+
def configure_case_study(self):
|
|
85
|
+
"""Modify configuration file or set state variables
|
|
86
|
+
TODO: This should only modify the configuration file, so that changes
|
|
87
|
+
are transparent.
|
|
88
|
+
"""
|
|
89
|
+
if self._model_class is not None:
|
|
90
|
+
self.model = self._model_class._rhs_jax
|
|
91
|
+
self.solver_post_processing = self._model_class._solver_post_processing
|
|
41
92
|
|
|
42
|
-
def initialize(self, input: Dict = None):
|
|
43
93
|
self.ecx_mode: Literal["mean", "draws"] = "mean"
|
|
44
94
|
|
|
45
95
|
self.unit_time: Literal["day", "hour", "minute", "second"] = "day"
|
|
46
96
|
if hasattr(self.config.simulation, "unit_time"):
|
|
47
97
|
self.unit_time = self.config.simulation.unit_time # type: ignore
|
|
48
98
|
|
|
49
|
-
self.
|
|
50
|
-
|
|
51
|
-
self.results_interpolation = string_to_list(self.config.simulation.results_interpolation)
|
|
52
|
-
self.results_interpolation[0] = float(self.results_interpolation[0])
|
|
53
|
-
self.results_interpolation[1] = float(self.results_interpolation[1])
|
|
54
|
-
self.results_interpolation[2] = int(self.results_interpolation[2])
|
|
55
|
-
|
|
56
|
-
if "observations" in input:
|
|
57
|
-
self.observations = input["observations"]
|
|
58
|
-
else:
|
|
59
|
-
self.observations = self.read_data()
|
|
60
|
-
self.process_data()
|
|
61
|
-
|
|
62
|
-
# define tolerance based on the sovler tolerance
|
|
63
|
-
self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
|
|
64
|
-
|
|
65
|
-
self._reindex_time_dim()
|
|
99
|
+
if hasattr(self.config.simulation, "skip_data_processing"):
|
|
100
|
+
self._skip_data_processing = bool(self.config.simulation.skip_data_processing) # type: ignore
|
|
66
101
|
|
|
67
|
-
if "
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
observations=self.observations
|
|
102
|
+
if hasattr(self.config.simulation, "results_interpolation"):
|
|
103
|
+
results_interpolation_string = string_to_list(self.config.simulation.results_interpolation)
|
|
104
|
+
self.results_interpolation = (
|
|
105
|
+
float(results_interpolation_string[0]),
|
|
106
|
+
float(results_interpolation_string[1]),
|
|
107
|
+
int(results_interpolation_string[2])
|
|
74
108
|
)
|
|
75
109
|
|
|
76
|
-
|
|
77
|
-
self.config.data_structure.exposure.observed=False
|
|
78
|
-
|
|
79
|
-
# prepare y0 and x_in
|
|
110
|
+
def prepare_simulation_input(self):
|
|
80
111
|
x_in = self.parse_input(input="x_in", reference_data=self.observations, drop_dims=[])
|
|
81
112
|
y0 = self.parse_input(input="y0", reference_data=self.observations, drop_dims=["time"])
|
|
82
113
|
|
|
@@ -189,6 +220,23 @@ class GutsBase(SimulationBase):
|
|
|
189
220
|
self._create_indices()
|
|
190
221
|
self._indices_to_dimensions()
|
|
191
222
|
|
|
223
|
+
# define tolerance based on the sovler tolerance
|
|
224
|
+
self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
|
|
225
|
+
|
|
226
|
+
self._reindex_time_dim()
|
|
227
|
+
|
|
228
|
+
if "survival" in self.observations:
|
|
229
|
+
if "subject_count" not in self.observations.coords:
|
|
230
|
+
self.observations = self.observations.assign_coords(
|
|
231
|
+
subject_count=("id", self.observations["survival"].isel(time=0).values, )
|
|
232
|
+
)
|
|
233
|
+
self.observations = self._data.prepare_survival_data_for_conditional_binomial(
|
|
234
|
+
observations=self.observations
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if "exposure" in self.observations:
|
|
238
|
+
self.config.data_structure.exposure.observed=False
|
|
239
|
+
|
|
192
240
|
def _create_indices(self):
|
|
193
241
|
"""Use if indices should be added to sim.indices and sim.observations"""
|
|
194
242
|
pass
|
|
@@ -225,8 +273,6 @@ class GutsBase(SimulationBase):
|
|
|
225
273
|
"is calculated without a dense time resolution, the estimates can be biased!"
|
|
226
274
|
))
|
|
227
275
|
|
|
228
|
-
|
|
229
|
-
|
|
230
276
|
def recompute_posterior(self):
|
|
231
277
|
"""This function interpolates the posterior with a given resolution
|
|
232
278
|
posterior_predictions calculate proper survival predictions for the
|
|
@@ -236,17 +282,13 @@ class GutsBase(SimulationBase):
|
|
|
236
282
|
than the original dataset
|
|
237
283
|
"""
|
|
238
284
|
|
|
239
|
-
|
|
240
|
-
self.results_interpolation[0] = float(self.observations["time"].min())
|
|
241
|
-
|
|
242
|
-
if np.isnan(self.results_interpolation[1]):
|
|
243
|
-
self.results_interpolation[1] = float(self.observations["time"].max())
|
|
285
|
+
ri = self.results_interpolation
|
|
244
286
|
|
|
245
287
|
# generate high resolution posterior predictions
|
|
246
288
|
if self.results_interpolation is not None:
|
|
247
289
|
time_interpolate = np.linspace(
|
|
248
|
-
start=self.
|
|
249
|
-
stop=self.
|
|
290
|
+
start=float(self.observations["time"].min()) if np.isnan(ri[0]) else ri[0],
|
|
291
|
+
stop=float(self.observations["time"].max()) if np.isnan(ri[0]) else ri[1],
|
|
250
292
|
num=self.results_interpolation[2]
|
|
251
293
|
)
|
|
252
294
|
|
|
@@ -254,11 +296,15 @@ class GutsBase(SimulationBase):
|
|
|
254
296
|
# a) helps error checking during posterior predictions.
|
|
255
297
|
# b) makes sure that the original time vector is retained, which may be
|
|
256
298
|
# relevant for the simulation success (e.g. IT model)
|
|
257
|
-
|
|
299
|
+
obs = self.observations.reindex(
|
|
258
300
|
time=np.unique(np.concatenate(
|
|
259
301
|
[time_interpolate, self.observations["time"]]
|
|
260
|
-
))
|
|
302
|
+
)),
|
|
261
303
|
)
|
|
304
|
+
|
|
305
|
+
obs["survivors_before_t"] = obs.survivors_before_t.ffill(dim="time").astype(int)
|
|
306
|
+
obs["survivors_at_start"] = obs.survivors_at_start.ffill(dim="time").astype(int)
|
|
307
|
+
self.observations = obs
|
|
262
308
|
|
|
263
309
|
self.dispatch_constructor()
|
|
264
310
|
_ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
|
|
@@ -266,13 +312,13 @@ class GutsBase(SimulationBase):
|
|
|
266
312
|
self.logger.info("Recomputed posterior and storing in `numpyro_posterior_interp.nc`")
|
|
267
313
|
|
|
268
314
|
|
|
269
|
-
def prior_predictive_checks(self):
|
|
270
|
-
super().prior_predictive_checks()
|
|
315
|
+
def prior_predictive_checks(self, **plot_kwargs):
|
|
316
|
+
super().prior_predictive_checks(**plot_kwargs)
|
|
271
317
|
|
|
272
318
|
self._plot.plot_prior_predictions(self, data_vars=["survival"])
|
|
273
319
|
|
|
274
|
-
def posterior_predictive_checks(self):
|
|
275
|
-
super().posterior_predictive_checks()
|
|
320
|
+
def posterior_predictive_checks(self, **plot_kwargs):
|
|
321
|
+
super().posterior_predictive_checks(**plot_kwargs)
|
|
276
322
|
|
|
277
323
|
self.recompute_posterior()
|
|
278
324
|
# TODO: Include posterior_predictive group once the survival predictions are correctly working
|
|
@@ -282,28 +328,6 @@ class GutsBase(SimulationBase):
|
|
|
282
328
|
def plot(self, results):
|
|
283
329
|
self._plot.plot_survival(self, results)
|
|
284
330
|
|
|
285
|
-
def copy(self):
|
|
286
|
-
with warnings.catch_warnings(action="ignore"):
|
|
287
|
-
sim_copy = type(self)(self.config.copy(deep=True))
|
|
288
|
-
sim_copy.observations = self.observations.copy(deep=True)
|
|
289
|
-
|
|
290
|
-
# must copy individual parts of the dict due to the on_update method
|
|
291
|
-
model_parameters = {k: deepcopy(v) for k, v in self.model_parameters.items()}
|
|
292
|
-
|
|
293
|
-
# TODO: Refactor this once the parameterize method is removed.
|
|
294
|
-
sim_copy.parameterize = partial(sim_copy.parameterize, model_parameters=model_parameters)
|
|
295
|
-
sim_copy._model_parameters = ParameterDict(model_parameters, callback=sim_copy._on_params_updated)
|
|
296
|
-
|
|
297
|
-
sim_copy.load_modules()
|
|
298
|
-
if hasattr(self, "inferer"):
|
|
299
|
-
sim_copy.inferer = type(self.inferer)(sim_copy)
|
|
300
|
-
# idata uses deepcopy by default
|
|
301
|
-
sim_copy.inferer.idata = self.inferer.idata.copy()
|
|
302
|
-
sim_copy.model = self.model
|
|
303
|
-
sim_copy.solver_post_processing = self.solver_post_processing
|
|
304
|
-
|
|
305
|
-
return sim_copy
|
|
306
|
-
|
|
307
331
|
def predefined_scenarios(self):
|
|
308
332
|
"""
|
|
309
333
|
TODO: Fix timescale to observations
|
|
@@ -343,10 +367,10 @@ class GutsBase(SimulationBase):
|
|
|
343
367
|
|
|
344
368
|
scenarios = {}
|
|
345
369
|
for coord in exposure_coordinates:
|
|
346
|
-
concentrations = np.where(coord == exposure_coordinates, 1, 0)
|
|
370
|
+
concentrations = np.where(coord == exposure_coordinates, 1.0, 0.0)
|
|
347
371
|
|
|
348
372
|
exposure_dict = {
|
|
349
|
-
coord:
|
|
373
|
+
coord: ExposureDataDict(start=0, end=1, concentration=conc)
|
|
350
374
|
for coord, conc in zip(exposure_coordinates, concentrations)
|
|
351
375
|
}
|
|
352
376
|
|
|
@@ -363,6 +387,34 @@ class GutsBase(SimulationBase):
|
|
|
363
387
|
|
|
364
388
|
return scenarios
|
|
365
389
|
|
|
390
|
+
@staticmethod
|
|
391
|
+
def _exposure_data_to_xarray(exposure_data: Dict[str, pd.DataFrame], dim: str):
|
|
392
|
+
"""
|
|
393
|
+
TODO: Currently no rect interpolation
|
|
394
|
+
"""
|
|
395
|
+
arrays = {}
|
|
396
|
+
for key, df in exposure_data.items():
|
|
397
|
+
# this override is necessary to make all dimensions work out
|
|
398
|
+
df.index.name = "time"
|
|
399
|
+
arrays.update({
|
|
400
|
+
key: df.to_xarray().to_dataarray(dim="id", name=key)
|
|
401
|
+
})
|
|
402
|
+
|
|
403
|
+
exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
|
|
404
|
+
exposure_array = exposure_array.transpose("id", "time", ...)
|
|
405
|
+
return xr.Dataset({"exposure": exposure_array})
|
|
406
|
+
|
|
407
|
+
@staticmethod
|
|
408
|
+
def _survival_data_to_xarray(survival_data: pd.DataFrame):
|
|
409
|
+
# TODO: survival name is currently not kept because the raw data is not transferred from the survival
|
|
410
|
+
survival_data.index.name = "time"
|
|
411
|
+
|
|
412
|
+
survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
|
|
413
|
+
survival_array = survival_array.transpose("id", "time", ...)
|
|
414
|
+
arrays = {"survival": survival_array}
|
|
415
|
+
return xr.Dataset(arrays)
|
|
416
|
+
|
|
417
|
+
|
|
366
418
|
def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
|
|
367
419
|
"""This method will take an existing coordinate of a dataset that has the same
|
|
368
420
|
coordinate has the batch dimension. It will then re-express the coordinate as a
|
|
@@ -429,6 +481,164 @@ class GutsBase(SimulationBase):
|
|
|
429
481
|
def initialize_from_script(self):
|
|
430
482
|
pass
|
|
431
483
|
|
|
484
|
+
@property
|
|
485
|
+
def _model_class(self):
|
|
486
|
+
if hasattr(self.config.simulation, "model_class"):
|
|
487
|
+
module, attr = self.config.simulation.model_class.rsplit(".", 1)
|
|
488
|
+
_module = importlib.import_module(module)
|
|
489
|
+
return getattr(_module, attr)
|
|
490
|
+
else:
|
|
491
|
+
return None
|
|
492
|
+
|
|
493
|
+
### API methods ###
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def point_estimate(
|
|
497
|
+
self,
|
|
498
|
+
estimate: Literal["mean", "map"] = "map",
|
|
499
|
+
to: Literal["xarray", "dict"] = "xarray"
|
|
500
|
+
):
|
|
501
|
+
"""Returns a point estimate of the posterior. If you want more control over the posterior
|
|
502
|
+
use the attribute: sim.inferer.idata.posterior and summarize it or select from it
|
|
503
|
+
using the arviz (https://python.arviz.org/en/stable/index.html) and the
|
|
504
|
+
xarray (https://docs.xarray.dev/en/stable/index.html) packages
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
|
|
509
|
+
estimate : Literal["map", "mean"]
|
|
510
|
+
Point estimate to return.
|
|
511
|
+
- map: Maximum a Posteriori. The sample that has the highest posterior probability.
|
|
512
|
+
This sample considers the correlation structure of the posterior
|
|
513
|
+
- mean: The average of all marginal parameter distributions.
|
|
514
|
+
|
|
515
|
+
to : Literal["xarray", "dict"]
|
|
516
|
+
Specifies the representation to transform the summarized data to. dict can
|
|
517
|
+
be used to insert parameters in the .evaluate() method. While xarray is the
|
|
518
|
+
standard view. Defaults to xarray
|
|
519
|
+
|
|
520
|
+
Example
|
|
521
|
+
-------
|
|
522
|
+
|
|
523
|
+
>>> sim.best_estimate(to='dict')
|
|
524
|
+
"""
|
|
525
|
+
if estimate == "mean":
|
|
526
|
+
best_estimate = self.inferer.idata.posterior.mean(("chain", "draw"))
|
|
527
|
+
|
|
528
|
+
elif estimate == "map":
|
|
529
|
+
loglik = self.inferer.idata.log_likelihood\
|
|
530
|
+
.sum(["id", "time"])\
|
|
531
|
+
.to_array().sum("variable")
|
|
532
|
+
|
|
533
|
+
sample_max_loglik = loglik.argmax(dim=("chain", "draw"))
|
|
534
|
+
best_estimate = self.inferer.idata.posterior.sel(sample_max_loglik)
|
|
535
|
+
else:
|
|
536
|
+
raise GutsBaseError(
|
|
537
|
+
f"Estimate '{estimate}' not implemented. Choose one of ['mean', 'map']"
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
if to == "xarray":
|
|
542
|
+
return best_estimate
|
|
543
|
+
|
|
544
|
+
elif to == "dict":
|
|
545
|
+
return {k: v.values for k, v in best_estimate.items()}
|
|
546
|
+
|
|
547
|
+
else:
|
|
548
|
+
raise GutsBaseError(
|
|
549
|
+
"PymobConverter.best_esimtate() supports only return types to=['xarray', 'dict']" +
|
|
550
|
+
f"You used {to=}"
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def evaluate(
|
|
555
|
+
self,
|
|
556
|
+
parameters: Mapping[str, float|NumericArray|Sequence[float]] = {},
|
|
557
|
+
y0: Mapping[str, float|NumericArray|Sequence[float]] = {},
|
|
558
|
+
x_in: Mapping[str, float|NumericArray|Sequence[float]] = {},
|
|
559
|
+
):
|
|
560
|
+
"""Evaluates the model along the coordinates of the observations with given
|
|
561
|
+
parameters, x_in, and y0. The dictionaries passed to the function arguments
|
|
562
|
+
only overwrite the existing default parameters; which makes the usage very simple.
|
|
563
|
+
|
|
564
|
+
Note that the first run of .evaluate() after calling the .dispatch_constructor()
|
|
565
|
+
takes a little longer, because the model and solver are jit-compiled to JAX for
|
|
566
|
+
highly efficient computations.
|
|
567
|
+
|
|
568
|
+
Parameters
|
|
569
|
+
----------
|
|
570
|
+
|
|
571
|
+
theta : Dict[float|Sequence[float]]
|
|
572
|
+
Dictionary of model parameters that should be changed for dispatch.
|
|
573
|
+
Unspecified model parameters will assume the default values,
|
|
574
|
+
specified under config.model_parameters.NAME.value
|
|
575
|
+
|
|
576
|
+
y0 : Dict[float|Sequence[float]]
|
|
577
|
+
Dictionary of initial values that should be changed for dispatch.
|
|
578
|
+
|
|
579
|
+
x_in : Dict[float|Sequence[float]]
|
|
580
|
+
Dictionary of model input values that should be changed for dispatch.
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
Example
|
|
584
|
+
-------
|
|
585
|
+
|
|
586
|
+
>>> sim.dispatch_constructor() # necessary if the sim object has been modified
|
|
587
|
+
>>> # evaluate setting the background mortaltiy to zero
|
|
588
|
+
>>> sim.evaluate(parameters={'hb': 0.0})
|
|
589
|
+
|
|
590
|
+
"""
|
|
591
|
+
evaluator = self.dispatch(theta=parameters, x_in=x_in, y0=y0)
|
|
592
|
+
evaluator()
|
|
593
|
+
return evaluator.results
|
|
594
|
+
|
|
595
|
+
def load_exposure_scenario(
|
|
596
|
+
self,
|
|
597
|
+
data: str|Dict[str,pd.DataFrame],
|
|
598
|
+
sheet_name_prefix: str = "",
|
|
599
|
+
rect_interpolate=False
|
|
600
|
+
|
|
601
|
+
):
|
|
602
|
+
|
|
603
|
+
if isinstance(data, str):
|
|
604
|
+
_data, time_unit = read_excel_file(
|
|
605
|
+
path=data,
|
|
606
|
+
sheet_name_prefix=sheet_name_prefix,
|
|
607
|
+
convert_time_to=self.unit_time
|
|
608
|
+
)
|
|
609
|
+
else:
|
|
610
|
+
_data = data
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
self._obs_backup = self.observations.copy(deep=True)
|
|
614
|
+
|
|
615
|
+
# read exposure array from file
|
|
616
|
+
exposure_dim = [
|
|
617
|
+
d for d in self.config.data_structure.exposure.dimensions
|
|
618
|
+
if d not in (self.config.simulation.x_dimension, self.config.simulation.batch_dimension)
|
|
619
|
+
]
|
|
620
|
+
exposure = self._exposure_data_to_xarray(
|
|
621
|
+
exposure_data=_data,
|
|
622
|
+
dim=exposure_dim[0]
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# combine with observations
|
|
626
|
+
new_obs = xr.combine_by_coords([
|
|
627
|
+
exposure,
|
|
628
|
+
self.observations.survival
|
|
629
|
+
]).sel(id=exposure.id)
|
|
630
|
+
|
|
631
|
+
self.observations = new_obs.sel(time=[t for t in new_obs.time if t <= exposure.time.max()])
|
|
632
|
+
self.config.simulation.x_in = ["exposure=exposure"]
|
|
633
|
+
self.model_parameters["x_in"] = self.parse_input("x_in", exposure).ffill("time") # type: ignore
|
|
634
|
+
self.model_parameters["y0"] = self.parse_input("y0", drop_dims=["time"])
|
|
635
|
+
|
|
636
|
+
self.dispatch_constructor()
|
|
637
|
+
|
|
638
|
+
def export(self, directory: Optional[str] = None):
|
|
639
|
+
self.config.simulation.skip_data_processing = False
|
|
640
|
+
super().export(directory=directory)
|
|
641
|
+
|
|
432
642
|
class GutsSimulationConstantExposure(GutsBase):
|
|
433
643
|
t_max = 10
|
|
434
644
|
def initialize_from_script(self):
|
|
@@ -449,7 +659,7 @@ class GutsSimulationConstantExposure(GutsBase):
|
|
|
449
659
|
self.config.model_parameters.z = Param(value=0.2, free=True)
|
|
450
660
|
|
|
451
661
|
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
452
|
-
self.config.simulation.model = "
|
|
662
|
+
self.config.simulation.model = "guts_constant_exposure"
|
|
453
663
|
|
|
454
664
|
self.coordinates["time"] = np.linspace(0,self.t_max)
|
|
455
665
|
|
|
@@ -459,7 +669,7 @@ class GutsSimulationConstantExposure(GutsBase):
|
|
|
459
669
|
# =======================
|
|
460
670
|
|
|
461
671
|
self.coordinates["time"] = np.array([0,self.t_max])
|
|
462
|
-
self.config.simulation.model = "
|
|
672
|
+
self.config.simulation.model = "guts_constant_exposure"
|
|
463
673
|
|
|
464
674
|
self.solver = JaxSolver
|
|
465
675
|
|
|
@@ -500,7 +710,7 @@ class GutsSimulationVariableExposure(GutsSimulationConstantExposure):
|
|
|
500
710
|
self.config.model_parameters.remove("C_0")
|
|
501
711
|
|
|
502
712
|
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
503
|
-
self.config.simulation.solver_post_processing = "
|
|
713
|
+
self.config.simulation.solver_post_processing = "red_sd_post_processing"
|
|
504
714
|
self.config.simulation.model = "guts_variable_exposure"
|
|
505
715
|
|
|
506
716
|
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import arviz as az
|
|
3
|
+
from guts_base.sim import GutsBase
|
|
4
|
+
|
|
5
|
+
def construct_sim_from_config(
|
|
6
|
+
scenario: str,
|
|
7
|
+
simulation_class: type,
|
|
8
|
+
output_path=None
|
|
9
|
+
) -> GutsBase:
|
|
10
|
+
"""Helper function to construct simulations for debugging"""
|
|
11
|
+
sim = simulation_class(f"scenarios/{scenario}/settings.cfg")
|
|
12
|
+
|
|
13
|
+
# this sets a different output directory
|
|
14
|
+
if output_path is not None:
|
|
15
|
+
p = output_path / sim.config.case_study.name / "results" / sim.config.case_study.scenario
|
|
16
|
+
sim.config.case_study.output_path = str(p)
|
|
17
|
+
else:
|
|
18
|
+
sim.config.case_study.scenario = "debug_test"
|
|
19
|
+
sim.setup()
|
|
20
|
+
return sim
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_idata(sim: GutsBase, idata_file: str) -> GutsBase:
|
|
24
|
+
sim.set_inferer("numpyro")
|
|
25
|
+
|
|
26
|
+
if os.path.exists(idata_file):
|
|
27
|
+
sim.inferer.idata = az.from_netcdf(idata_file)
|
|
28
|
+
else:
|
|
29
|
+
sim.inferer.idata = None
|
|
30
|
+
|
|
31
|
+
return sim
|