guts-base 0.8.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

guts_base/sim/base.py CHANGED
@@ -1,15 +1,18 @@
1
1
  import os
2
2
  import glob
3
+ from functools import partial
4
+ from copy import deepcopy
5
+ import importlib
3
6
  import warnings
4
7
  import numpy as np
5
8
  import xarray as xr
6
9
  from diffrax import Dopri5
7
- from typing import Literal, Optional, List, Dict
10
+ from typing import Literal, Optional, List, Dict, Mapping, Sequence, Tuple
8
11
  import tempfile
9
12
  import pandas as pd
10
13
 
11
14
  from pymob import SimulationBase
12
- from pymob.sim.config import DataVariable, Param, string_to_list
15
+ from pymob.sim.config import DataVariable, Param, string_to_list, NumericArray
13
16
 
14
17
  from pymob.solvers import JaxSolver
15
18
  from pymob.solvers.base import rect_interpolation
@@ -17,14 +20,15 @@ from expyDB.intervention_model import (
17
20
  Treatment, Timeseries, select, from_expydb
18
21
  )
19
22
 
23
+
24
+ from guts_base.sim.utils import GutsBaseError
20
25
  from guts_base import mod
21
26
  from guts_base.data import (
22
27
  to_dataset, reduce_multiindex_to_flat_index, create_artificial_data,
23
- create_database_and_import_data_main, design_exposure_scenario
28
+ create_database_and_import_data_main, design_exposure_scenario, ExposureDataDict
24
29
  )
25
30
  from guts_base.sim.report import GutsReport
26
31
 
27
-
28
32
  class GutsBase(SimulationBase):
29
33
  """
30
34
  Initializes GUTS models from a variety of data sources
@@ -33,48 +37,77 @@ class GutsBase(SimulationBase):
33
37
  1. check if necessary entries are made in the configuration, otherwise add defaults
34
38
  2. read data or take from input
35
39
  3. process data (add dimensions, or add indices)
40
+ 4. Prepare model input
36
41
  """
37
42
  solver = JaxSolver
38
43
  Report = GutsReport
39
- unit_time: Literal["day", "hour", "minute", "second"] = "day"
40
- results_interpolation: Optional[List[float|int]] = [np.nan, np.nan, 100]
41
- ecx_mode: Literal["mean", "draws"] = "mean"
44
+ results_interpolation: Tuple[float,float,int] = (np.nan, np.nan, 100)
45
+ _skip_data_processing: bool = False
46
+
47
+ def initialize(self, input: Optional[Dict] = None):
48
+ """Initiaization goes through a couple of steps:
49
+
50
+ 1. Configuration: This makes case-study specific changes to the configuration
51
+ file or sets state variables that are relevant for the simulation
52
+ TODO: Ideally everything that is configurable ends up in the config so it
53
+ can be serialized
54
+
55
+ 2. Import data: This method consists of submethods that can be adapted or
56
+ overwritten in subclass methods.
57
+ - .read_data
58
+ - .save_observations
59
+ - .process_data
60
+ process_data itself utilizes the submethods _create_indices and
61
+ _indices_to_dimensions which are empty methods by default, but can be used
62
+ in subclasses if needed
63
+
64
+ 3. Initialize the simulation input (parameters, y0, x_in). This can
65
+
66
+ By splitting up the simulation init method, into these three steps, modifcations
67
+ of the initialize method allows for higher granularity in subclasses.
68
+ """
42
69
 
43
- def initialize(self, input: Dict = None):
70
+ # 1. Configuration
71
+ self.configure_case_study()
44
72
 
45
- if hasattr(self.config.simulation, "unit_time"):
46
- self.unit_time = self.config.simulation.unit_time # type: ignore
73
+ # 2. Import data
74
+ self.observations = self.read_data()
75
+ # FIXME: Saving observations here is not intuituve. If i export a simulation,
76
+ # I want to use the last used state, not some obscure intermediate state
77
+ # self.save_observations(filename="observations.nc", directory=self.output_path, force=True)
78
+ if not self._skip_data_processing:
79
+ self.process_data()
47
80
 
48
- if hasattr(self.config.simulation, "results_interpolation"):
49
- self.results_interpolation = string_to_list(self.config.simulation.results_interpolation)
50
- self.results_interpolation[0] = float(self.results_interpolation[0])
51
- self.results_interpolation[1] = float(self.results_interpolation[1])
52
- self.results_interpolation[2] = int(self.results_interpolation[2])
81
+ # 3. prepare y0 and x_in
82
+ self.prepare_simulation_input()
53
83
 
54
- if "observations" in input:
55
- self.observations = input["observations"]
56
- else:
57
- self.observations = self.read_data()
58
- self.process_data()
84
+ def configure_case_study(self):
85
+ """Modify configuration file or set state variables
86
+ TODO: This should only modify the configuration file, so that changes
87
+ are transparent.
88
+ """
89
+ if self._model_class is not None:
90
+ self.model = self._model_class._rhs_jax
91
+ self.solver_post_processing = self._model_class._solver_post_processing
59
92
 
60
- # define tolerance based on the sovler tolerance
61
- self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
93
+ self.ecx_mode: Literal["mean", "draws"] = "mean"
62
94
 
63
- self._reindex_time_dim()
95
+ self.unit_time: Literal["day", "hour", "minute", "second"] = "day"
96
+ if hasattr(self.config.simulation, "unit_time"):
97
+ self.unit_time = self.config.simulation.unit_time # type: ignore
64
98
 
65
- if "survival" in self.observations:
66
- if "subject_count" not in self.observations.coords:
67
- self.observations = self.observations.assign_coords(
68
- subject_count=("id", self.observations["survival"].isel(time=0).values, )
69
- )
70
- self.observations = self._data.prepare_survival_data_for_conditional_binomial(
71
- observations=self.observations
72
- )
99
+ if hasattr(self.config.simulation, "skip_data_processing"):
100
+ self._skip_data_processing = bool(self.config.simulation.skip_data_processing) # type: ignore
73
101
 
74
- if "exposure" in self.observations:
75
- self.config.data_structure.exposure.observed=False
102
+ if hasattr(self.config.simulation, "results_interpolation"):
103
+ results_interpolation_string = string_to_list(self.config.simulation.results_interpolation)
104
+ self.results_interpolation = (
105
+ float(results_interpolation_string[0]),
106
+ float(results_interpolation_string[1]),
107
+ int(results_interpolation_string[2])
108
+ )
76
109
 
77
- # prepare y0 and x_in
110
+ def prepare_simulation_input(self):
78
111
  x_in = self.parse_input(input="x_in", reference_data=self.observations, drop_dims=[])
79
112
  y0 = self.parse_input(input="y0", reference_data=self.observations, drop_dims=["time"])
80
113
 
@@ -187,6 +220,23 @@ class GutsBase(SimulationBase):
187
220
  self._create_indices()
188
221
  self._indices_to_dimensions()
189
222
 
223
+ # define tolerance based on the sovler tolerance
224
+ self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
225
+
226
+ self._reindex_time_dim()
227
+
228
+ if "survival" in self.observations:
229
+ if "subject_count" not in self.observations.coords:
230
+ self.observations = self.observations.assign_coords(
231
+ subject_count=("id", self.observations["survival"].isel(time=0).values, )
232
+ )
233
+ self.observations = self._data.prepare_survival_data_for_conditional_binomial(
234
+ observations=self.observations
235
+ )
236
+
237
+ if "exposure" in self.observations:
238
+ self.config.data_structure.exposure.observed=False
239
+
190
240
  def _create_indices(self):
191
241
  """Use if indices should be added to sim.indices and sim.observations"""
192
242
  pass
@@ -223,30 +273,38 @@ class GutsBase(SimulationBase):
223
273
  "is calculated without a dense time resolution, the estimates can be biased!"
224
274
  ))
225
275
 
226
-
227
-
228
276
  def recompute_posterior(self):
229
277
  """This function interpolates the posterior with a given resolution
230
278
  posterior_predictions calculate proper survival predictions for the
231
279
  posterior.
280
+
281
+ It also makes sure that the new interpolation does not include fewer values
282
+ than the original dataset
232
283
  """
233
284
 
234
- if np.isnan(self.results_interpolation[0]):
235
- self.results_interpolation[0] = float(self.observations["time"].min())
236
-
237
- if np.isnan(self.results_interpolation[1]):
238
- self.results_interpolation[1] = float(self.observations["time"].max())
285
+ ri = self.results_interpolation
239
286
 
240
287
  # generate high resolution posterior predictions
241
288
  if self.results_interpolation is not None:
242
289
  time_interpolate = np.linspace(
243
- start=self.results_interpolation[0],
244
- stop=self.results_interpolation[1],
290
+ start=float(self.observations["time"].min()) if np.isnan(ri[0]) else ri[0],
291
+ stop=float(self.observations["time"].max()) if np.isnan(ri[0]) else ri[1],
245
292
  num=self.results_interpolation[2]
246
293
  )
247
- self.observations = self.observations.reindex(
248
- time=time_interpolate
294
+
295
+ # combine original coordinates and interpolation. This
296
+ # a) helps error checking during posterior predictions.
297
+ # b) makes sure that the original time vector is retained, which may be
298
+ # relevant for the simulation success (e.g. IT model)
299
+ obs = self.observations.reindex(
300
+ time=np.unique(np.concatenate(
301
+ [time_interpolate, self.observations["time"]]
302
+ )),
249
303
  )
304
+
305
+ obs["survivors_before_t"] = obs.survivors_before_t.ffill(dim="time").astype(int)
306
+ obs["survivors_at_start"] = obs.survivors_at_start.ffill(dim="time").astype(int)
307
+ self.observations = obs
250
308
 
251
309
  self.dispatch_constructor()
252
310
  _ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
@@ -254,13 +312,13 @@ class GutsBase(SimulationBase):
254
312
  self.logger.info("Recomputed posterior and storing in `numpyro_posterior_interp.nc`")
255
313
 
256
314
 
257
- def prior_predictive_checks(self):
258
- super().prior_predictive_checks()
315
+ def prior_predictive_checks(self, **plot_kwargs):
316
+ super().prior_predictive_checks(**plot_kwargs)
259
317
 
260
318
  self._plot.plot_prior_predictions(self, data_vars=["survival"])
261
319
 
262
- def posterior_predictive_checks(self):
263
- super().posterior_predictive_checks()
320
+ def posterior_predictive_checks(self, **plot_kwargs):
321
+ super().posterior_predictive_checks(**plot_kwargs)
264
322
 
265
323
  self.recompute_posterior()
266
324
  # TODO: Include posterior_predictive group once the survival predictions are correctly working
@@ -270,37 +328,93 @@ class GutsBase(SimulationBase):
270
328
  def plot(self, results):
271
329
  self._plot.plot_survival(self, results)
272
330
 
273
- def copy(self):
274
- with warnings.catch_warnings(action="ignore"):
275
- sim_copy = type(self)(self.config)
276
- sim_copy.observations = self.observations
277
- sim_copy.model_parameters = self.model_parameters
278
- if self.inferer is not None:
279
- sim_copy.inferer = type(self.inferer)(self)
280
- sim_copy.inferer.idata = self.inferer.idata
281
- sim_copy.model = self.model
282
- sim_copy.solver_post_processing = self.solver_post_processing
283
- sim_copy.load_modules()
284
-
285
- return sim_copy
286
-
287
-
288
- @property
289
331
  def predefined_scenarios(self):
290
- # this produces a exposure x_in dataset with only the dimensions ID and TIME
291
- oral_acute_1d = design_exposure_scenario(
292
- exposures={
293
- "oral":dict(start=0, end=1.0, concentration=1.0),
294
- },
295
- t_max=10.01,
296
- dt=1/24,
297
- exposure_dimension="exposure_path"
332
+ """
333
+ TODO: Fix timescale to observations
334
+ TODO: Incorporate extra exposure patterns (constant, pulse_1day, pulse_2day)
335
+ """
336
+ # get the maximum possible time to provide exposure scenarios that are definitely
337
+ # long enough
338
+ time_max = max(
339
+ self.observations[self.config.simulation.x_dimension].max(),
340
+ *self.Report.ecx_estimates_times
298
341
  )
299
342
 
300
- return dict(
301
- oral_acute_1d=oral_acute_1d
343
+ # this produces a exposure x_in dataset with only the dimensions ID and TIME
344
+ standard_dimensions = (
345
+ self.config.simulation.batch_dimension,
346
+ self.config.simulation.x_dimension,
302
347
  )
303
348
 
349
+ # get dimensions different from standard dimensions
350
+ exposure_dimension = [
351
+ d for d in self.observations.exposure.dims if d not in
352
+ standard_dimensions
353
+ ]
354
+
355
+ # raise an error if the number of extra dimensions is larger than 1
356
+ if len(exposure_dimension) > 1:
357
+ raise ValueError(
358
+ f"{type(self).__name__} can currently handle one additional dimension for "+
359
+ f"the exposure beside {standard_dimensions}. You provided an exposure "+
360
+ f"array with the dimensions: {self.observations.exposure.dims}"
361
+ )
362
+ else:
363
+ exposure_dimension = exposure_dimension[0]
364
+
365
+ # iterate over the coordinates of the exposure dimensions to
366
+ exposure_coordinates = self.observations.exposure[exposure_dimension].values
367
+
368
+ scenarios = {}
369
+ for coord in exposure_coordinates:
370
+ concentrations = np.where(coord == exposure_coordinates, 1.0, 0.0)
371
+
372
+ exposure_dict = {
373
+ coord: ExposureDataDict(start=0, end=1, concentration=conc)
374
+ for coord, conc in zip(exposure_coordinates, concentrations)
375
+ }
376
+
377
+ scenario = design_exposure_scenario(
378
+ exposures=exposure_dict,
379
+ t_max=time_max,
380
+ dt=1/24,
381
+ exposure_dimension=exposure_dimension
382
+ )
383
+
384
+ scenarios.update({
385
+ f"1day_exposure_{coord}": scenario
386
+ })
387
+
388
+ return scenarios
389
+
390
+ @staticmethod
391
+ def _exposure_data_to_xarray(exposure_data: Dict[str, pd.DataFrame], dim: str):
392
+ """
393
+ TODO: Currently no rect interpolation
394
+ """
395
+ arrays = {}
396
+ for key, df in exposure_data.items():
397
+ # this override is necessary to make all dimensions work out
398
+ df.index.name = "time"
399
+ arrays.update({
400
+ key: df.to_xarray().to_dataarray(dim="id", name=key)
401
+ })
402
+
403
+ exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
404
+ exposure_array = exposure_array.transpose("id", "time", ...)
405
+ return xr.Dataset({"exposure": exposure_array})
406
+
407
+ @staticmethod
408
+ def _survival_data_to_xarray(survival_data: pd.DataFrame):
409
+ # TODO: survival name is currently not kept because the raw data is not transferred from the survival
410
+ survival_data.index.name = "time"
411
+
412
+ survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
413
+ survival_array = survival_array.transpose("id", "time", ...)
414
+ arrays = {"survival": survival_array}
415
+ return xr.Dataset(arrays)
416
+
417
+
304
418
  def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
305
419
  """This method will take an existing coordinate of a dataset that has the same
306
420
  coordinate has the batch dimension. It will then re-express the coordinate as a
@@ -367,6 +481,164 @@ class GutsBase(SimulationBase):
367
481
  def initialize_from_script(self):
368
482
  pass
369
483
 
484
+ @property
485
+ def _model_class(self):
486
+ if hasattr(self.config.simulation, "model_class"):
487
+ module, attr = self.config.simulation.model_class.rsplit(".", 1)
488
+ _module = importlib.import_module(module)
489
+ return getattr(_module, attr)
490
+ else:
491
+ return None
492
+
493
+ ### API methods ###
494
+
495
+
496
+ def point_estimate(
497
+ self,
498
+ estimate: Literal["mean", "map"] = "map",
499
+ to: Literal["xarray", "dict"] = "xarray"
500
+ ):
501
+ """Returns a point estimate of the posterior. If you want more control over the posterior
502
+ use the attribute: sim.inferer.idata.posterior and summarize it or select from it
503
+ using the arviz (https://python.arviz.org/en/stable/index.html) and the
504
+ xarray (https://docs.xarray.dev/en/stable/index.html) packages
505
+
506
+ Parameters
507
+ ----------
508
+
509
+ estimate : Literal["map", "mean"]
510
+ Point estimate to return.
511
+ - map: Maximum a Posteriori. The sample that has the highest posterior probability.
512
+ This sample considers the correlation structure of the posterior
513
+ - mean: The average of all marginal parameter distributions.
514
+
515
+ to : Literal["xarray", "dict"]
516
+ Specifies the representation to transform the summarized data to. dict can
517
+ be used to insert parameters in the .evaluate() method. While xarray is the
518
+ standard view. Defaults to xarray
519
+
520
+ Example
521
+ -------
522
+
523
+ >>> sim.best_estimate(to='dict')
524
+ """
525
+ if estimate == "mean":
526
+ best_estimate = self.inferer.idata.posterior.mean(("chain", "draw"))
527
+
528
+ elif estimate == "map":
529
+ loglik = self.inferer.idata.log_likelihood\
530
+ .sum(["id", "time"])\
531
+ .to_array().sum("variable")
532
+
533
+ sample_max_loglik = loglik.argmax(dim=("chain", "draw"))
534
+ best_estimate = self.inferer.idata.posterior.sel(sample_max_loglik)
535
+ else:
536
+ raise GutsBaseError(
537
+ f"Estimate '{estimate}' not implemented. Choose one of ['mean', 'map']"
538
+ )
539
+
540
+
541
+ if to == "xarray":
542
+ return best_estimate
543
+
544
+ elif to == "dict":
545
+ return {k: v.values for k, v in best_estimate.items()}
546
+
547
+ else:
548
+ raise GutsBaseError(
549
+ "PymobConverter.best_esimtate() supports only return types to=['xarray', 'dict']" +
550
+ f"You used {to=}"
551
+ )
552
+
553
+
554
+ def evaluate(
555
+ self,
556
+ parameters: Mapping[str, float|NumericArray|Sequence[float]] = {},
557
+ y0: Mapping[str, float|NumericArray|Sequence[float]] = {},
558
+ x_in: Mapping[str, float|NumericArray|Sequence[float]] = {},
559
+ ):
560
+ """Evaluates the model along the coordinates of the observations with given
561
+ parameters, x_in, and y0. The dictionaries passed to the function arguments
562
+ only overwrite the existing default parameters; which makes the usage very simple.
563
+
564
+ Note that the first run of .evaluate() after calling the .dispatch_constructor()
565
+ takes a little longer, because the model and solver are jit-compiled to JAX for
566
+ highly efficient computations.
567
+
568
+ Parameters
569
+ ----------
570
+
571
+ theta : Dict[float|Sequence[float]]
572
+ Dictionary of model parameters that should be changed for dispatch.
573
+ Unspecified model parameters will assume the default values,
574
+ specified under config.model_parameters.NAME.value
575
+
576
+ y0 : Dict[float|Sequence[float]]
577
+ Dictionary of initial values that should be changed for dispatch.
578
+
579
+ x_in : Dict[float|Sequence[float]]
580
+ Dictionary of model input values that should be changed for dispatch.
581
+
582
+
583
+ Example
584
+ -------
585
+
586
+ >>> sim.dispatch_constructor() # necessary if the sim object has been modified
587
+ >>> # evaluate setting the background mortaltiy to zero
588
+ >>> sim.evaluate(parameters={'hb': 0.0})
589
+
590
+ """
591
+ evaluator = self.dispatch(theta=parameters, x_in=x_in, y0=y0)
592
+ evaluator()
593
+ return evaluator.results
594
+
595
+ def load_exposure_scenario(
596
+ self,
597
+ data: str|Dict[str,pd.DataFrame],
598
+ sheet_name_prefix: str = "",
599
+ rect_interpolate=False
600
+
601
+ ):
602
+
603
+ if isinstance(data, str):
604
+ _data, time_unit = read_excel_file(
605
+ path=data,
606
+ sheet_name_prefix=sheet_name_prefix,
607
+ convert_time_to=self.unit_time
608
+ )
609
+ else:
610
+ _data = data
611
+
612
+
613
+ self._obs_backup = self.observations.copy(deep=True)
614
+
615
+ # read exposure array from file
616
+ exposure_dim = [
617
+ d for d in self.config.data_structure.exposure.dimensions
618
+ if d not in (self.config.simulation.x_dimension, self.config.simulation.batch_dimension)
619
+ ]
620
+ exposure = self._exposure_data_to_xarray(
621
+ exposure_data=_data,
622
+ dim=exposure_dim[0]
623
+ )
624
+
625
+ # combine with observations
626
+ new_obs = xr.combine_by_coords([
627
+ exposure,
628
+ self.observations.survival
629
+ ]).sel(id=exposure.id)
630
+
631
+ self.observations = new_obs.sel(time=[t for t in new_obs.time if t <= exposure.time.max()])
632
+ self.config.simulation.x_in = ["exposure=exposure"]
633
+ self.model_parameters["x_in"] = self.parse_input("x_in", exposure).ffill("time") # type: ignore
634
+ self.model_parameters["y0"] = self.parse_input("y0", drop_dims=["time"])
635
+
636
+ self.dispatch_constructor()
637
+
638
+ def export(self, directory: Optional[str] = None):
639
+ self.config.simulation.skip_data_processing = False
640
+ super().export(directory=directory)
641
+
370
642
  class GutsSimulationConstantExposure(GutsBase):
371
643
  t_max = 10
372
644
  def initialize_from_script(self):
@@ -387,7 +659,7 @@ class GutsSimulationConstantExposure(GutsBase):
387
659
  self.config.model_parameters.z = Param(value=0.2, free=True)
388
660
 
389
661
  self.model_parameters["parameters"] = self.config.model_parameters.value_dict
390
- self.config.simulation.model = "guts_jax"
662
+ self.config.simulation.model = "guts_constant_exposure"
391
663
 
392
664
  self.coordinates["time"] = np.linspace(0,self.t_max)
393
665
 
@@ -397,7 +669,7 @@ class GutsSimulationConstantExposure(GutsBase):
397
669
  # =======================
398
670
 
399
671
  self.coordinates["time"] = np.array([0,self.t_max])
400
- self.config.simulation.model = "guts_jax"
672
+ self.config.simulation.model = "guts_constant_exposure"
401
673
 
402
674
  self.solver = JaxSolver
403
675
 
@@ -438,7 +710,7 @@ class GutsSimulationVariableExposure(GutsSimulationConstantExposure):
438
710
  self.config.model_parameters.remove("C_0")
439
711
 
440
712
  self.model_parameters["parameters"] = self.config.model_parameters.value_dict
441
- self.config.simulation.solver_post_processing = "post_exposure"
713
+ self.config.simulation.solver_post_processing = "red_sd_post_processing"
442
714
  self.config.simulation.model = "guts_variable_exposure"
443
715
 
444
716
 
@@ -0,0 +1,31 @@
1
+ import os
2
+ import arviz as az
3
+ from guts_base.sim import GutsBase
4
+
5
+ def construct_sim_from_config(
6
+ scenario: str,
7
+ simulation_class: type,
8
+ output_path=None
9
+ ) -> GutsBase:
10
+ """Helper function to construct simulations for debugging"""
11
+ sim = simulation_class(f"scenarios/{scenario}/settings.cfg")
12
+
13
+ # this sets a different output directory
14
+ if output_path is not None:
15
+ p = output_path / sim.config.case_study.name / "results" / sim.config.case_study.scenario
16
+ sim.config.case_study.output_path = str(p)
17
+ else:
18
+ sim.config.case_study.scenario = "debug_test"
19
+ sim.setup()
20
+ return sim
21
+
22
+
23
+ def load_idata(sim: GutsBase, idata_file: str) -> GutsBase:
24
+ sim.set_inferer("numpyro")
25
+
26
+ if os.path.exists(idata_file):
27
+ sim.inferer.idata = az.from_netcdf(idata_file)
28
+ else:
29
+ sim.inferer.idata = None
30
+
31
+ return sim