guts-base 0.8.5__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- guts_base/__init__.py +2 -1
- guts_base/data/__init__.py +1 -1
- guts_base/data/generator.py +6 -5
- guts_base/data/survival.py +6 -0
- guts_base/mod.py +27 -80
- guts_base/prob.py +23 -275
- guts_base/sim/__init__.py +10 -1
- guts_base/sim/base.py +350 -78
- guts_base/sim/constructors.py +31 -0
- guts_base/sim/ecx.py +221 -63
- guts_base/sim/mempy.py +85 -70
- guts_base/sim/report.py +9 -9
- guts_base/sim/utils.py +10 -0
- {guts_base-0.8.5.dist-info → guts_base-1.0.0.dist-info}/METADATA +3 -4
- guts_base-1.0.0.dist-info/RECORD +25 -0
- guts_base/sim.py +0 -0
- guts_base-0.8.5.dist-info/RECORD +0 -24
- {guts_base-0.8.5.dist-info → guts_base-1.0.0.dist-info}/WHEEL +0 -0
- {guts_base-0.8.5.dist-info → guts_base-1.0.0.dist-info}/entry_points.txt +0 -0
- {guts_base-0.8.5.dist-info → guts_base-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {guts_base-0.8.5.dist-info → guts_base-1.0.0.dist-info}/top_level.txt +0 -0
guts_base/sim/base.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import glob
|
|
3
|
+
from functools import partial
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
import importlib
|
|
3
6
|
import warnings
|
|
4
7
|
import numpy as np
|
|
5
8
|
import xarray as xr
|
|
6
9
|
from diffrax import Dopri5
|
|
7
|
-
from typing import Literal, Optional, List, Dict
|
|
10
|
+
from typing import Literal, Optional, List, Dict, Mapping, Sequence, Tuple
|
|
8
11
|
import tempfile
|
|
9
12
|
import pandas as pd
|
|
10
13
|
|
|
11
14
|
from pymob import SimulationBase
|
|
12
|
-
from pymob.sim.config import DataVariable, Param, string_to_list
|
|
15
|
+
from pymob.sim.config import DataVariable, Param, string_to_list, NumericArray
|
|
13
16
|
|
|
14
17
|
from pymob.solvers import JaxSolver
|
|
15
18
|
from pymob.solvers.base import rect_interpolation
|
|
@@ -17,14 +20,15 @@ from expyDB.intervention_model import (
|
|
|
17
20
|
Treatment, Timeseries, select, from_expydb
|
|
18
21
|
)
|
|
19
22
|
|
|
23
|
+
|
|
24
|
+
from guts_base.sim.utils import GutsBaseError
|
|
20
25
|
from guts_base import mod
|
|
21
26
|
from guts_base.data import (
|
|
22
27
|
to_dataset, reduce_multiindex_to_flat_index, create_artificial_data,
|
|
23
|
-
create_database_and_import_data_main, design_exposure_scenario
|
|
28
|
+
create_database_and_import_data_main, design_exposure_scenario, ExposureDataDict
|
|
24
29
|
)
|
|
25
30
|
from guts_base.sim.report import GutsReport
|
|
26
31
|
|
|
27
|
-
|
|
28
32
|
class GutsBase(SimulationBase):
|
|
29
33
|
"""
|
|
30
34
|
Initializes GUTS models from a variety of data sources
|
|
@@ -33,48 +37,77 @@ class GutsBase(SimulationBase):
|
|
|
33
37
|
1. check if necessary entries are made in the configuration, otherwise add defaults
|
|
34
38
|
2. read data or take from input
|
|
35
39
|
3. process data (add dimensions, or add indices)
|
|
40
|
+
4. Prepare model input
|
|
36
41
|
"""
|
|
37
42
|
solver = JaxSolver
|
|
38
43
|
Report = GutsReport
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
44
|
+
results_interpolation: Tuple[float,float,int] = (np.nan, np.nan, 100)
|
|
45
|
+
_skip_data_processing: bool = False
|
|
46
|
+
|
|
47
|
+
def initialize(self, input: Optional[Dict] = None):
|
|
48
|
+
"""Initiaization goes through a couple of steps:
|
|
49
|
+
|
|
50
|
+
1. Configuration: This makes case-study specific changes to the configuration
|
|
51
|
+
file or sets state variables that are relevant for the simulation
|
|
52
|
+
TODO: Ideally everything that is configurable ends up in the config so it
|
|
53
|
+
can be serialized
|
|
54
|
+
|
|
55
|
+
2. Import data: This method consists of submethods that can be adapted or
|
|
56
|
+
overwritten in subclass methods.
|
|
57
|
+
- .read_data
|
|
58
|
+
- .save_observations
|
|
59
|
+
- .process_data
|
|
60
|
+
process_data itself utilizes the submethods _create_indices and
|
|
61
|
+
_indices_to_dimensions which are empty methods by default, but can be used
|
|
62
|
+
in subclasses if needed
|
|
63
|
+
|
|
64
|
+
3. Initialize the simulation input (parameters, y0, x_in). This can
|
|
65
|
+
|
|
66
|
+
By splitting up the simulation init method, into these three steps, modifcations
|
|
67
|
+
of the initialize method allows for higher granularity in subclasses.
|
|
68
|
+
"""
|
|
42
69
|
|
|
43
|
-
|
|
70
|
+
# 1. Configuration
|
|
71
|
+
self.configure_case_study()
|
|
44
72
|
|
|
45
|
-
|
|
46
|
-
|
|
73
|
+
# 2. Import data
|
|
74
|
+
self.observations = self.read_data()
|
|
75
|
+
# FIXME: Saving observations here is not intuituve. If i export a simulation,
|
|
76
|
+
# I want to use the last used state, not some obscure intermediate state
|
|
77
|
+
# self.save_observations(filename="observations.nc", directory=self.output_path, force=True)
|
|
78
|
+
if not self._skip_data_processing:
|
|
79
|
+
self.process_data()
|
|
47
80
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
self.results_interpolation[0] = float(self.results_interpolation[0])
|
|
51
|
-
self.results_interpolation[1] = float(self.results_interpolation[1])
|
|
52
|
-
self.results_interpolation[2] = int(self.results_interpolation[2])
|
|
81
|
+
# 3. prepare y0 and x_in
|
|
82
|
+
self.prepare_simulation_input()
|
|
53
83
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
84
|
+
def configure_case_study(self):
|
|
85
|
+
"""Modify configuration file or set state variables
|
|
86
|
+
TODO: This should only modify the configuration file, so that changes
|
|
87
|
+
are transparent.
|
|
88
|
+
"""
|
|
89
|
+
if self._model_class is not None:
|
|
90
|
+
self.model = self._model_class._rhs_jax
|
|
91
|
+
self.solver_post_processing = self._model_class._solver_post_processing
|
|
59
92
|
|
|
60
|
-
|
|
61
|
-
self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
|
|
93
|
+
self.ecx_mode: Literal["mean", "draws"] = "mean"
|
|
62
94
|
|
|
63
|
-
self.
|
|
95
|
+
self.unit_time: Literal["day", "hour", "minute", "second"] = "day"
|
|
96
|
+
if hasattr(self.config.simulation, "unit_time"):
|
|
97
|
+
self.unit_time = self.config.simulation.unit_time # type: ignore
|
|
64
98
|
|
|
65
|
-
if "
|
|
66
|
-
|
|
67
|
-
self.observations = self.observations.assign_coords(
|
|
68
|
-
subject_count=("id", self.observations["survival"].isel(time=0).values, )
|
|
69
|
-
)
|
|
70
|
-
self.observations = self._data.prepare_survival_data_for_conditional_binomial(
|
|
71
|
-
observations=self.observations
|
|
72
|
-
)
|
|
99
|
+
if hasattr(self.config.simulation, "skip_data_processing"):
|
|
100
|
+
self._skip_data_processing = bool(self.config.simulation.skip_data_processing) # type: ignore
|
|
73
101
|
|
|
74
|
-
if "
|
|
75
|
-
self.config.
|
|
102
|
+
if hasattr(self.config.simulation, "results_interpolation"):
|
|
103
|
+
results_interpolation_string = string_to_list(self.config.simulation.results_interpolation)
|
|
104
|
+
self.results_interpolation = (
|
|
105
|
+
float(results_interpolation_string[0]),
|
|
106
|
+
float(results_interpolation_string[1]),
|
|
107
|
+
int(results_interpolation_string[2])
|
|
108
|
+
)
|
|
76
109
|
|
|
77
|
-
|
|
110
|
+
def prepare_simulation_input(self):
|
|
78
111
|
x_in = self.parse_input(input="x_in", reference_data=self.observations, drop_dims=[])
|
|
79
112
|
y0 = self.parse_input(input="y0", reference_data=self.observations, drop_dims=["time"])
|
|
80
113
|
|
|
@@ -187,6 +220,23 @@ class GutsBase(SimulationBase):
|
|
|
187
220
|
self._create_indices()
|
|
188
221
|
self._indices_to_dimensions()
|
|
189
222
|
|
|
223
|
+
# define tolerance based on the sovler tolerance
|
|
224
|
+
self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
|
|
225
|
+
|
|
226
|
+
self._reindex_time_dim()
|
|
227
|
+
|
|
228
|
+
if "survival" in self.observations:
|
|
229
|
+
if "subject_count" not in self.observations.coords:
|
|
230
|
+
self.observations = self.observations.assign_coords(
|
|
231
|
+
subject_count=("id", self.observations["survival"].isel(time=0).values, )
|
|
232
|
+
)
|
|
233
|
+
self.observations = self._data.prepare_survival_data_for_conditional_binomial(
|
|
234
|
+
observations=self.observations
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if "exposure" in self.observations:
|
|
238
|
+
self.config.data_structure.exposure.observed=False
|
|
239
|
+
|
|
190
240
|
def _create_indices(self):
|
|
191
241
|
"""Use if indices should be added to sim.indices and sim.observations"""
|
|
192
242
|
pass
|
|
@@ -223,30 +273,38 @@ class GutsBase(SimulationBase):
|
|
|
223
273
|
"is calculated without a dense time resolution, the estimates can be biased!"
|
|
224
274
|
))
|
|
225
275
|
|
|
226
|
-
|
|
227
|
-
|
|
228
276
|
def recompute_posterior(self):
|
|
229
277
|
"""This function interpolates the posterior with a given resolution
|
|
230
278
|
posterior_predictions calculate proper survival predictions for the
|
|
231
279
|
posterior.
|
|
280
|
+
|
|
281
|
+
It also makes sure that the new interpolation does not include fewer values
|
|
282
|
+
than the original dataset
|
|
232
283
|
"""
|
|
233
284
|
|
|
234
|
-
|
|
235
|
-
self.results_interpolation[0] = float(self.observations["time"].min())
|
|
236
|
-
|
|
237
|
-
if np.isnan(self.results_interpolation[1]):
|
|
238
|
-
self.results_interpolation[1] = float(self.observations["time"].max())
|
|
285
|
+
ri = self.results_interpolation
|
|
239
286
|
|
|
240
287
|
# generate high resolution posterior predictions
|
|
241
288
|
if self.results_interpolation is not None:
|
|
242
289
|
time_interpolate = np.linspace(
|
|
243
|
-
start=self.
|
|
244
|
-
stop=self.
|
|
290
|
+
start=float(self.observations["time"].min()) if np.isnan(ri[0]) else ri[0],
|
|
291
|
+
stop=float(self.observations["time"].max()) if np.isnan(ri[0]) else ri[1],
|
|
245
292
|
num=self.results_interpolation[2]
|
|
246
293
|
)
|
|
247
|
-
|
|
248
|
-
|
|
294
|
+
|
|
295
|
+
# combine original coordinates and interpolation. This
|
|
296
|
+
# a) helps error checking during posterior predictions.
|
|
297
|
+
# b) makes sure that the original time vector is retained, which may be
|
|
298
|
+
# relevant for the simulation success (e.g. IT model)
|
|
299
|
+
obs = self.observations.reindex(
|
|
300
|
+
time=np.unique(np.concatenate(
|
|
301
|
+
[time_interpolate, self.observations["time"]]
|
|
302
|
+
)),
|
|
249
303
|
)
|
|
304
|
+
|
|
305
|
+
obs["survivors_before_t"] = obs.survivors_before_t.ffill(dim="time").astype(int)
|
|
306
|
+
obs["survivors_at_start"] = obs.survivors_at_start.ffill(dim="time").astype(int)
|
|
307
|
+
self.observations = obs
|
|
250
308
|
|
|
251
309
|
self.dispatch_constructor()
|
|
252
310
|
_ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
|
|
@@ -254,13 +312,13 @@ class GutsBase(SimulationBase):
|
|
|
254
312
|
self.logger.info("Recomputed posterior and storing in `numpyro_posterior_interp.nc`")
|
|
255
313
|
|
|
256
314
|
|
|
257
|
-
def prior_predictive_checks(self):
|
|
258
|
-
super().prior_predictive_checks()
|
|
315
|
+
def prior_predictive_checks(self, **plot_kwargs):
|
|
316
|
+
super().prior_predictive_checks(**plot_kwargs)
|
|
259
317
|
|
|
260
318
|
self._plot.plot_prior_predictions(self, data_vars=["survival"])
|
|
261
319
|
|
|
262
|
-
def posterior_predictive_checks(self):
|
|
263
|
-
super().posterior_predictive_checks()
|
|
320
|
+
def posterior_predictive_checks(self, **plot_kwargs):
|
|
321
|
+
super().posterior_predictive_checks(**plot_kwargs)
|
|
264
322
|
|
|
265
323
|
self.recompute_posterior()
|
|
266
324
|
# TODO: Include posterior_predictive group once the survival predictions are correctly working
|
|
@@ -270,37 +328,93 @@ class GutsBase(SimulationBase):
|
|
|
270
328
|
def plot(self, results):
|
|
271
329
|
self._plot.plot_survival(self, results)
|
|
272
330
|
|
|
273
|
-
def copy(self):
|
|
274
|
-
with warnings.catch_warnings(action="ignore"):
|
|
275
|
-
sim_copy = type(self)(self.config)
|
|
276
|
-
sim_copy.observations = self.observations
|
|
277
|
-
sim_copy.model_parameters = self.model_parameters
|
|
278
|
-
if self.inferer is not None:
|
|
279
|
-
sim_copy.inferer = type(self.inferer)(self)
|
|
280
|
-
sim_copy.inferer.idata = self.inferer.idata
|
|
281
|
-
sim_copy.model = self.model
|
|
282
|
-
sim_copy.solver_post_processing = self.solver_post_processing
|
|
283
|
-
sim_copy.load_modules()
|
|
284
|
-
|
|
285
|
-
return sim_copy
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
@property
|
|
289
331
|
def predefined_scenarios(self):
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
332
|
+
"""
|
|
333
|
+
TODO: Fix timescale to observations
|
|
334
|
+
TODO: Incorporate extra exposure patterns (constant, pulse_1day, pulse_2day)
|
|
335
|
+
"""
|
|
336
|
+
# get the maximum possible time to provide exposure scenarios that are definitely
|
|
337
|
+
# long enough
|
|
338
|
+
time_max = max(
|
|
339
|
+
self.observations[self.config.simulation.x_dimension].max(),
|
|
340
|
+
*self.Report.ecx_estimates_times
|
|
298
341
|
)
|
|
299
342
|
|
|
300
|
-
|
|
301
|
-
|
|
343
|
+
# this produces a exposure x_in dataset with only the dimensions ID and TIME
|
|
344
|
+
standard_dimensions = (
|
|
345
|
+
self.config.simulation.batch_dimension,
|
|
346
|
+
self.config.simulation.x_dimension,
|
|
302
347
|
)
|
|
303
348
|
|
|
349
|
+
# get dimensions different from standard dimensions
|
|
350
|
+
exposure_dimension = [
|
|
351
|
+
d for d in self.observations.exposure.dims if d not in
|
|
352
|
+
standard_dimensions
|
|
353
|
+
]
|
|
354
|
+
|
|
355
|
+
# raise an error if the number of extra dimensions is larger than 1
|
|
356
|
+
if len(exposure_dimension) > 1:
|
|
357
|
+
raise ValueError(
|
|
358
|
+
f"{type(self).__name__} can currently handle one additional dimension for "+
|
|
359
|
+
f"the exposure beside {standard_dimensions}. You provided an exposure "+
|
|
360
|
+
f"array with the dimensions: {self.observations.exposure.dims}"
|
|
361
|
+
)
|
|
362
|
+
else:
|
|
363
|
+
exposure_dimension = exposure_dimension[0]
|
|
364
|
+
|
|
365
|
+
# iterate over the coordinates of the exposure dimensions to
|
|
366
|
+
exposure_coordinates = self.observations.exposure[exposure_dimension].values
|
|
367
|
+
|
|
368
|
+
scenarios = {}
|
|
369
|
+
for coord in exposure_coordinates:
|
|
370
|
+
concentrations = np.where(coord == exposure_coordinates, 1.0, 0.0)
|
|
371
|
+
|
|
372
|
+
exposure_dict = {
|
|
373
|
+
coord: ExposureDataDict(start=0, end=1, concentration=conc)
|
|
374
|
+
for coord, conc in zip(exposure_coordinates, concentrations)
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
scenario = design_exposure_scenario(
|
|
378
|
+
exposures=exposure_dict,
|
|
379
|
+
t_max=time_max,
|
|
380
|
+
dt=1/24,
|
|
381
|
+
exposure_dimension=exposure_dimension
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
scenarios.update({
|
|
385
|
+
f"1day_exposure_{coord}": scenario
|
|
386
|
+
})
|
|
387
|
+
|
|
388
|
+
return scenarios
|
|
389
|
+
|
|
390
|
+
@staticmethod
|
|
391
|
+
def _exposure_data_to_xarray(exposure_data: Dict[str, pd.DataFrame], dim: str):
|
|
392
|
+
"""
|
|
393
|
+
TODO: Currently no rect interpolation
|
|
394
|
+
"""
|
|
395
|
+
arrays = {}
|
|
396
|
+
for key, df in exposure_data.items():
|
|
397
|
+
# this override is necessary to make all dimensions work out
|
|
398
|
+
df.index.name = "time"
|
|
399
|
+
arrays.update({
|
|
400
|
+
key: df.to_xarray().to_dataarray(dim="id", name=key)
|
|
401
|
+
})
|
|
402
|
+
|
|
403
|
+
exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
|
|
404
|
+
exposure_array = exposure_array.transpose("id", "time", ...)
|
|
405
|
+
return xr.Dataset({"exposure": exposure_array})
|
|
406
|
+
|
|
407
|
+
@staticmethod
|
|
408
|
+
def _survival_data_to_xarray(survival_data: pd.DataFrame):
|
|
409
|
+
# TODO: survival name is currently not kept because the raw data is not transferred from the survival
|
|
410
|
+
survival_data.index.name = "time"
|
|
411
|
+
|
|
412
|
+
survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
|
|
413
|
+
survival_array = survival_array.transpose("id", "time", ...)
|
|
414
|
+
arrays = {"survival": survival_array}
|
|
415
|
+
return xr.Dataset(arrays)
|
|
416
|
+
|
|
417
|
+
|
|
304
418
|
def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
|
|
305
419
|
"""This method will take an existing coordinate of a dataset that has the same
|
|
306
420
|
coordinate has the batch dimension. It will then re-express the coordinate as a
|
|
@@ -367,6 +481,164 @@ class GutsBase(SimulationBase):
|
|
|
367
481
|
def initialize_from_script(self):
|
|
368
482
|
pass
|
|
369
483
|
|
|
484
|
+
@property
|
|
485
|
+
def _model_class(self):
|
|
486
|
+
if hasattr(self.config.simulation, "model_class"):
|
|
487
|
+
module, attr = self.config.simulation.model_class.rsplit(".", 1)
|
|
488
|
+
_module = importlib.import_module(module)
|
|
489
|
+
return getattr(_module, attr)
|
|
490
|
+
else:
|
|
491
|
+
return None
|
|
492
|
+
|
|
493
|
+
### API methods ###
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def point_estimate(
|
|
497
|
+
self,
|
|
498
|
+
estimate: Literal["mean", "map"] = "map",
|
|
499
|
+
to: Literal["xarray", "dict"] = "xarray"
|
|
500
|
+
):
|
|
501
|
+
"""Returns a point estimate of the posterior. If you want more control over the posterior
|
|
502
|
+
use the attribute: sim.inferer.idata.posterior and summarize it or select from it
|
|
503
|
+
using the arviz (https://python.arviz.org/en/stable/index.html) and the
|
|
504
|
+
xarray (https://docs.xarray.dev/en/stable/index.html) packages
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
|
|
509
|
+
estimate : Literal["map", "mean"]
|
|
510
|
+
Point estimate to return.
|
|
511
|
+
- map: Maximum a Posteriori. The sample that has the highest posterior probability.
|
|
512
|
+
This sample considers the correlation structure of the posterior
|
|
513
|
+
- mean: The average of all marginal parameter distributions.
|
|
514
|
+
|
|
515
|
+
to : Literal["xarray", "dict"]
|
|
516
|
+
Specifies the representation to transform the summarized data to. dict can
|
|
517
|
+
be used to insert parameters in the .evaluate() method. While xarray is the
|
|
518
|
+
standard view. Defaults to xarray
|
|
519
|
+
|
|
520
|
+
Example
|
|
521
|
+
-------
|
|
522
|
+
|
|
523
|
+
>>> sim.best_estimate(to='dict')
|
|
524
|
+
"""
|
|
525
|
+
if estimate == "mean":
|
|
526
|
+
best_estimate = self.inferer.idata.posterior.mean(("chain", "draw"))
|
|
527
|
+
|
|
528
|
+
elif estimate == "map":
|
|
529
|
+
loglik = self.inferer.idata.log_likelihood\
|
|
530
|
+
.sum(["id", "time"])\
|
|
531
|
+
.to_array().sum("variable")
|
|
532
|
+
|
|
533
|
+
sample_max_loglik = loglik.argmax(dim=("chain", "draw"))
|
|
534
|
+
best_estimate = self.inferer.idata.posterior.sel(sample_max_loglik)
|
|
535
|
+
else:
|
|
536
|
+
raise GutsBaseError(
|
|
537
|
+
f"Estimate '{estimate}' not implemented. Choose one of ['mean', 'map']"
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
if to == "xarray":
|
|
542
|
+
return best_estimate
|
|
543
|
+
|
|
544
|
+
elif to == "dict":
|
|
545
|
+
return {k: v.values for k, v in best_estimate.items()}
|
|
546
|
+
|
|
547
|
+
else:
|
|
548
|
+
raise GutsBaseError(
|
|
549
|
+
"PymobConverter.best_esimtate() supports only return types to=['xarray', 'dict']" +
|
|
550
|
+
f"You used {to=}"
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def evaluate(
|
|
555
|
+
self,
|
|
556
|
+
parameters: Mapping[str, float|NumericArray|Sequence[float]] = {},
|
|
557
|
+
y0: Mapping[str, float|NumericArray|Sequence[float]] = {},
|
|
558
|
+
x_in: Mapping[str, float|NumericArray|Sequence[float]] = {},
|
|
559
|
+
):
|
|
560
|
+
"""Evaluates the model along the coordinates of the observations with given
|
|
561
|
+
parameters, x_in, and y0. The dictionaries passed to the function arguments
|
|
562
|
+
only overwrite the existing default parameters; which makes the usage very simple.
|
|
563
|
+
|
|
564
|
+
Note that the first run of .evaluate() after calling the .dispatch_constructor()
|
|
565
|
+
takes a little longer, because the model and solver are jit-compiled to JAX for
|
|
566
|
+
highly efficient computations.
|
|
567
|
+
|
|
568
|
+
Parameters
|
|
569
|
+
----------
|
|
570
|
+
|
|
571
|
+
theta : Dict[float|Sequence[float]]
|
|
572
|
+
Dictionary of model parameters that should be changed for dispatch.
|
|
573
|
+
Unspecified model parameters will assume the default values,
|
|
574
|
+
specified under config.model_parameters.NAME.value
|
|
575
|
+
|
|
576
|
+
y0 : Dict[float|Sequence[float]]
|
|
577
|
+
Dictionary of initial values that should be changed for dispatch.
|
|
578
|
+
|
|
579
|
+
x_in : Dict[float|Sequence[float]]
|
|
580
|
+
Dictionary of model input values that should be changed for dispatch.
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
Example
|
|
584
|
+
-------
|
|
585
|
+
|
|
586
|
+
>>> sim.dispatch_constructor() # necessary if the sim object has been modified
|
|
587
|
+
>>> # evaluate setting the background mortaltiy to zero
|
|
588
|
+
>>> sim.evaluate(parameters={'hb': 0.0})
|
|
589
|
+
|
|
590
|
+
"""
|
|
591
|
+
evaluator = self.dispatch(theta=parameters, x_in=x_in, y0=y0)
|
|
592
|
+
evaluator()
|
|
593
|
+
return evaluator.results
|
|
594
|
+
|
|
595
|
+
def load_exposure_scenario(
|
|
596
|
+
self,
|
|
597
|
+
data: str|Dict[str,pd.DataFrame],
|
|
598
|
+
sheet_name_prefix: str = "",
|
|
599
|
+
rect_interpolate=False
|
|
600
|
+
|
|
601
|
+
):
|
|
602
|
+
|
|
603
|
+
if isinstance(data, str):
|
|
604
|
+
_data, time_unit = read_excel_file(
|
|
605
|
+
path=data,
|
|
606
|
+
sheet_name_prefix=sheet_name_prefix,
|
|
607
|
+
convert_time_to=self.unit_time
|
|
608
|
+
)
|
|
609
|
+
else:
|
|
610
|
+
_data = data
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
self._obs_backup = self.observations.copy(deep=True)
|
|
614
|
+
|
|
615
|
+
# read exposure array from file
|
|
616
|
+
exposure_dim = [
|
|
617
|
+
d for d in self.config.data_structure.exposure.dimensions
|
|
618
|
+
if d not in (self.config.simulation.x_dimension, self.config.simulation.batch_dimension)
|
|
619
|
+
]
|
|
620
|
+
exposure = self._exposure_data_to_xarray(
|
|
621
|
+
exposure_data=_data,
|
|
622
|
+
dim=exposure_dim[0]
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# combine with observations
|
|
626
|
+
new_obs = xr.combine_by_coords([
|
|
627
|
+
exposure,
|
|
628
|
+
self.observations.survival
|
|
629
|
+
]).sel(id=exposure.id)
|
|
630
|
+
|
|
631
|
+
self.observations = new_obs.sel(time=[t for t in new_obs.time if t <= exposure.time.max()])
|
|
632
|
+
self.config.simulation.x_in = ["exposure=exposure"]
|
|
633
|
+
self.model_parameters["x_in"] = self.parse_input("x_in", exposure).ffill("time") # type: ignore
|
|
634
|
+
self.model_parameters["y0"] = self.parse_input("y0", drop_dims=["time"])
|
|
635
|
+
|
|
636
|
+
self.dispatch_constructor()
|
|
637
|
+
|
|
638
|
+
def export(self, directory: Optional[str] = None):
|
|
639
|
+
self.config.simulation.skip_data_processing = False
|
|
640
|
+
super().export(directory=directory)
|
|
641
|
+
|
|
370
642
|
class GutsSimulationConstantExposure(GutsBase):
|
|
371
643
|
t_max = 10
|
|
372
644
|
def initialize_from_script(self):
|
|
@@ -387,7 +659,7 @@ class GutsSimulationConstantExposure(GutsBase):
|
|
|
387
659
|
self.config.model_parameters.z = Param(value=0.2, free=True)
|
|
388
660
|
|
|
389
661
|
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
390
|
-
self.config.simulation.model = "
|
|
662
|
+
self.config.simulation.model = "guts_constant_exposure"
|
|
391
663
|
|
|
392
664
|
self.coordinates["time"] = np.linspace(0,self.t_max)
|
|
393
665
|
|
|
@@ -397,7 +669,7 @@ class GutsSimulationConstantExposure(GutsBase):
|
|
|
397
669
|
# =======================
|
|
398
670
|
|
|
399
671
|
self.coordinates["time"] = np.array([0,self.t_max])
|
|
400
|
-
self.config.simulation.model = "
|
|
672
|
+
self.config.simulation.model = "guts_constant_exposure"
|
|
401
673
|
|
|
402
674
|
self.solver = JaxSolver
|
|
403
675
|
|
|
@@ -438,7 +710,7 @@ class GutsSimulationVariableExposure(GutsSimulationConstantExposure):
|
|
|
438
710
|
self.config.model_parameters.remove("C_0")
|
|
439
711
|
|
|
440
712
|
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
441
|
-
self.config.simulation.solver_post_processing = "
|
|
713
|
+
self.config.simulation.solver_post_processing = "red_sd_post_processing"
|
|
442
714
|
self.config.simulation.model = "guts_variable_exposure"
|
|
443
715
|
|
|
444
716
|
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import arviz as az
|
|
3
|
+
from guts_base.sim import GutsBase
|
|
4
|
+
|
|
5
|
+
def construct_sim_from_config(
|
|
6
|
+
scenario: str,
|
|
7
|
+
simulation_class: type,
|
|
8
|
+
output_path=None
|
|
9
|
+
) -> GutsBase:
|
|
10
|
+
"""Helper function to construct simulations for debugging"""
|
|
11
|
+
sim = simulation_class(f"scenarios/{scenario}/settings.cfg")
|
|
12
|
+
|
|
13
|
+
# this sets a different output directory
|
|
14
|
+
if output_path is not None:
|
|
15
|
+
p = output_path / sim.config.case_study.name / "results" / sim.config.case_study.scenario
|
|
16
|
+
sim.config.case_study.output_path = str(p)
|
|
17
|
+
else:
|
|
18
|
+
sim.config.case_study.scenario = "debug_test"
|
|
19
|
+
sim.setup()
|
|
20
|
+
return sim
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_idata(sim: GutsBase, idata_file: str) -> GutsBase:
|
|
24
|
+
sim.set_inferer("numpyro")
|
|
25
|
+
|
|
26
|
+
if os.path.exists(idata_file):
|
|
27
|
+
sim.inferer.idata = az.from_netcdf(idata_file)
|
|
28
|
+
else:
|
|
29
|
+
sim.inferer.idata = None
|
|
30
|
+
|
|
31
|
+
return sim
|