guts-base 1.0.0__tar.gz → 1.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

Files changed (37) hide show
  1. {guts_base-1.0.0 → guts_base-1.0.2}/PKG-INFO +2 -1
  2. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/__init__.py +1 -1
  3. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/plot.py +34 -2
  4. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/sim/__init__.py +1 -1
  5. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/sim/base.py +228 -7
  6. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/sim/ecx.py +56 -9
  7. guts_base-1.0.2/guts_base/sim/report.py +178 -0
  8. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base.egg-info/PKG-INFO +2 -1
  9. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base.egg-info/requires.txt +1 -0
  10. {guts_base-1.0.0 → guts_base-1.0.2}/pyproject.toml +4 -3
  11. guts_base-1.0.0/guts_base/sim/report.py +0 -72
  12. {guts_base-1.0.0 → guts_base-1.0.2}/LICENSE +0 -0
  13. {guts_base-1.0.0 → guts_base-1.0.2}/README.md +0 -0
  14. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/data/__init__.py +0 -0
  15. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/data/expydb.py +0 -0
  16. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/data/generator.py +0 -0
  17. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/data/openguts.py +0 -0
  18. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/data/preprocessing.py +0 -0
  19. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/data/survival.py +0 -0
  20. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/data/time_of_death.py +0 -0
  21. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/data/utils.py +0 -0
  22. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/mod.py +0 -0
  23. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/prob.py +0 -0
  24. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/sim/constructors.py +0 -0
  25. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/sim/mempy.py +0 -0
  26. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base/sim/utils.py +0 -0
  27. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base.egg-info/SOURCES.txt +0 -0
  28. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base.egg-info/dependency_links.txt +0 -0
  29. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base.egg-info/entry_points.txt +0 -0
  30. {guts_base-1.0.0 → guts_base-1.0.2}/guts_base.egg-info/top_level.txt +0 -0
  31. {guts_base-1.0.0 → guts_base-1.0.2}/setup.cfg +0 -0
  32. {guts_base-1.0.0 → guts_base-1.0.2}/tests/test_data_import.py +0 -0
  33. {guts_base-1.0.0 → guts_base-1.0.2}/tests/test_ecx.py +0 -0
  34. {guts_base-1.0.0 → guts_base-1.0.2}/tests/test_from_pymob.py +0 -0
  35. {guts_base-1.0.0 → guts_base-1.0.2}/tests/test_scripted_simulations.py +0 -0
  36. {guts_base-1.0.0 → guts_base-1.0.2}/tests/test_simulations.py +0 -0
  37. {guts_base-1.0.0 → guts_base-1.0.2}/tests/test_symbolic_solve.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: guts_base
3
- Version: 1.0.0
3
+ Version: 1.0.2
4
4
  Summary: Basic GUTS model implementation in pymob
5
5
  Author-email: Florian Schunck <fluncki@protonmail.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -693,6 +693,7 @@ Requires-Dist: openpyxl>=3.1.3
693
693
  Requires-Dist: Bottleneck>=1.5.0
694
694
  Requires-Dist: expydb>=0.6.0
695
695
  Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.5.10
696
+ Requires-Dist: pint
696
697
  Provides-Extra: dev
697
698
  Requires-Dist: pytest>=7.3; extra == "dev"
698
699
  Requires-Dist: bumpver; extra == "dev"
@@ -4,7 +4,7 @@ from . import data
4
4
  from . import prob
5
5
  from . import plot
6
6
 
7
- __version__ = "1.0.0"
7
+ __version__ = "1.0.2"
8
8
 
9
9
  from .sim import (
10
10
  GutsBase,
@@ -1,3 +1,4 @@
1
+ from cycler import cycler
1
2
  import numpy as np
2
3
  import arviz as az
3
4
  from matplotlib import pyplot as plt
@@ -119,8 +120,7 @@ def plot_survival_multipanel(sim, results, ncols=6, title=lambda _id: _id):
119
120
  for _id, ax in zip(sim.observations.id.values, axes):
120
121
  ax.set_ylim(-0.05,1.05)
121
122
 
122
- # TODO: use time unit from observations (?)
123
- ax.set_xlabel("Time")
123
+ ax.set_xlabel(f"Time [{sim.unit_time}]")
124
124
  ax.set_ylabel("Survival")
125
125
  ax.plot(mean.time, mean.sel(id=_id).survival.T, **plot_kwargs)
126
126
  ax.fill_between(hdi.time, *hdi.sel(id=_id).survival.T, alpha=.5, **plot_kwargs)
@@ -133,6 +133,38 @@ def plot_survival_multipanel(sim, results, ncols=6, title=lambda _id: _id):
133
133
 
134
134
  return out
135
135
 
136
+ def plot_exposure_multipanel(sim, results, ncols=6, title=lambda _id: _id):
137
+
138
+ n_panels = results.sizes["id"]
139
+
140
+ nrows = int(np.ceil(n_panels / ncols))
141
+
142
+ fig, axes = plt.subplots(nrows, ncols, sharex=True, figsize=(ncols*2+2, nrows*1.5+2))
143
+ axes = axes.flatten()
144
+ mean = results
145
+
146
+ plot_kwargs = {"color": "black"}
147
+ custom_cycler = (
148
+ cycler(ls=["-", "--", ":", "-."])
149
+ )
150
+ for _id, ax in zip(sim.observations.id.values, axes):
151
+ ax.set_prop_cycle(custom_cycler)
152
+
153
+ ax.set_xlabel(f"Time [{sim.unit_time}]")
154
+ ax.set_ylabel("Exposure")
155
+ for expo in sim.coordinates[sim._exposure_dimension]:
156
+ ax.plot(
157
+ mean.time, mean.sel({"id":_id, sim._exposure_dimension: expo}).exposure,
158
+ **plot_kwargs, label=f"Exposure: {expo}"
159
+ )
160
+ ax.set_title(title(_id))
161
+ ax.legend(fontsize=7)
162
+
163
+ out = f"{sim.output_path}/exposure_multipanel.png"
164
+ fig.tight_layout()
165
+ fig.savefig(out)
166
+
167
+ return out
136
168
 
137
169
  def multipanel_title(sim, _id):
138
170
  oid = sim.observations.sel(id=_id)
@@ -10,7 +10,7 @@ from .base import (
10
10
  )
11
11
 
12
12
  from .ecx import ECxEstimator, LPxEstimator
13
- from .report import GutsReport
13
+ from .report import GutsReport, ParameterConverter
14
14
 
15
15
  from .mempy import PymobSimulator
16
16
  from .utils import (
@@ -3,16 +3,19 @@ import glob
3
3
  from functools import partial
4
4
  from copy import deepcopy
5
5
  import importlib
6
+ import tempfile
6
7
  import warnings
7
8
  import numpy as np
8
9
  import xarray as xr
9
10
  from diffrax import Dopri5
10
11
  from typing import Literal, Optional, List, Dict, Mapping, Sequence, Tuple
11
- import tempfile
12
12
  import pandas as pd
13
+ import pint
13
14
 
14
15
  from pymob import SimulationBase
15
- from pymob.sim.config import DataVariable, Param, string_to_list, NumericArray
16
+ from pymob.sim.config import (
17
+ DataVariable, Param, string_to_list, string_to_dict, NumericArray
18
+ )
16
19
 
17
20
  from pymob.solvers import JaxSolver
18
21
  from pymob.solvers.base import rect_interpolation
@@ -29,6 +32,8 @@ from guts_base.data import (
29
32
  )
30
33
  from guts_base.sim.report import GutsReport
31
34
 
35
+ ureg = pint.UnitRegistry()
36
+
32
37
  class GutsBase(SimulationBase):
33
38
  """
34
39
  Initializes GUTS models from a variety of data sources
@@ -43,6 +48,7 @@ class GutsBase(SimulationBase):
43
48
  Report = GutsReport
44
49
  results_interpolation: Tuple[float,float,int] = (np.nan, np.nan, 100)
45
50
  _skip_data_processing: bool = False
51
+ ecx_mode: Literal["mean", "draws"] = "mean"
46
52
 
47
53
  def initialize(self, input: Optional[Dict] = None):
48
54
  """Initiaization goes through a couple of steps:
@@ -90,8 +96,6 @@ class GutsBase(SimulationBase):
90
96
  self.model = self._model_class._rhs_jax
91
97
  self.solver_post_processing = self._model_class._solver_post_processing
92
98
 
93
- self.ecx_mode: Literal["mean", "draws"] = "mean"
94
-
95
99
  self.unit_time: Literal["day", "hour", "minute", "second"] = "day"
96
100
  if hasattr(self.config.simulation, "unit_time"):
97
101
  self.unit_time = self.config.simulation.unit_time # type: ignore
@@ -107,6 +111,8 @@ class GutsBase(SimulationBase):
107
111
  int(results_interpolation_string[2])
108
112
  )
109
113
 
114
+ self._determine_background_mortality_parameter()
115
+
110
116
  def prepare_simulation_input(self):
111
117
  x_in = self.parse_input(input="x_in", reference_data=self.observations, drop_dims=[])
112
118
  y0 = self.parse_input(input="y0", reference_data=self.observations, drop_dims=["time"])
@@ -132,7 +138,18 @@ class GutsBase(SimulationBase):
132
138
  Timeseries.name == {exposure_path}
133
139
  )
134
140
 
135
- def read_data(self):
141
+ def read_data(self) -> xr.Dataset:
142
+ """Reads data and returns an xarray.Dataset.
143
+
144
+ GutsBase supports reading data from
145
+ - netcdf (.nc) files
146
+ - expyDB (SQLite databases)
147
+ - excel (directories of excel files)
148
+
149
+ expyDB and excel operate by converting data to xarrays while netcdf directly
150
+ loads xarray Datasets. For highest control over your data, you should always use
151
+ .nc files, because they are imported as-is.
152
+ """
136
153
  # TODO: Update to new INTERVENTION MODEL
137
154
  dataset = str(self.config.case_study.observations)
138
155
 
@@ -215,11 +232,9 @@ class GutsBase(SimulationBase):
215
232
  Currently these methods, change datasets, indices, etc. in-place.
216
233
  This is convenient, but more difficult to re-arragen with other methods
217
234
  TODO: Make these methods static if possible
218
-
219
235
  """
220
236
  self._create_indices()
221
237
  self._indices_to_dimensions()
222
-
223
238
  # define tolerance based on the sovler tolerance
224
239
  self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
225
240
 
@@ -237,6 +252,89 @@ class GutsBase(SimulationBase):
237
252
  if "exposure" in self.observations:
238
253
  self.config.data_structure.exposure.observed=False
239
254
 
255
+ def _convert_exposure_units(self):
256
+ """
257
+ TODO: Here I need to decide what to do. Work with rescaled units is dangerous
258
+ because fitting might be complicated with weird quantities.
259
+ It would be better to rescale output parameters
260
+ """
261
+ if not hasattr(self.config.simulation, "unit_exposure"):
262
+ return
263
+
264
+ units, unit_conversion_factors = self._convert_units(
265
+ self.observations.unit.reset_coords("unit", drop=True),
266
+ target_units=self.config.simulation.unit_exposure
267
+ )
268
+
269
+ self.observations = self.observations.assign_coords({
270
+ "unit": units,
271
+ "unit_conversion_factors": unit_conversion_factors
272
+ })
273
+
274
+ self.observations[self.config.simulation.substance] =\
275
+ self.observations[self.config.simulation.substance] * unit_conversion_factors
276
+
277
+ @staticmethod
278
+ def _unique_unsorted(values):
279
+ _, index = np.unique(values, return_index=True)
280
+ return tuple(np.array(values)[sorted(index)])
281
+
282
+ @staticmethod
283
+ def _convert_units(
284
+ units: xr.DataArray,
285
+ target_units: Dict[str,str]
286
+ ) -> Tuple[xr.DataArray, xr.DataArray]:
287
+ """Converts units of values associated with the exposure dimension
288
+ TODO: Converting before inference could be problem for the calibration, because
289
+ it is usually good if the values are both not too small and not too large
290
+ """
291
+
292
+ if len(units.dims) != 1:
293
+ raise GutsBaseError(
294
+ "GutsBase_convert_exposure_units only supports 1 dimensional exposure units"
295
+ )
296
+
297
+ _dim = units.dims[0]
298
+ _coordinates = units.coords[_dim]
299
+
300
+ converted_units = {}
301
+ _target_units = {}
302
+
303
+ for coord in _coordinates.values:
304
+ unit = str(units.sel({_dim: coord}).values)
305
+
306
+ # get item from config
307
+ unit_mapping = string_to_dict(target_units)
308
+ # split transformation expression from target expression
309
+ transform, target = unit_mapping[coord].split("->")
310
+ # insert unit from observations coordinates
311
+ transform = transform.format(x=unit)
312
+
313
+ # parse and convert units
314
+ new_unit = ureg.parse_expression(transform).to(target)
315
+ converted_units.update({coord: new_unit})
316
+ _target_units.update({coord: target})
317
+
318
+ _units = {k: f"{cu.units:C}" for k, cu in converted_units.items()}
319
+
320
+ # assert whether the converted units are the same as the target units
321
+ # so the target units can be used, because the converted units may reduce
322
+ # to dimensionless quantities.
323
+ if not all([
324
+ cu.units == ureg.parse_expression(tu)
325
+ for cu, tu in zip(converted_units.values(), _target_units.values())
326
+ ]):
327
+ raise GutsBaseError(
328
+ f"Mismatch between target units {_target_units} and converted units " +
329
+ f"{converted_units}."
330
+ )
331
+
332
+ _conversion_factors = {k: cu.magnitude for k, cu in converted_units.items()}
333
+ new_unit_coords = xr.Dataset(_target_units).to_array(dim=_dim)
334
+ conversion_factor_coords = xr.Dataset(_conversion_factors).to_array(dim=_dim)
335
+
336
+ return new_unit_coords, conversion_factor_coords
337
+
240
338
  def _create_indices(self):
241
339
  """Use if indices should be added to sim.indices and sim.observations"""
242
340
  pass
@@ -273,6 +371,18 @@ class GutsBase(SimulationBase):
273
371
  "is calculated without a dense time resolution, the estimates can be biased!"
274
372
  ))
275
373
 
374
+ def _determine_background_mortality_parameter(self):
375
+ if "hb" in self.config.model_parameters.all:
376
+ self.background_mortality = "hb"
377
+ elif "h_b" in self.config.model_parameters.all:
378
+ self.background_mortality = "h_b"
379
+ else:
380
+ raise GutsBaseError(
381
+ "The background mortality parameter is not defined as 'hb' or 'h_b'. " +
382
+ f"The defined parameters are {self.config.model_parameters.all}"
383
+ )
384
+
385
+
276
386
  def recompute_posterior(self):
277
387
  """This function interpolates the posterior with a given resolution
278
388
  posterior_predictions calculate proper survival predictions for the
@@ -414,6 +524,28 @@ class GutsBase(SimulationBase):
414
524
  arrays = {"survival": survival_array}
415
525
  return xr.Dataset(arrays)
416
526
 
527
+ @property
528
+ def _exposure_dimension(self):
529
+ standard_dims = [
530
+ self.config.simulation.batch_dimension,
531
+ self.config.simulation.x_dimension
532
+ ]
533
+
534
+ extra_dims = []
535
+ for k in self.config.data_structure["exposure"].dimensions:
536
+ if k not in standard_dims:
537
+ extra_dims.append(k)
538
+ else:
539
+ pass
540
+
541
+ if len(extra_dims) > 1:
542
+ raise GutsBaseError(
543
+ "Guts Base can currently only handle one exposure dimension beside" +
544
+ "the standard dimensions."
545
+ )
546
+ else:
547
+ return extra_dims[0]
548
+
417
549
 
418
550
  def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
419
551
  """This method will take an existing coordinate of a dataset that has the same
@@ -452,6 +584,7 @@ class GutsBase(SimulationBase):
452
584
  # apply automatic broadcasting to increase the size of the new dimension
453
585
  # data_var_np1_d = data_var * dummy_3d
454
586
  data_var_np1_d = data_var * dummy_categorical
587
+ data_var_np1_d.attrs = data_var.attrs
455
588
 
456
589
  # annotate coordinates of the new dimension
457
590
  data_var_np1_d = data_var_np1_d.assign_coords({
@@ -464,6 +597,40 @@ class GutsBase(SimulationBase):
464
597
 
465
598
  return obs
466
599
 
600
+ def map_batch_coordinates_to_extra_dim_coordinates(
601
+ self,
602
+ observations: xr.Dataset,
603
+ target_dimension: str,
604
+ coordinates: Optional[List[str]] = None
605
+ ) -> xr.Dataset:
606
+ """Iterates over coordinates and reduces those coordinates to the new dimension
607
+ which have the same number of unique coordinates as the new dimension has coordinates
608
+ """
609
+ if coordinates is None:
610
+ coordinates = list(observations.coords.keys())
611
+
612
+ for key, coord in observations.coords.items():
613
+ # skips coords, if not specified in coordinates
614
+ if key not in coordinates:
615
+ continue
616
+
617
+ if self.config.simulation.batch_dimension in coord.dims and key not in observations.dims:
618
+ if len(coord.dims) == 1:
619
+ dim_coords = self._unique_unsorted(coord.values)
620
+ if len(dim_coords) == len(observations[target_dimension]):
621
+ observations[key] = (target_dimension, list(dim_coords))
622
+ else:
623
+ pass
624
+ else:
625
+ warnings.warn(
626
+ f"Coordinate '{key}' is has dimensions {coord.dims}. " +
627
+ "Mapping coordinates with more than 1 dimension to the extra " +
628
+ f"dimension '{target_dimension}' is not supported yet."
629
+ )
630
+ pass
631
+
632
+ return observations
633
+
467
634
 
468
635
  def reduce_dimension_to_batch_like_coordinate(self, dimension, variables):
469
636
  """This method takes an existing dimension from a N-D array and reduces it to an
@@ -639,6 +806,60 @@ class GutsBase(SimulationBase):
639
806
  self.config.simulation.skip_data_processing = False
640
807
  super().export(directory=directory)
641
808
 
809
+ def copy(self):
810
+ """Creates a copy of a SimulationBase object by deepcopying all loose references
811
+ TODO: If this works out well integrate into pymob. I have the feeling there will
812
+ still be some problems down the line.
813
+ """
814
+ with warnings.catch_warnings(action="ignore"):
815
+ # using the context manager here will not work, because it will prematurely
816
+ # remove the tempdir and its contents. This interacts badly with netcdf files.
817
+ # which are only transferred into memory when the contents are needed.
818
+ tempdir = tempfile.TemporaryDirectory()
819
+ self.export(directory=tempdir.name)
820
+ print(os.listdir(tempdir.name))
821
+ sim_copy = type(self).from_directory(tempdir.name)
822
+
823
+ # set this attribute so that it can be cleaned up after the copied sim is
824
+ # no longer necessary.
825
+ sim_copy.tempdir = tempdir
826
+
827
+ return sim_copy
828
+
829
+ @staticmethod
830
+ def _condition_posterior(
831
+ posterior: xr.Dataset,
832
+ parameter: str,
833
+ value: float,
834
+ exception: Literal["raise", "warn"]="raise"
835
+ ):
836
+ """TODO: Provide this method also to SimulationBase"""
837
+ if parameter not in posterior:
838
+ keys = list(posterior.keys())
839
+ msg = (
840
+ f"{parameter=} was not found in the posterior {keys=}. " +
841
+ f"Unable to condition the posterior to {value=}. Have you "+
842
+ "requested the correct parameter for conditioning?"
843
+ )
844
+
845
+ if exception == "raise":
846
+ raise GutsBaseError(msg)
847
+ elif exception == "warn":
848
+ warnings.warn(msg)
849
+ else:
850
+ raise GutsBaseError(
851
+ "Use one of exception='raise' or exception='warn'. " +
852
+ f"Currently using {exception=}"
853
+ )
854
+
855
+ # broadcast value so that methods like drawing samples and hdi still work
856
+ broadcasted_value = np.full_like(posterior[parameter], value)
857
+
858
+ return posterior.assign({
859
+ parameter: (posterior[parameter].dims, broadcasted_value)
860
+ })
861
+
862
+
642
863
  class GutsSimulationConstantExposure(GutsBase):
643
864
  t_max = 10
644
865
  def initialize_from_script(self):
@@ -23,6 +23,25 @@ class ECxEstimator:
23
23
  This must be a pymob.SimulationBase object. If the ECxEstimator.estimate method
24
24
  is used with the modes 'draw' or 'mean'
25
25
 
26
+ effect : str
27
+ The data variable for which the effect concentration is computed. This is one
28
+ of sim.observations and sim.results
29
+
30
+ x : float
31
+ Effect level. This is the level of the effect, for which the concentration is
32
+ computed.
33
+
34
+ time : float
35
+ Time at which the ECx is computed
36
+
37
+ x_in : xr.Dataset
38
+ The model input 'x_in' for which the effect is computed.
39
+
40
+ conditionals_posterior : Dict
41
+ Dictionary that overwrites values in the posterior. This is useful if for instance
42
+ background mortality should be set to a fixed value (e.g. zero). Consequently this
43
+ setting does not take effect in estimation mode 'manual' but only for mean and
44
+ draws. Defaults to an empty dict (no conditions applied).
26
45
  """
27
46
  _name = "EC"
28
47
  _parameter_msg = (
@@ -40,23 +59,30 @@ class ECxEstimator:
40
59
  x: float,
41
60
  time: float,
42
61
  x_in: xr.Dataset,
62
+ conditions_posterior: Dict[str, float] = {}
43
63
  ):
44
64
  self.sim = sim.copy()
45
65
  self.time = time
46
66
  self.x = x
47
67
  self.effect = effect
48
68
  self._mode = None
69
+ self._conditions_posterior = conditions_posterior
49
70
 
50
71
  # creates an empty observation dataset with the coordinates of the
51
72
  # original observations (especially time), except the ID, which is overwritten
52
73
  # and taken from the x_in dataset
53
74
  pseudo_obs = self.sim.observations.isel(id=[0])
54
- pseudo_obs = pseudo_obs.drop(["exposure","survival"])
75
+ pseudo_obs = pseudo_obs.drop([v for v in pseudo_obs.data_vars.keys()])
55
76
  pseudo_obs["id"] = x_in["id"]
56
77
 
57
78
  self.sim.config.data_structure.survival.observed = False
58
79
  self.sim.observations = pseudo_obs
59
80
 
81
+ # overwrite x_in to make sure that parse_input takes x_in from exposure and
82
+ # does not use the string that is tied to another data variable which was
83
+ # originally present
84
+ self.sim.config.simulation.x_in = ["exposure=exposure"]
85
+
60
86
  # ensure correct coordinate order with x_in and raise errors early
61
87
  self.sim.model_parameters["x_in"] = self.sim.parse_input("x_in", x_in)
62
88
 
@@ -69,7 +95,7 @@ class ECxEstimator:
69
95
  self.sim.coordinates["time"], np.array(time, ndmin=1)
70
96
  ]))
71
97
 
72
- self.sim.model_parameters["y0"] = self.sim.parse_input("y0", drop_dims="time")
98
+ self.sim.model_parameters["y0"] = self.sim.parse_input("y0", drop_dims=["time"])
73
99
  self.sim.dispatch_constructor()
74
100
 
75
101
  self.results = pd.Series({
@@ -84,6 +110,12 @@ class ECxEstimator:
84
110
  self.figure_profile_and_effect = None
85
111
  self.figure_loss_curve = None
86
112
 
113
+ self.condition_posterior_parameters(conditions=conditions_posterior)
114
+
115
+ def __del__(self):
116
+ if hasattr(self.sim, "tempdir"):
117
+ self.sim.tempdir.cleanup()
118
+
87
119
 
88
120
  def _assert_posterior(self):
89
121
  try:
@@ -94,7 +126,14 @@ class ECxEstimator:
94
126
  "('sim.inferer.idata.posterior'). " + self._parameter_msg
95
127
  )
96
128
 
97
-
129
+ def condition_posterior_parameters(self, conditions: Dict[str, float]):
130
+ for parameter, value in conditions.items():
131
+ self.sim.inferer.idata.posterior = self.sim._condition_posterior(
132
+ posterior=self.sim.inferer.idata.posterior,
133
+ parameter=parameter,
134
+ value=value,
135
+ exception="raise",
136
+ )
98
137
 
99
138
  def _evaluate(self, factor, theta):
100
139
  evaluator = self.sim.dispatch(
@@ -223,23 +262,31 @@ class ECxEstimator:
223
262
  else:
224
263
  pass
225
264
 
226
- warnings.warn(
227
- "Values passed to 'parameters' don't have an effect in mode='draws'"
228
- )
265
+ if parameters is not None:
266
+ warnings.warn(
267
+ "Values passed to 'parameters' don't have an effect in mode='draws'"
268
+ )
229
269
 
230
270
  elif mode == "mean":
231
271
  self._assert_posterior()
232
272
 
233
273
  draws = 1
234
274
 
235
- warnings.warn(
236
- "Values passed to 'parameters' don't have an effect in mode='draws'"
237
- )
275
+ if parameters is not None:
276
+ warnings.warn(
277
+ "Values passed to 'parameters' don't have an effect in mode='draws'"
278
+ )
238
279
 
239
280
  elif mode == "manual":
240
281
  draws = 1
241
282
  if parameters is None:
242
283
  raise GutsBaseError(self._parameter_msg)
284
+
285
+ if self._conditions_posterior is not None:
286
+ warnings.warn(
287
+ "Conditions applied to the posterior do not take effect in mode 'manual'"
288
+ )
289
+
243
290
  else:
244
291
  raise GutsBaseError(
245
292
  f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
@@ -0,0 +1,178 @@
1
+ import os
2
+ import itertools as it
3
+ from typing import List, Dict
4
+ import numpy as np
5
+ import pandas as pd
6
+ import xarray as xr
7
+
8
+ from pymob import SimulationBase
9
+ from pymob.sim.report import Report, reporting
10
+
11
+ from guts_base.plot import plot_survival_multipanel, plot_exposure_multipanel
12
+ from guts_base.sim.ecx import ECxEstimator
13
+
14
+ class GutsReport(Report):
15
+ ecx_estimates_times: List = [1, 2, 4, 10]
16
+ ecx_estimates_x: List = [0.1, 0.25, 0.5, 0.75, 0.9]
17
+ set_background_mortality_to_zero = True
18
+
19
+ def additional_reports(self, sim: "SimulationBase"):
20
+ super().additional_reports(sim=sim)
21
+ self.model_fits(sim)
22
+ self.LCx_estimates(sim)
23
+
24
+ @reporting
25
+ def model_input(self, sim: SimulationBase):
26
+ self._write("### Exposure conditions")
27
+ self._write(
28
+ "These are the exposure conditions that were assumed for parameter inference. "+
29
+ "Double check if they are aligned with your expectations. Especially short " +
30
+ "exposure durations may not be perceivable in this view. In this case it is "+
31
+ "recommended to have a look at the exposure conditions in the numerical "+
32
+ "tables provided below."
33
+ )
34
+
35
+ out_mp = plot_exposure_multipanel(
36
+ sim=sim,
37
+ results=sim.model_parameters["x_in"],
38
+ ncols=6,
39
+ )
40
+
41
+ lab = self._label.format(placeholder='exposure')
42
+ self._write(f"![Exposure model fits.\label{{{lab}}}]({os.path.basename(out_mp)})")
43
+
44
+ return out_mp
45
+
46
+ @reporting
47
+ def model_fits(self, sim: SimulationBase):
48
+ self._write("### Survival model fits")
49
+
50
+ self._write(
51
+ "Survival observations on the unit scale with model fits. The solid line is "+
52
+ "the average of individual survival probability predictions from multiple "+
53
+ "draws from the posterior parameter distribution. In case a point estimator "+
54
+ "was used the solid line indicates the best fit. Grey uncertainty intervals "+
55
+ "indicate the uncertainty in survival probabilities. Note that the survival "+
56
+ "probabilities indicate the probability for a given individual or population "+
57
+ "to be alive when observed at time t."
58
+ )
59
+
60
+ out_mp = plot_survival_multipanel(
61
+ sim=sim,
62
+ results=sim.inferer.idata.posterior_model_fits,
63
+ ncols=6,
64
+ )
65
+
66
+ lab = self._label.format(placeholder='survival_fits')
67
+ self._write(f"![Surival model fits.\label{{{lab}}}]({os.path.basename(out_mp)})")
68
+
69
+ return out_mp
70
+
71
+ @reporting
72
+ def LCx_estimates(self, sim):
73
+ X = self.ecx_estimates_x
74
+ T = self.ecx_estimates_times
75
+ P = sim.predefined_scenarios()
76
+
77
+ if self.set_background_mortality_to_zero:
78
+ conditions = {sim.background_mortality: 0.0}
79
+
80
+ estimates = pd.DataFrame(
81
+ it.product(X, T, P.keys()),
82
+ columns=["x", "time", "scenario"]
83
+ )
84
+
85
+ ecx = []
86
+
87
+ for i, row in estimates.iterrows():
88
+ ecx_estimator = ECxEstimator(
89
+ sim=sim,
90
+ effect="survival",
91
+ x=row.x,
92
+ time=row.time,
93
+ x_in=P[row.scenario],
94
+ conditions_posterior=conditions
95
+ )
96
+
97
+ ecx_estimator.estimate(
98
+ mode=sim.ecx_mode,
99
+ draws=250,
100
+ show_plot=False
101
+ )
102
+
103
+ ecx.append(ecx_estimator.results.copy(deep=True))
104
+
105
+ # remove ecx_estimator to not blow up temp files.
106
+ # This triggers the __del__ method of ECxEstimator,
107
+ # which cleans up a temporary directory if it was
108
+ # created during init
109
+ del ecx_estimator
110
+
111
+ results = pd.DataFrame(ecx)
112
+ estimates[results.columns] = results
113
+
114
+ out = self._write_table(tab=estimates, label_insert="$LC_x$ estimates")
115
+
116
+ return out
117
+
118
+
119
+ class ParameterConverter:
120
+ def __init__(
121
+ self,
122
+ sim: SimulationBase,
123
+ ):
124
+ self.sim = sim.copy()
125
+
126
+ # this converts the units of exposure in the copied simulation
127
+ # and scales the exposure dataarray
128
+ self.sim._convert_exposure_units()
129
+ self.convert_parameters()
130
+ self.sim.prepare_simulation_input()
131
+ self.sim.dispatch_constructor()
132
+
133
+ # self.plot_exposure_and_effect(self.sim, sim, _id=7, data_var="D")
134
+
135
+ # if parameters are not rescaled this method should raise an error
136
+ self.validate_parameter_conversion_default_params(sim_copy=self.sim, sim_orig=sim)
137
+ self.validate_parameter_conversion_posterior_mean(sim_copy=self.sim, sim_orig=sim)
138
+ self.validate_parameter_conversion_posterior_map(sim_copy=self.sim, sim_orig=sim)
139
+
140
+ def convert_parameters(self):
141
+ raise NotImplementedError
142
+
143
+
144
+ @staticmethod
145
+ def plot_exposure_and_effect(sim_copy, sim_orig, _id=1, data_var="survival"):
146
+ from matplotlib import pyplot as plt
147
+ fig, (ax1, ax2) = plt.subplots(2,1)
148
+ results_copy = sim_copy.evaluate(parameters=sim_copy.config.model_parameters.value_dict)
149
+ results_orig = sim_orig.evaluate(parameters=sim_orig.config.model_parameters.value_dict)
150
+
151
+ ax1.plot(results_orig.time, results_orig["exposure"].isel(id=_id), color="red", label="unscaled")
152
+ ax1.plot(results_copy.time, results_copy["exposure"].isel(id=_id), color="blue", ls="--", label="scaled")
153
+ ax2.plot(results_orig.time, results_orig[data_var].isel(id=_id), color="red", label="unscaled")
154
+ ax2.plot(results_copy.time, results_copy[data_var].isel(id=_id), color="blue", ls="--", label="scaled")
155
+ ax1.legend()
156
+ ax2.legend()
157
+ return fig
158
+
159
+ @staticmethod
160
+ def validate_parameter_conversion_default_params(sim_copy, sim_orig):
161
+ results_copy = sim_copy.evaluate(parameters=sim_copy.config.model_parameters.value_dict)
162
+ results_orig = sim_orig.evaluate(parameters=sim_orig.config.model_parameters.value_dict)
163
+
164
+ np.testing.assert_allclose(results_copy.H, results_orig.H, atol=0.001, rtol=0.001)
165
+
166
+ @staticmethod
167
+ def validate_parameter_conversion_posterior_mean(sim_copy, sim_orig):
168
+ results_copy = sim_copy.evaluate(parameters=sim_copy.point_estimate("mean", to="dict"))
169
+ results_orig = sim_orig.evaluate(parameters=sim_orig.point_estimate("mean", to="dict"))
170
+
171
+ np.testing.assert_allclose(results_copy.H, results_orig.H, atol=0.001, rtol=0.001)
172
+
173
+ @staticmethod
174
+ def validate_parameter_conversion_posterior_map(sim_copy, sim_orig):
175
+ results_copy = sim_copy.evaluate(parameters=sim_copy.point_estimate("map", to="dict"))
176
+ results_orig = sim_orig.evaluate(parameters=sim_orig.point_estimate("map", to="dict"))
177
+
178
+ np.testing.assert_allclose(results_copy.H, results_orig.H, atol=0.001, rtol=0.001)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: guts_base
3
- Version: 1.0.0
3
+ Version: 1.0.2
4
4
  Summary: Basic GUTS model implementation in pymob
5
5
  Author-email: Florian Schunck <fluncki@protonmail.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -693,6 +693,7 @@ Requires-Dist: openpyxl>=3.1.3
693
693
  Requires-Dist: Bottleneck>=1.5.0
694
694
  Requires-Dist: expydb>=0.6.0
695
695
  Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.5.10
696
+ Requires-Dist: pint
696
697
  Provides-Extra: dev
697
698
  Requires-Dist: pytest>=7.3; extra == "dev"
698
699
  Requires-Dist: bumpver; extra == "dev"
@@ -2,6 +2,7 @@ openpyxl>=3.1.3
2
2
  Bottleneck>=1.5.0
3
3
  expydb>=0.6.0
4
4
  pymob[interactive,numpyro]<0.6.0,>=0.5.10
5
+ pint
5
6
 
6
7
  [dev]
7
8
  pytest>=7.3
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "guts_base"
7
- version = "1.0.0"
7
+ version = "1.0.2"
8
8
  authors = [
9
9
  { name="Florian Schunck", email="fluncki@protonmail.com" },
10
10
  ]
@@ -15,7 +15,8 @@ dependencies=[
15
15
  "openpyxl >= 3.1.3",
16
16
  "Bottleneck >= 1.5.0",
17
17
  "expydb >= 0.6.0",
18
- "pymob[numpyro,interactive] >= 0.5.10, < 0.6.0"
18
+ "pymob[numpyro,interactive] >= 0.5.10, < 0.6.0",
19
+ "pint",
19
20
  ]
20
21
  license = {file = "LICENSE"}
21
22
  classifiers = [
@@ -48,7 +49,7 @@ import-openguts = "guts_base.data.openguts:create_database_and_import_data"
48
49
  convert-time-of-death-to-openguts = "guts_base.data.time_of_death:time_of_death_to_openguts"
49
50
 
50
51
  [tool.bumpver]
51
- current_version = "1.0.0"
52
+ current_version = "1.0.2"
52
53
  version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]"
53
54
  commit_message = "bump version {old_version} -> {new_version}"
54
55
  tag_message = "{new_version}"
@@ -1,72 +0,0 @@
1
- import os
2
- import itertools as it
3
- from typing import List
4
- import pandas as pd
5
-
6
- from pymob import SimulationBase
7
- from pymob.sim.report import Report, reporting
8
-
9
- from guts_base.plot import plot_survival_multipanel
10
- from guts_base.sim.ecx import ECxEstimator
11
-
12
- class GutsReport(Report):
13
- ecx_estimates_times: List = [1, 2, 4, 10]
14
- ecx_estimates_x: List = [0.1, 0.25, 0.5, 0.75, 0.9]
15
-
16
- def additional_reports(self, sim: "SimulationBase"):
17
- super().additional_reports(sim=sim)
18
- self.model_fits(sim)
19
- self.LCx_estimates(sim)
20
-
21
- @reporting
22
- def model_fits(self, sim: SimulationBase):
23
- self._write("### Survival model fits")
24
-
25
- out_mp = plot_survival_multipanel(
26
- sim=sim,
27
- results=sim.inferer.idata.posterior_model_fits,
28
- ncols=6,
29
- )
30
-
31
- lab = self._label.format(placeholder='survival_fits')
32
- self._write(f"![Surival model fits.\label{{{lab}}}]({os.path.basename(out_mp)})")
33
-
34
- return out_mp
35
-
36
-
37
- @reporting
38
- def LCx_estimates(self, sim):
39
- X = self.ecx_estimates_x
40
- T = self.ecx_estimates_times
41
- P = sim.predefined_scenarios()
42
-
43
- estimates = pd.DataFrame(
44
- it.product(X, T, P.keys()),
45
- columns=["x", "time", "scenario"]
46
- )
47
-
48
- ecx = []
49
-
50
- for i, row in estimates.iterrows():
51
- ecx_estimator = ECxEstimator(
52
- sim=sim,
53
- effect="survival",
54
- x=row.x,
55
- time=row.time,
56
- x_in=P[row.scenario],
57
- )
58
-
59
- ecx_estimator.estimate(
60
- mode=sim.ecx_mode,
61
- draws=250,
62
- show_plot=False
63
- )
64
-
65
- ecx.append(ecx_estimator.results)
66
-
67
- results = pd.DataFrame(ecx)
68
- estimates[results.columns] = results
69
-
70
- out = self._write_table(tab=estimates, label_insert="$LC_x$ estimates")
71
-
72
- return out
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes