guts-base 1.0.4__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

guts_base/__init__.py CHANGED
@@ -4,7 +4,7 @@ from . import data
4
4
  from . import prob
5
5
  from . import plot
6
6
 
7
- __version__ = "1.0.4"
7
+ __version__ = "1.0.6"
8
8
 
9
9
  from .sim import (
10
10
  GutsBase,
guts_base/sim/base.py CHANGED
@@ -101,7 +101,12 @@ class GutsBase(SimulationBase):
101
101
  self.unit_time = self.config.simulation.unit_time # type: ignore
102
102
 
103
103
  if hasattr(self.config.simulation, "skip_data_processing"):
104
- self._skip_data_processing = bool(self.config.simulation.skip_data_processing) # type: ignore
104
+ self._skip_data_processing = not (
105
+ self.config.simulation.skip_data_processing == "False" or
106
+ self.config.simulation.skip_data_processing == "false" or # type: ignore
107
+ self.config.simulation.skip_data_processing == "" or # type: ignore
108
+ self.config.simulation.skip_data_processing == 0 # type: ignore
109
+ )
105
110
 
106
111
  if hasattr(self.config.simulation, "results_interpolation"):
107
112
  results_interpolation_string = string_to_list(self.config.simulation.results_interpolation)
@@ -418,8 +423,6 @@ class GutsBase(SimulationBase):
418
423
 
419
424
  self.dispatch_constructor()
420
425
  _ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
421
- self.inferer.store_results(output=f"{self.output_path}/numpyro_posterior_interp.nc") # type: ignore
422
- self.logger.info("Recomputed posterior and storing in `numpyro_posterior_interp.nc`")
423
426
 
424
427
 
425
428
  def prior_predictive_checks(self, **plot_kwargs):
@@ -430,9 +433,12 @@ class GutsBase(SimulationBase):
430
433
  def posterior_predictive_checks(self, **plot_kwargs):
431
434
  super().posterior_predictive_checks(**plot_kwargs)
432
435
 
433
- self.recompute_posterior()
436
+ sim_copy = self.copy()
437
+ sim_copy.recompute_posterior()
434
438
  # TODO: Include posterior_predictive group once the survival predictions are correctly working
435
- self._plot.plot_posterior_predictions(self, data_vars=["survival"], groups=["posterior_model_fits"])
439
+ sim_copy._plot.plot_posterior_predictions(
440
+ sim_copy, data_vars=["survival"], groups=["posterior_model_fits"]
441
+ )
436
442
 
437
443
 
438
444
  def plot(self, results):
@@ -802,29 +808,19 @@ class GutsBase(SimulationBase):
802
808
 
803
809
  self.dispatch_constructor()
804
810
 
805
- def export(self, directory: Optional[str] = None):
806
- self.config.simulation.skip_data_processing = False
807
- super().export(directory=directory)
808
-
809
- def copy(self):
810
- """Creates a copy of a SimulationBase object by deepcopying all loose references
811
- TODO: If this works out well integrate into pymob. I have the feeling there will
812
- still be some problems down the line.
813
- """
814
- with warnings.catch_warnings(action="ignore"):
815
- # using the context manager here will not work, because it will prematurely
816
- # remove the tempdir and its contents. This interacts badly with netcdf files.
817
- # which are only transferred into memory when the contents are needed.
818
- tempdir = tempfile.TemporaryDirectory()
819
- self.export(directory=tempdir.name)
820
- print(os.listdir(tempdir.name))
821
- sim_copy = type(self).from_directory(tempdir.name)
822
-
823
- # set this attribute so that it can be cleaned up after the copied sim is
824
- # no longer necessary.
825
- sim_copy.tempdir = tempdir
826
-
827
- return sim_copy
811
+ def export(self, directory: Optional[str] = None, mode="export", skip_data_processing=True):
812
+ self.config.simulation.skip_data_processing = skip_data_processing
813
+ super().export(directory=directory, mode=mode)
814
+
815
+ def export_to_scenario(self, scenario, force=False):
816
+ """Exports a case study as a new scenario for running inference"""
817
+ self.config.case_study.scenario = scenario
818
+ self.config.case_study.data = None
819
+ self.config.case_study.output = None
820
+ self.config.case_study.scenario_path_override = None
821
+ self.config.simulation.skip_data_processing = True
822
+ self.save_observations(filename=f"observations_{scenario}.nc", force=force)
823
+ self.config.save(force=force)
828
824
 
829
825
  @staticmethod
830
826
  def _condition_posterior(
guts_base/sim/ecx.py CHANGED
@@ -112,11 +112,6 @@ class ECxEstimator:
112
112
 
113
113
  self.condition_posterior_parameters(conditions=conditions_posterior)
114
114
 
115
- def __del__(self):
116
- if hasattr(self.sim, "tempdir"):
117
- self.sim.tempdir.cleanup()
118
-
119
-
120
115
  def _assert_posterior(self):
121
116
  try:
122
117
  p = self.sim.inferer.idata.posterior
guts_base/sim/report.py CHANGED
@@ -1,12 +1,15 @@
1
+ from functools import partial
1
2
  import os
2
3
  import itertools as it
3
- from typing import List, Dict
4
+ from typing import List, Dict, Literal, Optional, Union
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
  import xarray as xr
8
+ import arviz as az
7
9
 
8
10
  from pymob import SimulationBase
9
11
  from pymob.sim.report import Report, reporting
12
+ from pymob.inference.analysis import round_to_sigfig, format_parameter
10
13
 
11
14
  from guts_base.plot import plot_survival_multipanel, plot_exposure_multipanel
12
15
  from guts_base.sim.ecx import ECxEstimator
@@ -17,6 +20,12 @@ class GutsReport(Report):
17
20
  ecx_draws: int = 250
18
21
  ecx_force_draws: bool = False
19
22
  set_background_mortality_to_zero = True
23
+ table_parameter_stat_focus = "mean"
24
+ units = xr.Dataset({
25
+ "metric": ["unit"],
26
+ "k_d": ("metric", ["1/t"])
27
+ })
28
+
20
29
 
21
30
  def additional_reports(self, sim: "SimulationBase"):
22
31
  super().additional_reports(sim=sim)
@@ -105,12 +114,6 @@ class GutsReport(Report):
105
114
 
106
115
  ecx.append(ecx_estimator.results.copy(deep=True))
107
116
 
108
- # remove ecx_estimator to not blow up temp files.
109
- # This triggers the __del__ method of ECxEstimator,
110
- # which cleans up a temporary directory if it was
111
- # created during init
112
- del ecx_estimator
113
-
114
117
  results = pd.DataFrame(ecx)
115
118
  estimates[results.columns] = results
116
119
 
@@ -119,6 +122,164 @@ class GutsReport(Report):
119
122
  return out
120
123
 
121
124
 
125
+ @reporting
126
+ def table_parameter_estimates(self, posterior, indices):
127
+
128
+ if self.rc.table_parameter_estimates_with_batch_dim_vars:
129
+ var_names = {
130
+ k: k for k, v in self.config.model_parameters.free.items()
131
+ }
132
+ else:
133
+ var_names = {
134
+ k: k for k, v in self.config.model_parameters.free.items()
135
+ if self.config.simulation.batch_dimension not in v.dims
136
+ }
137
+
138
+ var_names.update(self.rc.table_parameter_estimates_override_names)
139
+
140
+ if len(self.rc.table_parameter_estimates_exclude_vars) > 0:
141
+ self._write(f"Excluding parameters: {self.rc.table_parameter_estimates_exclude_vars} for meaningful visualization")
142
+
143
+ var_names = {
144
+ k: k for k, v in var_names.items()
145
+ if k not in self.rc.table_parameter_estimates_exclude_vars
146
+ }
147
+
148
+ tab_report = create_table(
149
+ posterior=posterior,
150
+ vars=var_names,
151
+ error_metric=self.rc.table_parameter_estimates_error_metric,
152
+ units=self.units,
153
+ significant_figures=self.rc.table_parameter_estimates_significant_figures,
154
+ nesting_dimension=indices.keys(),
155
+ parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
156
+ )
157
+
158
+ # rewrite table in the desired output format
159
+ tab = create_table(
160
+ posterior=posterior,
161
+ vars=var_names,
162
+ error_metric=self.rc.table_parameter_estimates_error_metric,
163
+ units=self.units,
164
+ significant_figures=self.rc.table_parameter_estimates_significant_figures,
165
+ fmt=self.rc.table_parameter_estimates_format,
166
+ nesting_dimension=indices.keys(),
167
+ parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
168
+ )
169
+
170
+ self._write_table(tab=tab, tab_report=tab_report, label_insert="Parameter estimates")
171
+
172
+
173
+ def create_table(
174
+ posterior,
175
+ error_metric: Literal["hdi","sd"] = "hdi",
176
+ vars: Dict = {},
177
+ nesting_dimension: Optional[Union[List,str]] = None,
178
+ units: xr.Dataset = xr.Dataset(),
179
+ fmt: Literal["csv", "tsv", "latex"] = "csv",
180
+ significant_figures: int = 3,
181
+ parameters_as_rows: bool = True,
182
+ ) -> pd.DataFrame:
183
+ """The function is not ready to deal with any nesting dimensionality
184
+ and currently expects the 2-D case
185
+ """
186
+ tab = az.summary(
187
+ posterior, var_names=list(vars.keys()),
188
+ fmt="xarray", kind="stats", stat_focus="mean",
189
+ hdi_prob=0.94
190
+ )
191
+
192
+ tab = tab.rename(vars)
193
+
194
+ _units = flatten_coords(
195
+ dataset=create_units(dataset=tab, defined_units=units),
196
+ keep_dims=["metric"]
197
+ )
198
+ tab = flatten_coords(dataset=tab, keep_dims=["metric"])
199
+
200
+ tab = tab.apply(np.vectorize(
201
+ partial(round_to_sigfig, sig_fig=significant_figures)
202
+ ))
203
+
204
+
205
+ if error_metric == "sd":
206
+ arrays = []
207
+ for _, data_var in tab.data_vars.items():
208
+ par_formatted = data_var.sel(metric=["mean", "sd"])\
209
+ .astype(str).str\
210
+ .join("metric", sep=" ± ")
211
+ arrays.append(par_formatted)
212
+
213
+
214
+ table = xr.combine_by_coords(arrays)
215
+ table = table.assign_coords(metric="mean ± std").expand_dims("metric")
216
+ table = table.to_dataframe().T
217
+
218
+ elif error_metric == "hdi":
219
+ stacked_tab = tab.sel(metric=["mean", "hdi_3%", "hdi_97%"])\
220
+ .assign_coords(metric=["mean", "hdi 3%", "hdi 97%"])
221
+ table = stacked_tab.to_dataframe().T
222
+
223
+ else:
224
+ raise NotImplementedError("Must use one of 'sd' or 'hdi'")
225
+
226
+
227
+ if fmt == "latex":
228
+ table.columns.names = [c.replace('_',' ') for c in table.columns.names]
229
+ table.index = [format_parameter(i) for i in list(table.index)]
230
+ table = table.rename(
231
+ columns={"hdi 3%": "hdi 3\\%", "hdi 97%": "hdi 97\\%"}
232
+ )
233
+ else:
234
+ pass
235
+
236
+ table["unit"] = _units.to_pandas().T
237
+
238
+
239
+ if parameters_as_rows:
240
+ return table
241
+ else:
242
+ return table.T
243
+
244
+ def flatten_coords(dataset: xr.Dataset, keep_dims):
245
+ """flattens extra coordinates beside the keep_dim dimension for all data variables
246
+ producing a array with harmonized dimensions
247
+ """
248
+ ds = dataset.copy()
249
+ for var_name, data_var in ds.data_vars.items():
250
+ extra_coords = [k for k in list(data_var.coords.keys()) if k not in keep_dims]
251
+ if len(extra_coords) == 0:
252
+ continue
253
+
254
+ data_var_ = data_var.stack(index=extra_coords)
255
+
256
+ # otherwise
257
+ for idx in data_var_["index"].values:
258
+ new_var_name = f"{var_name}[{','.join([str(e) for e in idx])}]"
259
+ # reset coordinates to move non-dim index coords from coordinates to the
260
+ # data variables and then select only the var_name from the data vars
261
+ new_data_var = data_var_.sel({"index": idx}).reset_coords()[var_name]
262
+ ds[new_var_name] = new_data_var
263
+
264
+ ds = ds.drop(var_name)
265
+
266
+ # drop any coordinates that should not be in the dataset at this stage
267
+ extra_coords = [k for k in list(ds.coords.keys()) if k not in keep_dims]
268
+ ds = ds.drop(extra_coords)
269
+
270
+ return ds
271
+
272
+ def create_units(dataset: xr.Dataset, defined_units: xr.Dataset):
273
+ units = dataset.sel(metric=["mean"]).astype(str)
274
+ units = units.assign_coords({"metric": ("metric", ["unit"])})
275
+ for k, u in units.data_vars.items():
276
+ if k in defined_units:
277
+ units = units.assign({k: defined_units[k].astype(units[k].dtype)})
278
+ else:
279
+ units[k].values = np.full_like(u.values, "")
280
+
281
+ return units
282
+
122
283
  class ParameterConverter:
123
284
  def __init__(
124
285
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: guts_base
3
- Version: 1.0.4
3
+ Version: 1.0.6
4
4
  Summary: Basic GUTS model implementation in pymob
5
5
  Author-email: Florian Schunck <fluncki@protonmail.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,4 +1,4 @@
1
- guts_base/__init__.py,sha256=xW56Yo4_P0LRfBpKaMEOwnTnZsmTajA4breTU1mVad0,227
1
+ guts_base/__init__.py,sha256=mqsDPsRl9TvFmIRICGE6wDPG-d1TEAzO0Sj1Rfurl5A,227
2
2
  guts_base/mod.py,sha256=AzOCg1A8FP5EtVfp-66HT7G7h_wnHkenieaxTc9qCyk,5796
3
3
  guts_base/plot.py,sha256=Sr_d0sXHNajPLPelcGl72yCOEEqB7NGNWhViKYAiTng,6692
4
4
  guts_base/prob.py,sha256=ITwo5dAGMHr5xTldilHMbKU6AFsWo4_ZwbfaXh97Gew,5443
@@ -11,15 +11,15 @@ guts_base/data/survival.py,sha256=U-Ehloo8vnD81VeIglXLEUHX9lt7SjtEs2YEB0D9FHE,50
11
11
  guts_base/data/time_of_death.py,sha256=hwngUwfRP3u8WmD3dHyXrphuu5d8ZJTKyBovGRwAHNQ,21014
12
12
  guts_base/data/utils.py,sha256=u3gGDJK15MfRUP4iIxsS-I1oqxD2qH_ugsT7o_Eac18,236
13
13
  guts_base/sim/__init__.py,sha256=sbHmT1p2saN0MJ-iYnCDOHjkHtTcKgm7X-dHX5o0tYA,435
14
- guts_base/sim/base.py,sha256=XKW2_bFxb4oUdyb2Y8dfFLQ_mLmJhaHmqfhHyp97240,37788
14
+ guts_base/sim/base.py,sha256=xd4VroOS7KM8Ap7hYGaD85hB3n_8IT8GuEvecYu3TnE,37549
15
15
  guts_base/sim/constructors.py,sha256=Kz9FHIH3EHsSIKd9sQgHa3eveniFifFlk1Hf-QR69Pg,875
16
- guts_base/sim/ecx.py,sha256=QONOWuv2jvjq1UAwiojuvj_jJyxfz7haewASjj4oCbs,20913
16
+ guts_base/sim/ecx.py,sha256=PeX8UVF1HMMNHaIU-jL6dml4JGezhWwiGSqPFJbOFM4,20808
17
17
  guts_base/sim/mempy.py,sha256=IHd87UrmdXpC7y7q0IjYQJH075frjbp2a-dMVBeqZ0U,10164
18
- guts_base/sim/report.py,sha256=__6BGnG6c3DJXNUO39R0WkMuf8A1PWAeTePThz7ydKo,7040
18
+ guts_base/sim/report.py,sha256=o19MBhKcwty2auPjYWoz4QY91jjJFkA80UTzUuZo1oE,12720
19
19
  guts_base/sim/utils.py,sha256=Qj_FPH6kywVxOwgCerS7w5CyuYR9HKmvBWFpmxwDFgk,256
20
- guts_base-1.0.4.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
21
- guts_base-1.0.4.dist-info/METADATA,sha256=fbcVvL249G1AjdaBQg8_ASjLhyKsnYqU1zv7rA20E4c,45426
22
- guts_base-1.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- guts_base-1.0.4.dist-info/entry_points.txt,sha256=icsHzG2jQ90ZS7XvLsI5Qj0-qGuPv2T0RBVN5daGCPU,183
24
- guts_base-1.0.4.dist-info/top_level.txt,sha256=PxhBgUd4r39W_VI4FyJjARwKbV5_glgCVnd6v_zAGdE,10
25
- guts_base-1.0.4.dist-info/RECORD,,
20
+ guts_base-1.0.6.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
21
+ guts_base-1.0.6.dist-info/METADATA,sha256=scr05XWOqcMi7XsHVrlyflI-4WuDq1B-w6OFz9mu6Ho,45426
22
+ guts_base-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ guts_base-1.0.6.dist-info/entry_points.txt,sha256=icsHzG2jQ90ZS7XvLsI5Qj0-qGuPv2T0RBVN5daGCPU,183
24
+ guts_base-1.0.6.dist-info/top_level.txt,sha256=PxhBgUd4r39W_VI4FyJjARwKbV5_glgCVnd6v_zAGdE,10
25
+ guts_base-1.0.6.dist-info/RECORD,,