guts-base 1.0.5__tar.gz → 1.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- {guts_base-1.0.5 → guts_base-1.0.6}/PKG-INFO +1 -1
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/__init__.py +1 -1
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/base.py +24 -26
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/report.py +168 -1
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/PKG-INFO +1 -1
- {guts_base-1.0.5 → guts_base-1.0.6}/pyproject.toml +2 -2
- {guts_base-1.0.5 → guts_base-1.0.6}/LICENSE +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/README.md +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/__init__.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/expydb.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/generator.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/openguts.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/preprocessing.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/survival.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/time_of_death.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/data/utils.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/mod.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/plot.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/prob.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/__init__.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/constructors.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/ecx.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/mempy.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base/sim/utils.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/SOURCES.txt +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/dependency_links.txt +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/entry_points.txt +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/requires.txt +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/guts_base.egg-info/top_level.txt +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/setup.cfg +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_data_import.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_ecx.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_from_pymob.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_scripted_simulations.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_simulations.py +0 -0
- {guts_base-1.0.5 → guts_base-1.0.6}/tests/test_symbolic_solve.py +0 -0
|
@@ -101,7 +101,12 @@ class GutsBase(SimulationBase):
|
|
|
101
101
|
self.unit_time = self.config.simulation.unit_time # type: ignore
|
|
102
102
|
|
|
103
103
|
if hasattr(self.config.simulation, "skip_data_processing"):
|
|
104
|
-
self._skip_data_processing =
|
|
104
|
+
self._skip_data_processing = not (
|
|
105
|
+
self.config.simulation.skip_data_processing == "False" or
|
|
106
|
+
self.config.simulation.skip_data_processing == "false" or # type: ignore
|
|
107
|
+
self.config.simulation.skip_data_processing == "" or # type: ignore
|
|
108
|
+
self.config.simulation.skip_data_processing == 0 # type: ignore
|
|
109
|
+
)
|
|
105
110
|
|
|
106
111
|
if hasattr(self.config.simulation, "results_interpolation"):
|
|
107
112
|
results_interpolation_string = string_to_list(self.config.simulation.results_interpolation)
|
|
@@ -418,8 +423,6 @@ class GutsBase(SimulationBase):
|
|
|
418
423
|
|
|
419
424
|
self.dispatch_constructor()
|
|
420
425
|
_ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
|
|
421
|
-
self.inferer.store_results(output=f"{self.output_path}/numpyro_posterior_interp.nc") # type: ignore
|
|
422
|
-
self.logger.info("Recomputed posterior and storing in `numpyro_posterior_interp.nc`")
|
|
423
426
|
|
|
424
427
|
|
|
425
428
|
def prior_predictive_checks(self, **plot_kwargs):
|
|
@@ -430,9 +433,12 @@ class GutsBase(SimulationBase):
|
|
|
430
433
|
def posterior_predictive_checks(self, **plot_kwargs):
|
|
431
434
|
super().posterior_predictive_checks(**plot_kwargs)
|
|
432
435
|
|
|
433
|
-
self.
|
|
436
|
+
sim_copy = self.copy()
|
|
437
|
+
sim_copy.recompute_posterior()
|
|
434
438
|
# TODO: Include posterior_predictive group once the survival predictions are correctly working
|
|
435
|
-
|
|
439
|
+
sim_copy._plot.plot_posterior_predictions(
|
|
440
|
+
sim_copy, data_vars=["survival"], groups=["posterior_model_fits"]
|
|
441
|
+
)
|
|
436
442
|
|
|
437
443
|
|
|
438
444
|
def plot(self, results):
|
|
@@ -802,27 +808,19 @@ class GutsBase(SimulationBase):
|
|
|
802
808
|
|
|
803
809
|
self.dispatch_constructor()
|
|
804
810
|
|
|
805
|
-
def export(self, directory: Optional[str] = None):
|
|
806
|
-
self.config.simulation.skip_data_processing =
|
|
807
|
-
super().export(directory=directory)
|
|
808
|
-
|
|
809
|
-
def
|
|
810
|
-
"""
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
tmp_basedir = os.path.join(self.output_path, "tmp")
|
|
819
|
-
os.makedirs(tmp_basedir, exist_ok=True)
|
|
820
|
-
with tempfile.TemporaryDirectory(dir=tmp_basedir) as name:
|
|
821
|
-
self.export(directory=name)
|
|
822
|
-
print(f"Exported files ({name}):", os.listdir(name))
|
|
823
|
-
sim_copy = type(self).from_directory(name)
|
|
824
|
-
|
|
825
|
-
return sim_copy
|
|
811
|
+
def export(self, directory: Optional[str] = None, mode="export", skip_data_processing=True):
|
|
812
|
+
self.config.simulation.skip_data_processing = skip_data_processing
|
|
813
|
+
super().export(directory=directory, mode=mode)
|
|
814
|
+
|
|
815
|
+
def export_to_scenario(self, scenario, force=False):
|
|
816
|
+
"""Exports a case study as a new scenario for running inference"""
|
|
817
|
+
self.config.case_study.scenario = scenario
|
|
818
|
+
self.config.case_study.data = None
|
|
819
|
+
self.config.case_study.output = None
|
|
820
|
+
self.config.case_study.scenario_path_override = None
|
|
821
|
+
self.config.simulation.skip_data_processing = True
|
|
822
|
+
self.save_observations(filename=f"observations_{scenario}.nc", force=force)
|
|
823
|
+
self.config.save(force=force)
|
|
826
824
|
|
|
827
825
|
@staticmethod
|
|
828
826
|
def _condition_posterior(
|
|
@@ -1,12 +1,15 @@
|
|
|
1
|
+
from functools import partial
|
|
1
2
|
import os
|
|
2
3
|
import itertools as it
|
|
3
|
-
from typing import List, Dict
|
|
4
|
+
from typing import List, Dict, Literal, Optional, Union
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
6
7
|
import xarray as xr
|
|
8
|
+
import arviz as az
|
|
7
9
|
|
|
8
10
|
from pymob import SimulationBase
|
|
9
11
|
from pymob.sim.report import Report, reporting
|
|
12
|
+
from pymob.inference.analysis import round_to_sigfig, format_parameter
|
|
10
13
|
|
|
11
14
|
from guts_base.plot import plot_survival_multipanel, plot_exposure_multipanel
|
|
12
15
|
from guts_base.sim.ecx import ECxEstimator
|
|
@@ -17,6 +20,12 @@ class GutsReport(Report):
|
|
|
17
20
|
ecx_draws: int = 250
|
|
18
21
|
ecx_force_draws: bool = False
|
|
19
22
|
set_background_mortality_to_zero = True
|
|
23
|
+
table_parameter_stat_focus = "mean"
|
|
24
|
+
units = xr.Dataset({
|
|
25
|
+
"metric": ["unit"],
|
|
26
|
+
"k_d": ("metric", ["1/t"])
|
|
27
|
+
})
|
|
28
|
+
|
|
20
29
|
|
|
21
30
|
def additional_reports(self, sim: "SimulationBase"):
|
|
22
31
|
super().additional_reports(sim=sim)
|
|
@@ -113,6 +122,164 @@ class GutsReport(Report):
|
|
|
113
122
|
return out
|
|
114
123
|
|
|
115
124
|
|
|
125
|
+
@reporting
|
|
126
|
+
def table_parameter_estimates(self, posterior, indices):
|
|
127
|
+
|
|
128
|
+
if self.rc.table_parameter_estimates_with_batch_dim_vars:
|
|
129
|
+
var_names = {
|
|
130
|
+
k: k for k, v in self.config.model_parameters.free.items()
|
|
131
|
+
}
|
|
132
|
+
else:
|
|
133
|
+
var_names = {
|
|
134
|
+
k: k for k, v in self.config.model_parameters.free.items()
|
|
135
|
+
if self.config.simulation.batch_dimension not in v.dims
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
var_names.update(self.rc.table_parameter_estimates_override_names)
|
|
139
|
+
|
|
140
|
+
if len(self.rc.table_parameter_estimates_exclude_vars) > 0:
|
|
141
|
+
self._write(f"Excluding parameters: {self.rc.table_parameter_estimates_exclude_vars} for meaningful visualization")
|
|
142
|
+
|
|
143
|
+
var_names = {
|
|
144
|
+
k: k for k, v in var_names.items()
|
|
145
|
+
if k not in self.rc.table_parameter_estimates_exclude_vars
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
tab_report = create_table(
|
|
149
|
+
posterior=posterior,
|
|
150
|
+
vars=var_names,
|
|
151
|
+
error_metric=self.rc.table_parameter_estimates_error_metric,
|
|
152
|
+
units=self.units,
|
|
153
|
+
significant_figures=self.rc.table_parameter_estimates_significant_figures,
|
|
154
|
+
nesting_dimension=indices.keys(),
|
|
155
|
+
parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# rewrite table in the desired output format
|
|
159
|
+
tab = create_table(
|
|
160
|
+
posterior=posterior,
|
|
161
|
+
vars=var_names,
|
|
162
|
+
error_metric=self.rc.table_parameter_estimates_error_metric,
|
|
163
|
+
units=self.units,
|
|
164
|
+
significant_figures=self.rc.table_parameter_estimates_significant_figures,
|
|
165
|
+
fmt=self.rc.table_parameter_estimates_format,
|
|
166
|
+
nesting_dimension=indices.keys(),
|
|
167
|
+
parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self._write_table(tab=tab, tab_report=tab_report, label_insert="Parameter estimates")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def create_table(
|
|
174
|
+
posterior,
|
|
175
|
+
error_metric: Literal["hdi","sd"] = "hdi",
|
|
176
|
+
vars: Dict = {},
|
|
177
|
+
nesting_dimension: Optional[Union[List,str]] = None,
|
|
178
|
+
units: xr.Dataset = xr.Dataset(),
|
|
179
|
+
fmt: Literal["csv", "tsv", "latex"] = "csv",
|
|
180
|
+
significant_figures: int = 3,
|
|
181
|
+
parameters_as_rows: bool = True,
|
|
182
|
+
) -> pd.DataFrame:
|
|
183
|
+
"""The function is not ready to deal with any nesting dimensionality
|
|
184
|
+
and currently expects the 2-D case
|
|
185
|
+
"""
|
|
186
|
+
tab = az.summary(
|
|
187
|
+
posterior, var_names=list(vars.keys()),
|
|
188
|
+
fmt="xarray", kind="stats", stat_focus="mean",
|
|
189
|
+
hdi_prob=0.94
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
tab = tab.rename(vars)
|
|
193
|
+
|
|
194
|
+
_units = flatten_coords(
|
|
195
|
+
dataset=create_units(dataset=tab, defined_units=units),
|
|
196
|
+
keep_dims=["metric"]
|
|
197
|
+
)
|
|
198
|
+
tab = flatten_coords(dataset=tab, keep_dims=["metric"])
|
|
199
|
+
|
|
200
|
+
tab = tab.apply(np.vectorize(
|
|
201
|
+
partial(round_to_sigfig, sig_fig=significant_figures)
|
|
202
|
+
))
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
if error_metric == "sd":
|
|
206
|
+
arrays = []
|
|
207
|
+
for _, data_var in tab.data_vars.items():
|
|
208
|
+
par_formatted = data_var.sel(metric=["mean", "sd"])\
|
|
209
|
+
.astype(str).str\
|
|
210
|
+
.join("metric", sep=" ± ")
|
|
211
|
+
arrays.append(par_formatted)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
table = xr.combine_by_coords(arrays)
|
|
215
|
+
table = table.assign_coords(metric="mean ± std").expand_dims("metric")
|
|
216
|
+
table = table.to_dataframe().T
|
|
217
|
+
|
|
218
|
+
elif error_metric == "hdi":
|
|
219
|
+
stacked_tab = tab.sel(metric=["mean", "hdi_3%", "hdi_97%"])\
|
|
220
|
+
.assign_coords(metric=["mean", "hdi 3%", "hdi 97%"])
|
|
221
|
+
table = stacked_tab.to_dataframe().T
|
|
222
|
+
|
|
223
|
+
else:
|
|
224
|
+
raise NotImplementedError("Must use one of 'sd' or 'hdi'")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
if fmt == "latex":
|
|
228
|
+
table.columns.names = [c.replace('_',' ') for c in table.columns.names]
|
|
229
|
+
table.index = [format_parameter(i) for i in list(table.index)]
|
|
230
|
+
table = table.rename(
|
|
231
|
+
columns={"hdi 3%": "hdi 3\\%", "hdi 97%": "hdi 97\\%"}
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
pass
|
|
235
|
+
|
|
236
|
+
table["unit"] = _units.to_pandas().T
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
if parameters_as_rows:
|
|
240
|
+
return table
|
|
241
|
+
else:
|
|
242
|
+
return table.T
|
|
243
|
+
|
|
244
|
+
def flatten_coords(dataset: xr.Dataset, keep_dims):
|
|
245
|
+
"""flattens extra coordinates beside the keep_dim dimension for all data variables
|
|
246
|
+
producing a array with harmonized dimensions
|
|
247
|
+
"""
|
|
248
|
+
ds = dataset.copy()
|
|
249
|
+
for var_name, data_var in ds.data_vars.items():
|
|
250
|
+
extra_coords = [k for k in list(data_var.coords.keys()) if k not in keep_dims]
|
|
251
|
+
if len(extra_coords) == 0:
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
data_var_ = data_var.stack(index=extra_coords)
|
|
255
|
+
|
|
256
|
+
# otherwise
|
|
257
|
+
for idx in data_var_["index"].values:
|
|
258
|
+
new_var_name = f"{var_name}[{','.join([str(e) for e in idx])}]"
|
|
259
|
+
# reset coordinates to move non-dim index coords from coordinates to the
|
|
260
|
+
# data variables and then select only the var_name from the data vars
|
|
261
|
+
new_data_var = data_var_.sel({"index": idx}).reset_coords()[var_name]
|
|
262
|
+
ds[new_var_name] = new_data_var
|
|
263
|
+
|
|
264
|
+
ds = ds.drop(var_name)
|
|
265
|
+
|
|
266
|
+
# drop any coordinates that should not be in the dataset at this stage
|
|
267
|
+
extra_coords = [k for k in list(ds.coords.keys()) if k not in keep_dims]
|
|
268
|
+
ds = ds.drop(extra_coords)
|
|
269
|
+
|
|
270
|
+
return ds
|
|
271
|
+
|
|
272
|
+
def create_units(dataset: xr.Dataset, defined_units: xr.Dataset):
|
|
273
|
+
units = dataset.sel(metric=["mean"]).astype(str)
|
|
274
|
+
units = units.assign_coords({"metric": ("metric", ["unit"])})
|
|
275
|
+
for k, u in units.data_vars.items():
|
|
276
|
+
if k in defined_units:
|
|
277
|
+
units = units.assign({k: defined_units[k].astype(units[k].dtype)})
|
|
278
|
+
else:
|
|
279
|
+
units[k].values = np.full_like(u.values, "")
|
|
280
|
+
|
|
281
|
+
return units
|
|
282
|
+
|
|
116
283
|
class ParameterConverter:
|
|
117
284
|
def __init__(
|
|
118
285
|
self,
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "guts_base"
|
|
7
|
-
version = "1.0.
|
|
7
|
+
version = "1.0.6"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Florian Schunck", email="fluncki@protonmail.com" },
|
|
10
10
|
]
|
|
@@ -49,7 +49,7 @@ import-openguts = "guts_base.data.openguts:create_database_and_import_data"
|
|
|
49
49
|
convert-time-of-death-to-openguts = "guts_base.data.time_of_death:time_of_death_to_openguts"
|
|
50
50
|
|
|
51
51
|
[tool.bumpver]
|
|
52
|
-
current_version = "1.0.
|
|
52
|
+
current_version = "1.0.6"
|
|
53
53
|
version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]"
|
|
54
54
|
commit_message = "bump version {old_version} -> {new_version}"
|
|
55
55
|
tag_message = "{new_version}"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|