guts-base 1.0.5__tar.gz → 1.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

Files changed (36) hide show
  1. {guts_base-1.0.5 → guts_base-1.0.6}/PKG-INFO +1 -1
  2. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/__init__.py +1 -1
  3. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/base.py +24 -26
  4. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/report.py +168 -1
  5. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/PKG-INFO +1 -1
  6. {guts_base-1.0.5 → guts_base-1.0.6}/pyproject.toml +2 -2
  7. {guts_base-1.0.5 → guts_base-1.0.6}/LICENSE +0 -0
  8. {guts_base-1.0.5 → guts_base-1.0.6}/README.md +0 -0
  9. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/__init__.py +0 -0
  10. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/expydb.py +0 -0
  11. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/generator.py +0 -0
  12. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/openguts.py +0 -0
  13. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/preprocessing.py +0 -0
  14. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/survival.py +0 -0
  15. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/time_of_death.py +0 -0
  16. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/utils.py +0 -0
  17. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/mod.py +0 -0
  18. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/plot.py +0 -0
  19. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/prob.py +0 -0
  20. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/__init__.py +0 -0
  21. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/constructors.py +0 -0
  22. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/ecx.py +0 -0
  23. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/mempy.py +0 -0
  24. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/utils.py +0 -0
  25. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/SOURCES.txt +0 -0
  26. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/dependency_links.txt +0 -0
  27. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/entry_points.txt +0 -0
  28. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/requires.txt +0 -0
  29. {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/top_level.txt +0 -0
  30. {guts_base-1.0.5 → guts_base-1.0.6}/setup.cfg +0 -0
  31. {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_data_import.py +0 -0
  32. {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_ecx.py +0 -0
  33. {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_from_pymob.py +0 -0
  34. {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_scripted_simulations.py +0 -0
  35. {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_simulations.py +0 -0
  36. {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_symbolic_solve.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: guts_base
3
- Version: 1.0.5
3
+ Version: 1.0.6
4
4
  Summary: Basic GUTS model implementation in pymob
5
5
  Author-email: Florian Schunck <fluncki@protonmail.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ from . import data
4
4
  from . import prob
5
5
  from . import plot
6
6
 
7
- __version__ = "1.0.5"
7
+ __version__ = "1.0.6"
8
8
 
9
9
  from .sim import (
10
10
  GutsBase,
@@ -101,7 +101,12 @@ class GutsBase(SimulationBase):
101
101
  self.unit_time = self.config.simulation.unit_time # type: ignore
102
102
 
103
103
  if hasattr(self.config.simulation, "skip_data_processing"):
104
- self._skip_data_processing = bool(self.config.simulation.skip_data_processing) # type: ignore
104
+ self._skip_data_processing = not (
105
+ self.config.simulation.skip_data_processing == "False" or
106
+ self.config.simulation.skip_data_processing == "false" or # type: ignore
107
+ self.config.simulation.skip_data_processing == "" or # type: ignore
108
+ self.config.simulation.skip_data_processing == 0 # type: ignore
109
+ )
105
110
 
106
111
  if hasattr(self.config.simulation, "results_interpolation"):
107
112
  results_interpolation_string = string_to_list(self.config.simulation.results_interpolation)
@@ -418,8 +423,6 @@ class GutsBase(SimulationBase):
418
423
 
419
424
  self.dispatch_constructor()
420
425
  _ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
421
- self.inferer.store_results(output=f"{self.output_path}/numpyro_posterior_interp.nc") # type: ignore
422
- self.logger.info("Recomputed posterior and storing in `numpyro_posterior_interp.nc`")
423
426
 
424
427
 
425
428
  def prior_predictive_checks(self, **plot_kwargs):
@@ -430,9 +433,12 @@ class GutsBase(SimulationBase):
430
433
  def posterior_predictive_checks(self, **plot_kwargs):
431
434
  super().posterior_predictive_checks(**plot_kwargs)
432
435
 
433
- self.recompute_posterior()
436
+ sim_copy = self.copy()
437
+ sim_copy.recompute_posterior()
434
438
  # TODO: Include posterior_predictive group once the survival predictions are correctly working
435
- self._plot.plot_posterior_predictions(self, data_vars=["survival"], groups=["posterior_model_fits"])
439
+ sim_copy._plot.plot_posterior_predictions(
440
+ sim_copy, data_vars=["survival"], groups=["posterior_model_fits"]
441
+ )
436
442
 
437
443
 
438
444
  def plot(self, results):
@@ -802,27 +808,19 @@ class GutsBase(SimulationBase):
802
808
 
803
809
  self.dispatch_constructor()
804
810
 
805
- def export(self, directory: Optional[str] = None):
806
- self.config.simulation.skip_data_processing = False
807
- super().export(directory=directory)
808
-
809
- def copy(self):
810
- """Creates a copy of a SimulationBase object by deepcopying all loose references
811
- TODO: If this works out well integrate into pymob. I have the feeling there will
812
- still be some problems down the line.
813
- """
814
- with warnings.catch_warnings(action="ignore"):
815
- # create the tempdir in the output path, because on the remote cluster
816
- # the default temporary directory may not have enough space. Using the output
817
- # path here resolves any path issues.
818
- tmp_basedir = os.path.join(self.output_path, "tmp")
819
- os.makedirs(tmp_basedir, exist_ok=True)
820
- with tempfile.TemporaryDirectory(dir=tmp_basedir) as name:
821
- self.export(directory=name)
822
- print(f"Exported files ({name}):", os.listdir(name))
823
- sim_copy = type(self).from_directory(name)
824
-
825
- return sim_copy
811
+ def export(self, directory: Optional[str] = None, mode="export", skip_data_processing=True):
812
+ self.config.simulation.skip_data_processing = skip_data_processing
813
+ super().export(directory=directory, mode=mode)
814
+
815
+ def export_to_scenario(self, scenario, force=False):
816
+ """Exports a case study as a new scenario for running inference"""
817
+ self.config.case_study.scenario = scenario
818
+ self.config.case_study.data = None
819
+ self.config.case_study.output = None
820
+ self.config.case_study.scenario_path_override = None
821
+ self.config.simulation.skip_data_processing = True
822
+ self.save_observations(filename=f"observations_{scenario}.nc", force=force)
823
+ self.config.save(force=force)
826
824
 
827
825
  @staticmethod
828
826
  def _condition_posterior(
@@ -1,12 +1,15 @@
1
+ from functools import partial
1
2
  import os
2
3
  import itertools as it
3
- from typing import List, Dict
4
+ from typing import List, Dict, Literal, Optional, Union
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
  import xarray as xr
8
+ import arviz as az
7
9
 
8
10
  from pymob import SimulationBase
9
11
  from pymob.sim.report import Report, reporting
12
+ from pymob.inference.analysis import round_to_sigfig, format_parameter
10
13
 
11
14
  from guts_base.plot import plot_survival_multipanel, plot_exposure_multipanel
12
15
  from guts_base.sim.ecx import ECxEstimator
@@ -17,6 +20,12 @@ class GutsReport(Report):
17
20
  ecx_draws: int = 250
18
21
  ecx_force_draws: bool = False
19
22
  set_background_mortality_to_zero = True
23
+ table_parameter_stat_focus = "mean"
24
+ units = xr.Dataset({
25
+ "metric": ["unit"],
26
+ "k_d": ("metric", ["1/t"])
27
+ })
28
+
20
29
 
21
30
  def additional_reports(self, sim: "SimulationBase"):
22
31
  super().additional_reports(sim=sim)
@@ -113,6 +122,164 @@ class GutsReport(Report):
113
122
  return out
114
123
 
115
124
 
125
+ @reporting
126
+ def table_parameter_estimates(self, posterior, indices):
127
+
128
+ if self.rc.table_parameter_estimates_with_batch_dim_vars:
129
+ var_names = {
130
+ k: k for k, v in self.config.model_parameters.free.items()
131
+ }
132
+ else:
133
+ var_names = {
134
+ k: k for k, v in self.config.model_parameters.free.items()
135
+ if self.config.simulation.batch_dimension not in v.dims
136
+ }
137
+
138
+ var_names.update(self.rc.table_parameter_estimates_override_names)
139
+
140
+ if len(self.rc.table_parameter_estimates_exclude_vars) > 0:
141
+ self._write(f"Excluding parameters: {self.rc.table_parameter_estimates_exclude_vars} for meaningful visualization")
142
+
143
+ var_names = {
144
+ k: k for k, v in var_names.items()
145
+ if k not in self.rc.table_parameter_estimates_exclude_vars
146
+ }
147
+
148
+ tab_report = create_table(
149
+ posterior=posterior,
150
+ vars=var_names,
151
+ error_metric=self.rc.table_parameter_estimates_error_metric,
152
+ units=self.units,
153
+ significant_figures=self.rc.table_parameter_estimates_significant_figures,
154
+ nesting_dimension=indices.keys(),
155
+ parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
156
+ )
157
+
158
+ # rewrite table in the desired output format
159
+ tab = create_table(
160
+ posterior=posterior,
161
+ vars=var_names,
162
+ error_metric=self.rc.table_parameter_estimates_error_metric,
163
+ units=self.units,
164
+ significant_figures=self.rc.table_parameter_estimates_significant_figures,
165
+ fmt=self.rc.table_parameter_estimates_format,
166
+ nesting_dimension=indices.keys(),
167
+ parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
168
+ )
169
+
170
+ self._write_table(tab=tab, tab_report=tab_report, label_insert="Parameter estimates")
171
+
172
+
173
+ def create_table(
174
+ posterior,
175
+ error_metric: Literal["hdi","sd"] = "hdi",
176
+ vars: Dict = {},
177
+ nesting_dimension: Optional[Union[List,str]] = None,
178
+ units: xr.Dataset = xr.Dataset(),
179
+ fmt: Literal["csv", "tsv", "latex"] = "csv",
180
+ significant_figures: int = 3,
181
+ parameters_as_rows: bool = True,
182
+ ) -> pd.DataFrame:
183
+ """The function is not ready to deal with any nesting dimensionality
184
+ and currently expects the 2-D case
185
+ """
186
+ tab = az.summary(
187
+ posterior, var_names=list(vars.keys()),
188
+ fmt="xarray", kind="stats", stat_focus="mean",
189
+ hdi_prob=0.94
190
+ )
191
+
192
+ tab = tab.rename(vars)
193
+
194
+ _units = flatten_coords(
195
+ dataset=create_units(dataset=tab, defined_units=units),
196
+ keep_dims=["metric"]
197
+ )
198
+ tab = flatten_coords(dataset=tab, keep_dims=["metric"])
199
+
200
+ tab = tab.apply(np.vectorize(
201
+ partial(round_to_sigfig, sig_fig=significant_figures)
202
+ ))
203
+
204
+
205
+ if error_metric == "sd":
206
+ arrays = []
207
+ for _, data_var in tab.data_vars.items():
208
+ par_formatted = data_var.sel(metric=["mean", "sd"])\
209
+ .astype(str).str\
210
+ .join("metric", sep=" ± ")
211
+ arrays.append(par_formatted)
212
+
213
+
214
+ table = xr.combine_by_coords(arrays)
215
+ table = table.assign_coords(metric="mean ± std").expand_dims("metric")
216
+ table = table.to_dataframe().T
217
+
218
+ elif error_metric == "hdi":
219
+ stacked_tab = tab.sel(metric=["mean", "hdi_3%", "hdi_97%"])\
220
+ .assign_coords(metric=["mean", "hdi 3%", "hdi 97%"])
221
+ table = stacked_tab.to_dataframe().T
222
+
223
+ else:
224
+ raise NotImplementedError("Must use one of 'sd' or 'hdi'")
225
+
226
+
227
+ if fmt == "latex":
228
+ table.columns.names = [c.replace('_',' ') for c in table.columns.names]
229
+ table.index = [format_parameter(i) for i in list(table.index)]
230
+ table = table.rename(
231
+ columns={"hdi 3%": "hdi 3\\%", "hdi 97%": "hdi 97\\%"}
232
+ )
233
+ else:
234
+ pass
235
+
236
+ table["unit"] = _units.to_pandas().T
237
+
238
+
239
+ if parameters_as_rows:
240
+ return table
241
+ else:
242
+ return table.T
243
+
244
+ def flatten_coords(dataset: xr.Dataset, keep_dims):
245
+ """flattens extra coordinates beside the keep_dim dimension for all data variables
246
+ producing a array with harmonized dimensions
247
+ """
248
+ ds = dataset.copy()
249
+ for var_name, data_var in ds.data_vars.items():
250
+ extra_coords = [k for k in list(data_var.coords.keys()) if k not in keep_dims]
251
+ if len(extra_coords) == 0:
252
+ continue
253
+
254
+ data_var_ = data_var.stack(index=extra_coords)
255
+
256
+ # otherwise
257
+ for idx in data_var_["index"].values:
258
+ new_var_name = f"{var_name}[{','.join([str(e) for e in idx])}]"
259
+ # reset coordinates to move non-dim index coords from coordinates to the
260
+ # data variables and then select only the var_name from the data vars
261
+ new_data_var = data_var_.sel({"index": idx}).reset_coords()[var_name]
262
+ ds[new_var_name] = new_data_var
263
+
264
+ ds = ds.drop(var_name)
265
+
266
+ # drop any coordinates that should not be in the dataset at this stage
267
+ extra_coords = [k for k in list(ds.coords.keys()) if k not in keep_dims]
268
+ ds = ds.drop(extra_coords)
269
+
270
+ return ds
271
+
272
+ def create_units(dataset: xr.Dataset, defined_units: xr.Dataset):
273
+ units = dataset.sel(metric=["mean"]).astype(str)
274
+ units = units.assign_coords({"metric": ("metric", ["unit"])})
275
+ for k, u in units.data_vars.items():
276
+ if k in defined_units:
277
+ units = units.assign({k: defined_units[k].astype(units[k].dtype)})
278
+ else:
279
+ units[k].values = np.full_like(u.values, "")
280
+
281
+ return units
282
+
116
283
  class ParameterConverter:
117
284
  def __init__(
118
285
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: guts_base
3
- Version: 1.0.5
3
+ Version: 1.0.6
4
4
  Summary: Basic GUTS model implementation in pymob
5
5
  Author-email: Florian Schunck <fluncki@protonmail.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "guts_base"
7
- version = "1.0.5"
7
+ version = "1.0.6"
8
8
  authors = [
9
9
  { name="Florian Schunck", email="fluncki@protonmail.com" },
10
10
  ]
@@ -49,7 +49,7 @@ import-openguts = "guts_base.data.openguts:create_database_and_import_data"
49
49
  convert-time-of-death-to-openguts = "guts_base.data.time_of_death:time_of_death_to_openguts"
50
50
 
51
51
  [tool.bumpver]
52
- current_version = "1.0.5"
52
+ current_version = "1.0.6"
53
53
  version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]"
54
54
  commit_message = "bump version {old_version} -> {new_version}"
55
55
  tag_message = "{new_version}"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes