guts-base 0.8.6__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

guts_base/sim/base.py CHANGED
@@ -2,28 +2,30 @@ import os
2
2
  import glob
3
3
  from functools import partial
4
4
  from copy import deepcopy
5
+ import importlib
5
6
  import warnings
6
7
  import numpy as np
7
8
  import xarray as xr
8
9
  from diffrax import Dopri5
9
- from typing import Literal, Optional, List, Dict
10
+ from typing import Literal, Optional, List, Dict, Mapping, Sequence, Tuple
10
11
  import tempfile
11
12
  import pandas as pd
12
13
 
13
14
  from pymob import SimulationBase
14
- from pymob.sim.config import DataVariable, Param, string_to_list
15
+ from pymob.sim.config import DataVariable, Param, string_to_list, NumericArray
15
16
 
16
17
  from pymob.solvers import JaxSolver
17
18
  from pymob.solvers.base import rect_interpolation
18
- from pymob.sim.config import ParameterDict
19
19
  from expyDB.intervention_model import (
20
20
  Treatment, Timeseries, select, from_expydb
21
21
  )
22
22
 
23
+
24
+ from guts_base.sim.utils import GutsBaseError
23
25
  from guts_base import mod
24
26
  from guts_base.data import (
25
27
  to_dataset, reduce_multiindex_to_flat_index, create_artificial_data,
26
- create_database_and_import_data_main, design_exposure_scenario
28
+ create_database_and_import_data_main, design_exposure_scenario, ExposureDataDict
27
29
  )
28
30
  from guts_base.sim.report import GutsReport
29
31
 
@@ -35,48 +37,77 @@ class GutsBase(SimulationBase):
35
37
  1. check if necessary entries are made in the configuration, otherwise add defaults
36
38
  2. read data or take from input
37
39
  3. process data (add dimensions, or add indices)
40
+ 4. Prepare model input
38
41
  """
39
42
  solver = JaxSolver
40
43
  Report = GutsReport
44
+ results_interpolation: Tuple[float,float,int] = (np.nan, np.nan, 100)
45
+ _skip_data_processing: bool = False
46
+
47
+ def initialize(self, input: Optional[Dict] = None):
48
+ """Initiaization goes through a couple of steps:
49
+
50
+ 1. Configuration: This makes case-study specific changes to the configuration
51
+ file or sets state variables that are relevant for the simulation
52
+ TODO: Ideally everything that is configurable ends up in the config so it
53
+ can be serialized
54
+
55
+ 2. Import data: This method consists of submethods that can be adapted or
56
+ overwritten in subclass methods.
57
+ - .read_data
58
+ - .save_observations
59
+ - .process_data
60
+ process_data itself utilizes the submethods _create_indices and
61
+ _indices_to_dimensions which are empty methods by default, but can be used
62
+ in subclasses if needed
63
+
64
+ 3. Initialize the simulation input (parameters, y0, x_in). This can
65
+
66
+ By splitting up the simulation init method, into these three steps, modifcations
67
+ of the initialize method allows for higher granularity in subclasses.
68
+ """
69
+
70
+ # 1. Configuration
71
+ self.configure_case_study()
72
+
73
+ # 2. Import data
74
+ self.observations = self.read_data()
75
+ # FIXME: Saving observations here is not intuituve. If i export a simulation,
76
+ # I want to use the last used state, not some obscure intermediate state
77
+ # self.save_observations(filename="observations.nc", directory=self.output_path, force=True)
78
+ if not self._skip_data_processing:
79
+ self.process_data()
80
+
81
+ # 3. prepare y0 and x_in
82
+ self.prepare_simulation_input()
83
+
84
+ def configure_case_study(self):
85
+ """Modify configuration file or set state variables
86
+ TODO: This should only modify the configuration file, so that changes
87
+ are transparent.
88
+ """
89
+ if self._model_class is not None:
90
+ self.model = self._model_class._rhs_jax
91
+ self.solver_post_processing = self._model_class._solver_post_processing
41
92
 
42
- def initialize(self, input: Dict = None):
43
93
  self.ecx_mode: Literal["mean", "draws"] = "mean"
44
94
 
45
95
  self.unit_time: Literal["day", "hour", "minute", "second"] = "day"
46
96
  if hasattr(self.config.simulation, "unit_time"):
47
97
  self.unit_time = self.config.simulation.unit_time # type: ignore
48
98
 
49
- self.results_interpolation: Optional[List[float|int]] = [np.nan, np.nan, 100]
50
- if hasattr(self.config.simulation, "results_interpolation"):
51
- self.results_interpolation = string_to_list(self.config.simulation.results_interpolation)
52
- self.results_interpolation[0] = float(self.results_interpolation[0])
53
- self.results_interpolation[1] = float(self.results_interpolation[1])
54
- self.results_interpolation[2] = int(self.results_interpolation[2])
55
-
56
- if "observations" in input:
57
- self.observations = input["observations"]
58
- else:
59
- self.observations = self.read_data()
60
- self.process_data()
61
-
62
- # define tolerance based on the sovler tolerance
63
- self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
64
-
65
- self._reindex_time_dim()
99
+ if hasattr(self.config.simulation, "skip_data_processing"):
100
+ self._skip_data_processing = bool(self.config.simulation.skip_data_processing) # type: ignore
66
101
 
67
- if "survival" in self.observations:
68
- if "subject_count" not in self.observations.coords:
69
- self.observations = self.observations.assign_coords(
70
- subject_count=("id", self.observations["survival"].isel(time=0).values, )
71
- )
72
- self.observations = self._data.prepare_survival_data_for_conditional_binomial(
73
- observations=self.observations
102
+ if hasattr(self.config.simulation, "results_interpolation"):
103
+ results_interpolation_string = string_to_list(self.config.simulation.results_interpolation)
104
+ self.results_interpolation = (
105
+ float(results_interpolation_string[0]),
106
+ float(results_interpolation_string[1]),
107
+ int(results_interpolation_string[2])
74
108
  )
75
109
 
76
- if "exposure" in self.observations:
77
- self.config.data_structure.exposure.observed=False
78
-
79
- # prepare y0 and x_in
110
+ def prepare_simulation_input(self):
80
111
  x_in = self.parse_input(input="x_in", reference_data=self.observations, drop_dims=[])
81
112
  y0 = self.parse_input(input="y0", reference_data=self.observations, drop_dims=["time"])
82
113
 
@@ -189,6 +220,23 @@ class GutsBase(SimulationBase):
189
220
  self._create_indices()
190
221
  self._indices_to_dimensions()
191
222
 
223
+ # define tolerance based on the sovler tolerance
224
+ self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
225
+
226
+ self._reindex_time_dim()
227
+
228
+ if "survival" in self.observations:
229
+ if "subject_count" not in self.observations.coords:
230
+ self.observations = self.observations.assign_coords(
231
+ subject_count=("id", self.observations["survival"].isel(time=0).values, )
232
+ )
233
+ self.observations = self._data.prepare_survival_data_for_conditional_binomial(
234
+ observations=self.observations
235
+ )
236
+
237
+ if "exposure" in self.observations:
238
+ self.config.data_structure.exposure.observed=False
239
+
192
240
  def _create_indices(self):
193
241
  """Use if indices should be added to sim.indices and sim.observations"""
194
242
  pass
@@ -225,8 +273,6 @@ class GutsBase(SimulationBase):
225
273
  "is calculated without a dense time resolution, the estimates can be biased!"
226
274
  ))
227
275
 
228
-
229
-
230
276
  def recompute_posterior(self):
231
277
  """This function interpolates the posterior with a given resolution
232
278
  posterior_predictions calculate proper survival predictions for the
@@ -236,17 +282,13 @@ class GutsBase(SimulationBase):
236
282
  than the original dataset
237
283
  """
238
284
 
239
- if np.isnan(self.results_interpolation[0]):
240
- self.results_interpolation[0] = float(self.observations["time"].min())
241
-
242
- if np.isnan(self.results_interpolation[1]):
243
- self.results_interpolation[1] = float(self.observations["time"].max())
285
+ ri = self.results_interpolation
244
286
 
245
287
  # generate high resolution posterior predictions
246
288
  if self.results_interpolation is not None:
247
289
  time_interpolate = np.linspace(
248
- start=self.results_interpolation[0],
249
- stop=self.results_interpolation[1],
290
+ start=float(self.observations["time"].min()) if np.isnan(ri[0]) else ri[0],
291
+ stop=float(self.observations["time"].max()) if np.isnan(ri[0]) else ri[1],
250
292
  num=self.results_interpolation[2]
251
293
  )
252
294
 
@@ -254,11 +296,15 @@ class GutsBase(SimulationBase):
254
296
  # a) helps error checking during posterior predictions.
255
297
  # b) makes sure that the original time vector is retained, which may be
256
298
  # relevant for the simulation success (e.g. IT model)
257
- self.observations = self.observations.reindex(
299
+ obs = self.observations.reindex(
258
300
  time=np.unique(np.concatenate(
259
301
  [time_interpolate, self.observations["time"]]
260
- ))
302
+ )),
261
303
  )
304
+
305
+ obs["survivors_before_t"] = obs.survivors_before_t.ffill(dim="time").astype(int)
306
+ obs["survivors_at_start"] = obs.survivors_at_start.ffill(dim="time").astype(int)
307
+ self.observations = obs
262
308
 
263
309
  self.dispatch_constructor()
264
310
  _ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
@@ -266,13 +312,13 @@ class GutsBase(SimulationBase):
266
312
  self.logger.info("Recomputed posterior and storing in `numpyro_posterior_interp.nc`")
267
313
 
268
314
 
269
- def prior_predictive_checks(self):
270
- super().prior_predictive_checks()
315
+ def prior_predictive_checks(self, **plot_kwargs):
316
+ super().prior_predictive_checks(**plot_kwargs)
271
317
 
272
318
  self._plot.plot_prior_predictions(self, data_vars=["survival"])
273
319
 
274
- def posterior_predictive_checks(self):
275
- super().posterior_predictive_checks()
320
+ def posterior_predictive_checks(self, **plot_kwargs):
321
+ super().posterior_predictive_checks(**plot_kwargs)
276
322
 
277
323
  self.recompute_posterior()
278
324
  # TODO: Include posterior_predictive group once the survival predictions are correctly working
@@ -282,28 +328,6 @@ class GutsBase(SimulationBase):
282
328
  def plot(self, results):
283
329
  self._plot.plot_survival(self, results)
284
330
 
285
- def copy(self):
286
- with warnings.catch_warnings(action="ignore"):
287
- sim_copy = type(self)(self.config.copy(deep=True))
288
- sim_copy.observations = self.observations.copy(deep=True)
289
-
290
- # must copy individual parts of the dict due to the on_update method
291
- model_parameters = {k: deepcopy(v) for k, v in self.model_parameters.items()}
292
-
293
- # TODO: Refactor this once the parameterize method is removed.
294
- sim_copy.parameterize = partial(sim_copy.parameterize, model_parameters=model_parameters)
295
- sim_copy._model_parameters = ParameterDict(model_parameters, callback=sim_copy._on_params_updated)
296
-
297
- sim_copy.load_modules()
298
- if hasattr(self, "inferer"):
299
- sim_copy.inferer = type(self.inferer)(sim_copy)
300
- # idata uses deepcopy by default
301
- sim_copy.inferer.idata = self.inferer.idata.copy()
302
- sim_copy.model = self.model
303
- sim_copy.solver_post_processing = self.solver_post_processing
304
-
305
- return sim_copy
306
-
307
331
  def predefined_scenarios(self):
308
332
  """
309
333
  TODO: Fix timescale to observations
@@ -343,10 +367,10 @@ class GutsBase(SimulationBase):
343
367
 
344
368
  scenarios = {}
345
369
  for coord in exposure_coordinates:
346
- concentrations = np.where(coord == exposure_coordinates, 1, 0)
370
+ concentrations = np.where(coord == exposure_coordinates, 1.0, 0.0)
347
371
 
348
372
  exposure_dict = {
349
- coord: dict(start=0.0, end=1.0, concentration=conc)
373
+ coord: ExposureDataDict(start=0, end=1, concentration=conc)
350
374
  for coord, conc in zip(exposure_coordinates, concentrations)
351
375
  }
352
376
 
@@ -363,6 +387,34 @@ class GutsBase(SimulationBase):
363
387
 
364
388
  return scenarios
365
389
 
390
+ @staticmethod
391
+ def _exposure_data_to_xarray(exposure_data: Dict[str, pd.DataFrame], dim: str):
392
+ """
393
+ TODO: Currently no rect interpolation
394
+ """
395
+ arrays = {}
396
+ for key, df in exposure_data.items():
397
+ # this override is necessary to make all dimensions work out
398
+ df.index.name = "time"
399
+ arrays.update({
400
+ key: df.to_xarray().to_dataarray(dim="id", name=key)
401
+ })
402
+
403
+ exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
404
+ exposure_array = exposure_array.transpose("id", "time", ...)
405
+ return xr.Dataset({"exposure": exposure_array})
406
+
407
+ @staticmethod
408
+ def _survival_data_to_xarray(survival_data: pd.DataFrame):
409
+ # TODO: survival name is currently not kept because the raw data is not transferred from the survival
410
+ survival_data.index.name = "time"
411
+
412
+ survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
413
+ survival_array = survival_array.transpose("id", "time", ...)
414
+ arrays = {"survival": survival_array}
415
+ return xr.Dataset(arrays)
416
+
417
+
366
418
  def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
367
419
  """This method will take an existing coordinate of a dataset that has the same
368
420
  coordinate has the batch dimension. It will then re-express the coordinate as a
@@ -429,6 +481,164 @@ class GutsBase(SimulationBase):
429
481
  def initialize_from_script(self):
430
482
  pass
431
483
 
484
+ @property
485
+ def _model_class(self):
486
+ if hasattr(self.config.simulation, "model_class"):
487
+ module, attr = self.config.simulation.model_class.rsplit(".", 1)
488
+ _module = importlib.import_module(module)
489
+ return getattr(_module, attr)
490
+ else:
491
+ return None
492
+
493
+ ### API methods ###
494
+
495
+
496
+ def point_estimate(
497
+ self,
498
+ estimate: Literal["mean", "map"] = "map",
499
+ to: Literal["xarray", "dict"] = "xarray"
500
+ ):
501
+ """Returns a point estimate of the posterior. If you want more control over the posterior
502
+ use the attribute: sim.inferer.idata.posterior and summarize it or select from it
503
+ using the arviz (https://python.arviz.org/en/stable/index.html) and the
504
+ xarray (https://docs.xarray.dev/en/stable/index.html) packages
505
+
506
+ Parameters
507
+ ----------
508
+
509
+ estimate : Literal["map", "mean"]
510
+ Point estimate to return.
511
+ - map: Maximum a Posteriori. The sample that has the highest posterior probability.
512
+ This sample considers the correlation structure of the posterior
513
+ - mean: The average of all marginal parameter distributions.
514
+
515
+ to : Literal["xarray", "dict"]
516
+ Specifies the representation to transform the summarized data to. dict can
517
+ be used to insert parameters in the .evaluate() method. While xarray is the
518
+ standard view. Defaults to xarray
519
+
520
+ Example
521
+ -------
522
+
523
+ >>> sim.best_estimate(to='dict')
524
+ """
525
+ if estimate == "mean":
526
+ best_estimate = self.inferer.idata.posterior.mean(("chain", "draw"))
527
+
528
+ elif estimate == "map":
529
+ loglik = self.inferer.idata.log_likelihood\
530
+ .sum(["id", "time"])\
531
+ .to_array().sum("variable")
532
+
533
+ sample_max_loglik = loglik.argmax(dim=("chain", "draw"))
534
+ best_estimate = self.inferer.idata.posterior.sel(sample_max_loglik)
535
+ else:
536
+ raise GutsBaseError(
537
+ f"Estimate '{estimate}' not implemented. Choose one of ['mean', 'map']"
538
+ )
539
+
540
+
541
+ if to == "xarray":
542
+ return best_estimate
543
+
544
+ elif to == "dict":
545
+ return {k: v.values for k, v in best_estimate.items()}
546
+
547
+ else:
548
+ raise GutsBaseError(
549
+ "PymobConverter.best_esimtate() supports only return types to=['xarray', 'dict']" +
550
+ f"You used {to=}"
551
+ )
552
+
553
+
554
+ def evaluate(
555
+ self,
556
+ parameters: Mapping[str, float|NumericArray|Sequence[float]] = {},
557
+ y0: Mapping[str, float|NumericArray|Sequence[float]] = {},
558
+ x_in: Mapping[str, float|NumericArray|Sequence[float]] = {},
559
+ ):
560
+ """Evaluates the model along the coordinates of the observations with given
561
+ parameters, x_in, and y0. The dictionaries passed to the function arguments
562
+ only overwrite the existing default parameters; which makes the usage very simple.
563
+
564
+ Note that the first run of .evaluate() after calling the .dispatch_constructor()
565
+ takes a little longer, because the model and solver are jit-compiled to JAX for
566
+ highly efficient computations.
567
+
568
+ Parameters
569
+ ----------
570
+
571
+ theta : Dict[float|Sequence[float]]
572
+ Dictionary of model parameters that should be changed for dispatch.
573
+ Unspecified model parameters will assume the default values,
574
+ specified under config.model_parameters.NAME.value
575
+
576
+ y0 : Dict[float|Sequence[float]]
577
+ Dictionary of initial values that should be changed for dispatch.
578
+
579
+ x_in : Dict[float|Sequence[float]]
580
+ Dictionary of model input values that should be changed for dispatch.
581
+
582
+
583
+ Example
584
+ -------
585
+
586
+ >>> sim.dispatch_constructor() # necessary if the sim object has been modified
587
+ >>> # evaluate setting the background mortaltiy to zero
588
+ >>> sim.evaluate(parameters={'hb': 0.0})
589
+
590
+ """
591
+ evaluator = self.dispatch(theta=parameters, x_in=x_in, y0=y0)
592
+ evaluator()
593
+ return evaluator.results
594
+
595
+ def load_exposure_scenario(
596
+ self,
597
+ data: str|Dict[str,pd.DataFrame],
598
+ sheet_name_prefix: str = "",
599
+ rect_interpolate=False
600
+
601
+ ):
602
+
603
+ if isinstance(data, str):
604
+ _data, time_unit = read_excel_file(
605
+ path=data,
606
+ sheet_name_prefix=sheet_name_prefix,
607
+ convert_time_to=self.unit_time
608
+ )
609
+ else:
610
+ _data = data
611
+
612
+
613
+ self._obs_backup = self.observations.copy(deep=True)
614
+
615
+ # read exposure array from file
616
+ exposure_dim = [
617
+ d for d in self.config.data_structure.exposure.dimensions
618
+ if d not in (self.config.simulation.x_dimension, self.config.simulation.batch_dimension)
619
+ ]
620
+ exposure = self._exposure_data_to_xarray(
621
+ exposure_data=_data,
622
+ dim=exposure_dim[0]
623
+ )
624
+
625
+ # combine with observations
626
+ new_obs = xr.combine_by_coords([
627
+ exposure,
628
+ self.observations.survival
629
+ ]).sel(id=exposure.id)
630
+
631
+ self.observations = new_obs.sel(time=[t for t in new_obs.time if t <= exposure.time.max()])
632
+ self.config.simulation.x_in = ["exposure=exposure"]
633
+ self.model_parameters["x_in"] = self.parse_input("x_in", exposure).ffill("time") # type: ignore
634
+ self.model_parameters["y0"] = self.parse_input("y0", drop_dims=["time"])
635
+
636
+ self.dispatch_constructor()
637
+
638
+ def export(self, directory: Optional[str] = None):
639
+ self.config.simulation.skip_data_processing = False
640
+ super().export(directory=directory)
641
+
432
642
  class GutsSimulationConstantExposure(GutsBase):
433
643
  t_max = 10
434
644
  def initialize_from_script(self):
@@ -449,7 +659,7 @@ class GutsSimulationConstantExposure(GutsBase):
449
659
  self.config.model_parameters.z = Param(value=0.2, free=True)
450
660
 
451
661
  self.model_parameters["parameters"] = self.config.model_parameters.value_dict
452
- self.config.simulation.model = "guts_jax"
662
+ self.config.simulation.model = "guts_constant_exposure"
453
663
 
454
664
  self.coordinates["time"] = np.linspace(0,self.t_max)
455
665
 
@@ -459,7 +669,7 @@ class GutsSimulationConstantExposure(GutsBase):
459
669
  # =======================
460
670
 
461
671
  self.coordinates["time"] = np.array([0,self.t_max])
462
- self.config.simulation.model = "guts_jax"
672
+ self.config.simulation.model = "guts_constant_exposure"
463
673
 
464
674
  self.solver = JaxSolver
465
675
 
@@ -500,7 +710,7 @@ class GutsSimulationVariableExposure(GutsSimulationConstantExposure):
500
710
  self.config.model_parameters.remove("C_0")
501
711
 
502
712
  self.model_parameters["parameters"] = self.config.model_parameters.value_dict
503
- self.config.simulation.solver_post_processing = "post_exposure"
713
+ self.config.simulation.solver_post_processing = "red_sd_post_processing"
504
714
  self.config.simulation.model = "guts_variable_exposure"
505
715
 
506
716
 
@@ -0,0 +1,31 @@
1
+ import os
2
+ import arviz as az
3
+ from guts_base.sim import GutsBase
4
+
5
+ def construct_sim_from_config(
6
+ scenario: str,
7
+ simulation_class: type,
8
+ output_path=None
9
+ ) -> GutsBase:
10
+ """Helper function to construct simulations for debugging"""
11
+ sim = simulation_class(f"scenarios/{scenario}/settings.cfg")
12
+
13
+ # this sets a different output directory
14
+ if output_path is not None:
15
+ p = output_path / sim.config.case_study.name / "results" / sim.config.case_study.scenario
16
+ sim.config.case_study.output_path = str(p)
17
+ else:
18
+ sim.config.case_study.scenario = "debug_test"
19
+ sim.setup()
20
+ return sim
21
+
22
+
23
+ def load_idata(sim: GutsBase, idata_file: str) -> GutsBase:
24
+ sim.set_inferer("numpyro")
25
+
26
+ if os.path.exists(idata_file):
27
+ sim.inferer.idata = az.from_netcdf(idata_file)
28
+ else:
29
+ sim.inferer.idata = None
30
+
31
+ return sim