guts-base 2.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
guts_base/sim/mempy.py ADDED
@@ -0,0 +1,290 @@
1
+ import pathlib
2
+ from typing import Dict, Optional, Literal, Protocol, TypedDict, List, Tuple
3
+ import re
4
+ import os
5
+ import numpy as np
6
+ from numpy.ma.core import MaskedConstant
7
+ import pandas as pd
8
+ import xarray as xr
9
+ from pymob.sim.config import Config, DataVariable, Datastructure
10
+ from pymob.sim.config import Param
11
+ from guts_base.sim import GutsBase
12
+
13
+ __all__ = [
14
+ "PymobSimulator",
15
+ ]
16
+
17
+ class ParamsInfoDict(TypedDict):
18
+ name: str
19
+ min: float
20
+ max: float
21
+ initial: float
22
+ vary: bool
23
+ prior: str
24
+ dims: Tuple[str]
25
+ unit: str|List[str]
26
+ module: str
27
+
28
+ class StateVariablesDict(TypedDict):
29
+ dimensions: List[str]
30
+ observed: bool
31
+ y0: List[float]
32
+
33
+ class Model(Protocol):
34
+ extra_dim: str
35
+ params_info: Dict[str, ParamsInfoDict]
36
+ state_variables: Dict[str, StateVariablesDict]
37
+ _params_info_defaults: Dict[str, ParamsInfoDict]
38
+ _it_model: bool
39
+
40
+ @staticmethod
41
+ def _rhs_jax():
42
+ raise NotImplementedError
43
+
44
+ @staticmethod
45
+ def _solver_post_processing():
46
+ raise NotImplementedError
47
+
48
+ @staticmethod
49
+ def _likelihood_func_jax():
50
+ raise NotImplementedError
51
+
52
+
53
+ class PymobSimulator(GutsBase):
54
+
55
+ @classmethod
56
+ def from_model_and_dataset(
57
+ cls,
58
+ model: Model,
59
+ exposure_data: Dict[str, pd.DataFrame],
60
+ survival_data: pd.DataFrame,
61
+ info_dict: Dict = {},
62
+ pymob_config: Optional[Config] = None,
63
+ output_directory: str|pathlib.Path = pathlib.Path("output/pymob"),
64
+ default_prior: Literal["uniform", "lognorm"] = "lognorm",
65
+ ) -> "PymobSimulator":
66
+ """Construct a PymobSimulator from the
67
+ """
68
+
69
+ if pymob_config is None:
70
+ cfg = Config()
71
+ # Configure: The configuration can be overridden in a subclass to override the
72
+ # configuration
73
+ cls._configure(config=cfg)
74
+ else:
75
+ cfg = pymob_config
76
+
77
+ if isinstance(output_directory, str):
78
+ output_directory = pathlib.Path(output_directory)
79
+
80
+ cfg.case_study.output = str(output_directory)
81
+
82
+ # overrides scenario path. This means the scenario is also expected in the
83
+ # same folder
84
+ cfg.case_study.scenario_path_override = str(output_directory)
85
+ cfg.case_study.scenario = output_directory.stem
86
+ cfg.case_study.data = cfg.case_study.output_path
87
+ cfg.case_study.observations = "observations.nc"
88
+ cfg.create_directory(directory="results", force=True)
89
+
90
+ obs = cls._observations_from_dataframes(
91
+ exposure_data=exposure_data,
92
+ survival_data=survival_data,
93
+ exposure_dim=model.extra_dim,
94
+ unit_input=cfg.guts_base.unit_input,
95
+ unit_time=cfg.guts_base.unit_time,
96
+ )
97
+
98
+ obs.to_netcdf(os.path.join(cfg.case_study.output_path, cfg.case_study.observations))
99
+
100
+
101
+ # configure model and likelihood function
102
+ # extract the fully qualified name of the model module.name
103
+ if isinstance(model, type):
104
+ raise TypeError(
105
+ f"model '{model.__name__}' must be initialized. Initialize with "+
106
+ f"{model.__name__}(). Dont forget to specifiy the number of exposures "+
107
+ "with e.g. RED_SD_DA(num_expos=2) if your model has two exposures."
108
+ )
109
+ else:
110
+ pass
111
+
112
+ cfg.simulation.model_class = "{module}.{name}".format(
113
+ module=model.__module__, name=type(model).__name__
114
+ )
115
+ cfg.inference_numpyro.user_defined_error_model = "{module}.{name}".format(
116
+ module=model._likelihood_func_jax.__module__,
117
+ name=model._likelihood_func_jax.__name__
118
+ )
119
+
120
+ # derive data structure and params from the model instance
121
+ cls._set_data_structure(config=cfg, model=model)
122
+ cls._set_params(config=cfg, model=model, default_prior=default_prior)
123
+
124
+ # configure starting values and input
125
+ cfg.simulation.x_in = ["exposure=exposure"]
126
+ cfg.simulation.y0 = [f"{k}={v['y0']}" for k, v in model.state_variables.items() if "y0" in v]
127
+
128
+ cfg.guts_base.background_mortality_parameters = cls._get_background_mortality_params(model)
129
+
130
+ # create a simulation object
131
+ # It is essential that all post processing tasks are done in self.setup()
132
+ # which is extended below. This ensures that the simulation can also be run
133
+ # from automated tools like pymob-infer
134
+ sim = cls(config=cfg)
135
+ sim.setup()
136
+ return sim
137
+
138
+
139
+ def setup(self, **evaluator_kwargs):
140
+ super().setup(**evaluator_kwargs)
141
+ self._obs_backup = self.observations.copy(deep=True)
142
+
143
+
144
+ @classmethod
145
+ def _configure(cls, config: Config):
146
+ """This is normally set in the configuration file passed to a SimulationBase class.
147
+ Since the mempy to pymob converter initializes pymob.SimulationBase from scratch
148
+ (without using a config file), the necessary settings have to be specified here.
149
+ """
150
+ config.case_study.output = "results"
151
+ config.case_study.simulation = "PymobSimulator"
152
+
153
+ # this must be named guts_base, whihc is the name of the pip package and
154
+ # this regulates which packages are loaded.
155
+ config.case_study.name = "guts_base"
156
+
157
+ config.simulation.x_dimension = "time"
158
+ config.simulation.batch_dimension = "id"
159
+ config.simulation.solver_post_processing = None
160
+
161
+ # this is the registered guts-base section
162
+ # No longer necessary, because these are saved as defaults
163
+ # config.simulation.unit_time = "day"
164
+ # config.simulation.n_reindexed_x = 100
165
+ # config.simulation.forward_interpolate_exposure_data = True
166
+
167
+ config.inference.extra_vars = ["eps", "survivors_before_t", "survivors_at_start"]
168
+ config.inference.n_predictions = 100
169
+
170
+ config.jaxsolver.diffrax_solver = "Tsit5"
171
+ config.jaxsolver.rtol = 1e-10
172
+ config.jaxsolver.atol = 1e-12
173
+ config.jaxsolver.throw_exception = True
174
+ config.jaxsolver.pcoeff = 0.3
175
+ config.jaxsolver.icoeff = 0.3
176
+ config.jaxsolver.dcoeff = 0.0
177
+ config.jaxsolver.max_steps = 1000000
178
+ config.jaxsolver.throw_exception = True
179
+
180
+
181
+ config.inference_numpyro.gaussian_base_distribution = True
182
+ config.inference_numpyro.kernel = "svi"
183
+ config.inference_numpyro.init_strategy = "init_to_median"
184
+ config.inference_numpyro.svi_iterations = 10_000
185
+ config.inference_numpyro.svi_learning_rate = 0.001
186
+
187
+ @classmethod
188
+ def _set_data_structure(cls, config: Config, model: Model):
189
+ """Takes a dictionary that is specified in the model and uses only keys that
190
+ are fields of the DataVariable config-model"""
191
+
192
+ state_dict = model.state_variables
193
+
194
+ config.data_structure = Datastructure.model_validate({
195
+ key: DataVariable.model_validate({
196
+ k: v for k, v in state_info.items()
197
+ if k in DataVariable.model_fields
198
+ })
199
+ for key, state_info in state_dict.items()
200
+ })
201
+
202
+ @staticmethod
203
+ def _get_background_mortality_params(model: Model):
204
+ return [k for k, v in model.params_info.items() if v["module"] == "background-mortality"]
205
+
206
+
207
+ @classmethod
208
+ def _set_params(cls, config: Config, model: Model, default_prior: str):
209
+ params_info = model.params_info
210
+
211
+ if model._it_model:
212
+ eps = config.jaxsolver.atol * 10
213
+ params_info["eps"] = {'name':'eps', 'initial':eps, 'vary':False}
214
+
215
+
216
+ for par, param_dict in params_info.items():
217
+ for k, v in model._params_info_defaults.items():
218
+ if k not in param_dict:
219
+ param_dict.update({k:v})
220
+
221
+ param_df = pd.DataFrame(params_info).T
222
+ param_df["param_index"] = param_df.name.apply(lambda x: re.findall(r"\d+", x))
223
+ param_df["param_index"] = param_df.param_index.apply(lambda x: int(x[0])-1 if len(x) == 1 else None)
224
+ param_df["name"] = param_df.name.apply(lambda x: re.sub(r"\d+", "", x).strip("_"))
225
+
226
+ for (param_name, ), group in param_df.groupby(["name"]):
227
+
228
+ dims = list(dict.fromkeys(group["dims"]))
229
+ dims = tuple([]) if dims == [None] else tuple(dims)
230
+
231
+ prior = list(dict.fromkeys(group["prior"]))
232
+ prior = prior[0] if len(prior) == 1 else prior
233
+
234
+ _min = np.min(np.ma.masked_invalid(group["min"].values.astype(float)))
235
+ _max = np.max(np.ma.masked_invalid(group["max"].values.astype(float)))
236
+ _init = np.array(group["initial"].values.astype(float))
237
+ _free = np.array(group["vary"].values)
238
+
239
+ unit = list(dict.fromkeys(group["unit"]))
240
+ unit = unit[0] if len(unit) == 1 else unit
241
+
242
+ if isinstance(_min, MaskedConstant):
243
+ _min = None
244
+ if isinstance(_max, MaskedConstant):
245
+ _max = None
246
+
247
+ # TODO: allow for parsing one N-D prior from multiple priors
248
+ # TODO: Another choice would be to parse vary=False priors as deterministic
249
+ # and use a composite prior from a deterministic and a free prior as
250
+ # the input into the model
251
+
252
+ if prior is None:
253
+ if _min is None or _max is None:
254
+ prior = None
255
+ elif default_prior == "uniform":
256
+ _loc = _init * np.logical_not(_free) + _min * _free - config.jaxsolver.atol * 10 * np.logical_not(_free)
257
+ _scale = _init * np.logical_not(_free) + _max * _free + config.jaxsolver.atol * 10 * np.logical_not(_free)
258
+ _loc = _loc[0] if len(_loc) == 1 else _loc
259
+ _scale = _scale[0] if len(_scale) == 1 else _scale
260
+ prior = f"uniform(loc={_loc},scale={_scale})"
261
+ elif default_prior == "lognorm":
262
+ _s = 3 * _free + config.jaxsolver.atol * 10 * np.logical_not(_free)
263
+ _init = _init[0] if len(_init) == 1 else _init
264
+ _s = _s[0] if len(_s) == 1 else _s
265
+
266
+ prior = f"lognorm(scale={_init},s={_s})"
267
+ else:
268
+ raise ValueError(
269
+ f"Default prior: '{default_prior}' is not implemented. "+
270
+ "Use one of 'uniform', 'lognorm' or specify priors for each "+
271
+ "parameter directly with: "+
272
+ f"`model.params_dict['prior'] = {default_prior}(...)`"
273
+ )
274
+
275
+ if prior is not None:
276
+ prior = prior.replace(" ", ",")
277
+
278
+ # if isinstance(value,float):
279
+ param = Param.model_validate(dict(
280
+ value=_init,
281
+ free=np.max(_free),
282
+ min=_min,
283
+ max=_max,
284
+ prior=prior,
285
+ dims=dims,
286
+ unit=unit,
287
+ ))
288
+
289
+ setattr(config.model_parameters, param_name, param)
290
+
@@ -0,0 +1,405 @@
1
+ from functools import partial
2
+ import os
3
+ import itertools as it
4
+ from typing import List, Dict, Literal, Optional, Union, TYPE_CHECKING
5
+ import numpy as np
6
+ import pandas as pd
7
+ import xarray as xr
8
+ import arviz as az
9
+ from matplotlib import pyplot as plt
10
+ import pint
11
+
12
+ from pymob import SimulationBase
13
+ from pymob.sim.report import Report, reporting
14
+ from pymob.sim.config import Config, string_to_dict
15
+ from pymob.inference.analysis import round_to_sigfig, format_parameter
16
+
17
+ from guts_base.plot import plot_survival_multipanel, plot_exposure_multipanel
18
+ from guts_base.sim.ecx import ECxEstimator
19
+ from guts_base.sim.config import GutsBaseConfig
20
+ from guts_base.sim import units
21
+
22
+ if TYPE_CHECKING:
23
+ from guts_base.sim import GutsBase
24
+
25
+
26
+ class GutsReport(Report):
27
+ # ecx_estimates_times: List = [1, 2, 4, 10]
28
+ # ecx_estimates_x: List = [0.1, 0.25, 0.5, 0.75, 0.9]
29
+ # ecx_draws: int = 250
30
+ # ecx_force_draws: bool = False
31
+ # set_background_mortality_to_zero = True
32
+ # table_parameter_stat_focus = "mean"
33
+ # units = xr.Dataset({
34
+ # "metric": ["unit"],
35
+ # "k_d": ("metric", ["1/T"])
36
+ # })
37
+ # format_unit = "~P"
38
+
39
+
40
+ def additional_reports(self, sim: "GutsBase"):
41
+ super().additional_reports(sim=sim)
42
+ self.model_input(sim)
43
+ self.model_fits(sim)
44
+ self.LCx_estimates(sim)
45
+
46
+ @property
47
+ def units(self):
48
+ return units.derive_explicit_units(
49
+ config=self.config,
50
+ unit=self.observations["unit"]
51
+ )
52
+
53
+
54
+ @reporting
55
+ def model_input(self, sim: "GutsBase"):
56
+ self._write("### Exposure conditions")
57
+ self._write(
58
+ "These are the exposure conditions that were assumed for parameter inference. "+
59
+ "Double check if they are aligned with your expectations. Especially short " +
60
+ "exposure durations may not be perceivable in this view. In this case it is "+
61
+ "recommended to have a look at the exposure conditions in the numerical "+
62
+ "tables provided below."
63
+ )
64
+
65
+ out_mp = plot_exposure_multipanel(
66
+ sim=sim,
67
+ results=sim.model_parameters["x_in"],
68
+ ncols=6,
69
+ )
70
+
71
+ lab = self._label.format(placeholder='exposure')
72
+ self._write(f"![Exposure model fits.\label{{{lab}}}]({os.path.basename(out_mp)})")
73
+
74
+ return out_mp
75
+
76
+ @reporting
77
+ def model_fits(self, sim: "GutsBase"):
78
+ self._write("### Survival model fits")
79
+
80
+ self._write(
81
+ "Survival observations on the unit scale with model fits. The solid line is "+
82
+ "the average of individual survival probability predictions from multiple "+
83
+ "draws from the posterior parameter distribution. In case a point estimator "+
84
+ "was used the solid line indicates the best fit. Grey uncertainty intervals "+
85
+ "indicate the uncertainty in survival probabilities. Note that the survival "+
86
+ "probabilities indicate the probability for a given individual or population "+
87
+ "to be alive when observed at time t."
88
+ )
89
+
90
+
91
+ if sim._exclude_controls_after_fixing_background_mortality: #type: ignore
92
+ lab = self._label.format(placeholder='survival_fits_controls')
93
+
94
+ # use the plot that is generated in GutsBase.estimate_background_mortality
95
+ self._write(f"![Surival model fits (control treatments).\label{{{lab}}}]"+
96
+ "(survival_multipanel_control_treatments.png)")
97
+
98
+ out_mp = plot_survival_multipanel(
99
+ sim=sim,
100
+ results=sim.inferer.idata.posterior_model_fits,
101
+ ncols=6,
102
+ )
103
+
104
+ lab = self._label.format(placeholder='survival_fits')
105
+ self._write(f"![Surival model fits.\label{{{lab}}}]({os.path.basename(out_mp)})")
106
+
107
+ return out_mp
108
+
109
+ @reporting
110
+ def LCx_estimates(self, sim: "GutsBase"):
111
+ X = self.config.guts_base.ecx_estimates_x
112
+ T = self.config.guts_base.ecx_estimates_times
113
+ P = sim.predefined_scenarios()
114
+
115
+ if self.config.guts_base.ecx_set_background_mortality_to_zero:
116
+ conditions = {p: 0.0 for p in self.config.guts_base.background_mortality_parameters}
117
+
118
+ estimates = pd.DataFrame(
119
+ it.product(X, T, P.keys()),
120
+ columns=["x", "time", "scenario"]
121
+ )
122
+
123
+ ecx = []
124
+
125
+ for i, row in estimates.iterrows():
126
+ msg = "Estimating EC_{x}(t={t} {tu}) for scenario '{s}'".format(
127
+ x=row.x*100, t=row.time, s=row.scenario, tu=self.config.guts_base.unit_time
128
+ )
129
+ print(msg)
130
+ print("=" * len(msg), end="\n")
131
+
132
+ ecx_estimator = ECxEstimator(
133
+ sim=sim,
134
+ effect="survival",
135
+ x=row.x,
136
+ time=row.time,
137
+ x_in=P[row.scenario],
138
+ conditions_posterior=conditions
139
+ )
140
+
141
+ # find a good starting value
142
+ ecx_estimator.plot_loss_curve(
143
+ mode="mean",
144
+ log_x0=0,
145
+ # this is a huge interval
146
+ log_interval_radius=20,
147
+ log_interval_num=100
148
+ )
149
+
150
+ plt.close(ecx_estimator.figure_loss_curve)
151
+
152
+ ecx_estimator.estimate(
153
+ mode=sim.config.guts_base.ecx_mode,
154
+ log_x0=float(np.mean(ecx_estimator.ecx_candidates)),
155
+ draws=self.config.guts_base.ecx_draws,
156
+ force_draws=self.config.guts_base.ecx_force_draws,
157
+ show_plot=False
158
+ )
159
+
160
+ ecx.append(ecx_estimator.results.copy(deep=True))
161
+
162
+ results = pd.DataFrame(ecx)
163
+ estimates[results.columns] = results
164
+
165
+ out = self._write_table(tab=estimates, label_insert="$LC_x$ estimates")
166
+
167
+ return out
168
+
169
+
170
+ @reporting
171
+ def table_parameter_estimates(self, posterior, indices):
172
+
173
+ if self.rc.table_parameter_estimates_with_batch_dim_vars:
174
+ var_names = {
175
+ k: k for k, v in self.config.model_parameters.free.items()
176
+ }
177
+ else:
178
+ var_names = {
179
+ k: k for k, v in self.config.model_parameters.free.items()
180
+ if self.config.simulation.batch_dimension not in v.dims
181
+ }
182
+
183
+ var_names.update(self.rc.table_parameter_estimates_override_names) # type: ignore
184
+
185
+ if len(self.rc.table_parameter_estimates_exclude_vars) > 0:
186
+ self._write(f"Excluding parameters: {self.rc.table_parameter_estimates_exclude_vars} for meaningful visualization")
187
+
188
+ var_names = {
189
+ k: k for k, v in var_names.items()
190
+ if k not in self.rc.table_parameter_estimates_exclude_vars
191
+ }
192
+
193
+ tab_report = create_table(
194
+ posterior=posterior,
195
+ vars=var_names,
196
+ error_metric=self.rc.table_parameter_estimates_error_metric,
197
+ units=self.units,
198
+ stat_focus=self.config.guts_base.table_parameter_stat_focus,
199
+ significant_figures=self.rc.table_parameter_estimates_significant_figures,
200
+ nesting_dimension=indices.keys(),
201
+ parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
202
+ )
203
+
204
+ # rewrite table in the desired output format
205
+ tab = create_table(
206
+ posterior=posterior,
207
+ vars=var_names,
208
+ error_metric=self.rc.table_parameter_estimates_error_metric,
209
+ units=self.units,
210
+ stat_focus=self.config.guts_base.table_parameter_stat_focus,
211
+ significant_figures=self.rc.table_parameter_estimates_significant_figures,
212
+ fmt=self.rc.table_parameter_estimates_format,
213
+ nesting_dimension=indices.keys(),
214
+ parameters_as_rows=self.rc.table_parameter_estimates_parameters_as_rows,
215
+ )
216
+
217
+ self._write_table(tab=tab, tab_report=tab_report, label_insert="Parameter estimates")
218
+
219
+
220
+ def create_table(
221
+ posterior,
222
+ error_metric: Literal["hdi","sd"] = "hdi",
223
+ vars: Dict = {},
224
+ nesting_dimension: Optional[Union[List,str]] = None,
225
+ stat_focus: Literal["mean", "median"] = "mean",
226
+ units: xr.Dataset = xr.Dataset(),
227
+ fmt: Literal["csv", "tsv", "latex"] = "csv",
228
+ significant_figures: int = 3,
229
+ parameters_as_rows: bool = True,
230
+ ) -> pd.DataFrame:
231
+ """The function is not ready to deal with any nesting dimensionality
232
+ and currently expects the 2-D case
233
+ """
234
+ tab = az.summary(
235
+ posterior, var_names=list(vars.keys()),
236
+ fmt="xarray", kind="stats", stat_focus=stat_focus,
237
+ hdi_prob=0.94
238
+ )
239
+
240
+ if TYPE_CHECKING:
241
+ # just for
242
+ assert isinstance(tab, xr.Dataset)
243
+
244
+ tab = tab.rename(vars)
245
+
246
+ _units = flatten_coords(
247
+ dataset=create_units(dataset=tab, defined_units=units, stat_focus=stat_focus),
248
+ keep_dims=["metric"]
249
+ )
250
+ tab = flatten_coords(dataset=tab, keep_dims=["metric"])
251
+
252
+ tab = tab.apply(np.vectorize(
253
+ partial(round_to_sigfig, sig_fig=significant_figures)
254
+ ))
255
+
256
+ if stat_focus == "mean" and error_metric == "sd":
257
+ metrics = ["mean", "sd"]
258
+ elif stat_focus == "mean" and error_metric == "hdi":
259
+ metrics = ["mean", "hdi_3%", "hdi_97%"]
260
+ elif stat_focus == "median" and error_metric == "sd":
261
+ metrics = ["median", "mad"]
262
+ else:
263
+ metrics = ["median", "eti_3%", "eti_97%"]
264
+
265
+
266
+
267
+ if error_metric == "sd":
268
+ arrays = []
269
+ for _, data_var in tab.data_vars.items():
270
+ par_formatted = data_var.sel(metric=metrics)\
271
+ .astype(str).str\
272
+ .join("metric", sep=" ± ")
273
+ arrays.append(par_formatted)
274
+
275
+
276
+ table = xr.combine_by_coords(arrays)
277
+ table = table.assign_coords(metric=" ± ".join(metrics)).expand_dims("metric")
278
+ table = table.to_dataframe().T
279
+
280
+ elif error_metric == "hdi":
281
+ stacked_tab = tab.sel(metric=metrics)\
282
+ .assign_coords(metric=[m.replace("_", " ") for m in metrics])
283
+ table = stacked_tab.to_dataframe().T
284
+
285
+ else:
286
+ raise NotImplementedError("Must use one of 'sd' or 'hdi'")
287
+
288
+
289
+ if fmt == "latex":
290
+ table.columns.names = [str(c).replace('_',' ') for c in table.columns.names]
291
+ table.index = [format_parameter(i) for i in list(table.index)]
292
+ table = table.rename(
293
+ columns={c: c.replace("%", "\\%") for c in table.columns}
294
+ )
295
+ else:
296
+ pass
297
+
298
+ table["unit"] = _units.to_pandas().T
299
+
300
+
301
+ if parameters_as_rows:
302
+ return table
303
+ else:
304
+ return table.T
305
+
306
+ def flatten_coords(dataset: xr.Dataset, keep_dims):
307
+ """flattens extra coordinates beside the keep_dim dimension for all data variables
308
+ producing a array with harmonized dimensions
309
+ """
310
+ ds = dataset.copy()
311
+ ds = ds.reset_coords(drop=True)
312
+ for var_name, data_var in ds.data_vars.items():
313
+ extra_coords = [k for k in list(data_var.coords.keys()) if k not in keep_dims]
314
+ if len(extra_coords) == 0:
315
+ continue
316
+
317
+ data_var_ = data_var.stack(index=extra_coords)
318
+
319
+ # otherwise
320
+ for idx in data_var_["index"].values:
321
+ new_var_name = f"{var_name}[{','.join([str(e) for e in idx])}]"
322
+ # reset coordinates to move non-dim index coords from coordinates to the
323
+ # data variables and then select only the var_name from the data vars
324
+ new_data_var = data_var_.sel({"index": idx}).reset_coords()[var_name]
325
+ ds[new_var_name] = new_data_var
326
+
327
+ ds = ds.drop(var_name)
328
+
329
+ # drop any coordinates that should not be in the dataset at this stage
330
+ extra_coords = [k for k in list(ds.coords.keys()) if k not in keep_dims]
331
+ ds = ds.drop(extra_coords)
332
+
333
+ return ds
334
+
335
+ def create_units(dataset: xr.Dataset, defined_units: xr.Dataset, stat_focus):
336
+ units = dataset.sel(metric=[stat_focus]).astype(str)
337
+ units = units.assign_coords({"metric": ("metric", ["unit"])})
338
+ for k, u in units.data_vars.items():
339
+ if k in defined_units:
340
+ units = units.assign({k: defined_units[k].astype(units[k].dtype)})
341
+ else:
342
+ units[k].values = np.full_like(u.values, "")
343
+
344
+ return units
345
+
346
+ class ParameterConverter:
347
+ def __init__(
348
+ self,
349
+ sim: "GutsBase",
350
+ ):
351
+ self.sim = sim.copy()
352
+
353
+ # this converts the units of exposure in the copied simulation
354
+ # and scales the exposure dataarray
355
+ self.sim._convert_exposure_units()
356
+ self.convert_parameters()
357
+ self.sim.prepare_simulation_input()
358
+ self.sim.dispatch_constructor()
359
+
360
+ # self.plot_exposure_and_effect(self.sim, sim, _id=7, data_var="D")
361
+
362
+ # if parameters are not rescaled this method should raise an error
363
+ self.validate_parameter_conversion_default_params(sim_copy=self.sim, sim_orig=sim)
364
+ self.validate_parameter_conversion_posterior_mean(sim_copy=self.sim, sim_orig=sim)
365
+ self.validate_parameter_conversion_posterior_map(sim_copy=self.sim, sim_orig=sim)
366
+
367
+ def convert_parameters(self):
368
+ raise NotImplementedError
369
+
370
+
371
+ @staticmethod
372
+ def plot_exposure_and_effect(sim_copy, sim_orig, _id=1, data_var="survival"):
373
+ from matplotlib import pyplot as plt
374
+ fig, (ax1, ax2) = plt.subplots(2,1)
375
+ results_copy = sim_copy.evaluate(parameters=sim_copy.config.model_parameters.value_dict)
376
+ results_orig = sim_orig.evaluate(parameters=sim_orig.config.model_parameters.value_dict)
377
+
378
+ ax1.plot(results_orig.time, results_orig["exposure"].isel(id=_id), color="red", label="unscaled")
379
+ ax1.plot(results_copy.time, results_copy["exposure"].isel(id=_id), color="blue", ls="--", label="scaled")
380
+ ax2.plot(results_orig.time, results_orig[data_var].isel(id=_id), color="red", label="unscaled")
381
+ ax2.plot(results_copy.time, results_copy[data_var].isel(id=_id), color="blue", ls="--", label="scaled")
382
+ ax1.legend()
383
+ ax2.legend()
384
+ return fig
385
+
386
+ @staticmethod
387
+ def validate_parameter_conversion_default_params(sim_copy, sim_orig):
388
+ results_copy = sim_copy.evaluate(parameters=sim_copy.config.model_parameters.value_dict)
389
+ results_orig = sim_orig.evaluate(parameters=sim_orig.config.model_parameters.value_dict)
390
+
391
+ np.testing.assert_allclose(results_copy.H, results_orig.H, atol=0.1, rtol=0.05)
392
+
393
+ @staticmethod
394
+ def validate_parameter_conversion_posterior_mean(sim_copy, sim_orig):
395
+ results_copy = sim_copy.evaluate(parameters=sim_copy.point_estimate("mean", to="dict"))
396
+ results_orig = sim_orig.evaluate(parameters=sim_orig.point_estimate("mean", to="dict"))
397
+
398
+ np.testing.assert_allclose(results_copy.H, results_orig.H, atol=0.1, rtol=0.05)
399
+
400
+ @staticmethod
401
+ def validate_parameter_conversion_posterior_map(sim_copy, sim_orig):
402
+ results_copy = sim_copy.evaluate(parameters=sim_copy.point_estimate("map", to="dict"))
403
+ results_orig = sim_orig.evaluate(parameters=sim_orig.point_estimate("map", to="dict"))
404
+
405
+ np.testing.assert_allclose(results_copy.H, results_orig.H, atol=0.1, rtol=0.05)