guts-base 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

guts_base/sim/ecx.py ADDED
@@ -0,0 +1,357 @@
1
+ import warnings
2
+ from functools import partial
3
+ import numpy as np
4
+ import xarray as xr
5
+ from typing import Literal, Optional, Dict, List
6
+ import pandas as pd
7
+ from scipy.optimize import minimize
8
+ from matplotlib import pyplot as plt
9
+ from tqdm import tqdm
10
+
11
+ from pymob import SimulationBase
12
+
13
+ class ECxEstimator:
14
+ """Estimates the exposure level that corresponds to a given effect. The algorithm
15
+ operates by varying a given exposure profile (x_in)
16
+ """
17
+ _name = "EC"
18
+
19
+ def __init__(
20
+ self,
21
+ sim: SimulationBase,
22
+ effect: str,
23
+ x: float=0.5,
24
+ id: Optional[str]=None,
25
+ time: Optional[float]=None,
26
+ x_in: Optional[xr.Dataset]=None,
27
+ ):
28
+ self.sim = sim.copy()
29
+ self.time = time
30
+ self.x = x
31
+ self.id = id
32
+ self.effect = effect
33
+ self._mode = None
34
+
35
+ if id is None:
36
+ self.sim.coordinates["id"] = [self.sim.coordinates["id"][0]]
37
+ else:
38
+ self.sim.coordinates["id"] = [id]
39
+
40
+ self.sim.model_parameters["x_in"] = x_in
41
+
42
+ # self.sim.observations = self.sim.expand_batch_like_coordinate_to_new_dimension(
43
+ # coordinate="exposure_path",
44
+ # variables=["Flupyradifurone"]
45
+ # )
46
+
47
+ # self.sim.config.data_structure.remove("Flupyradifurone")
48
+
49
+ # # TODO: COnstruct a sim if the input dims change
50
+ # self.sim.config.data_structure.exposure.dimensions = ["id", "time", "exposure_path"]
51
+ self.sim.config.data_structure.survival.observed = False
52
+ self.sim.observations = self.sim.observations.sel(id=self.sim.coordinates["id"])
53
+
54
+ self.sim.model_parameters["y0"] = self.sim.parse_input("y0", drop_dims="time")
55
+ self.sim.dispatch_constructor()
56
+
57
+
58
+
59
+ def _evaluate(self, factor, theta):
60
+ evaluator = self.sim.dispatch(
61
+ theta=theta,
62
+ x_in=self.sim.validate_model_input(self.sim.model_parameters["x_in"] * factor)
63
+ )
64
+ evaluator()
65
+ return evaluator
66
+
67
+ def _loss(self, log_factor, theta):
68
+ # exponentiate the log factor
69
+ factor = np.exp(log_factor)
70
+
71
+ e = self._evaluate(factor, theta)
72
+ s = e.results.sel(time=self.time)[self.effect].values
73
+
74
+ return (s - (1 - self.x)) ** 2
75
+
76
+ def _posterior_mean(self):
77
+ mean = self.sim.inferer.idata.posterior.mean(("chain", "draw"))
78
+ mean = {k: v["data"] for k, v in mean.to_dict()["data_vars"].items()}
79
+ return mean
80
+
81
+ def _posterior_sample(self, i):
82
+ posterior_stacked = self.sim.inferer.idata.posterior.stack(
83
+ sample=("chain", "draw")
84
+ )
85
+ sample = posterior_stacked.isel(sample=i)
86
+ sample = {k: v["data"] for k, v in sample.to_dict()["data_vars"].items()}
87
+ return sample
88
+
89
+ def plot_loss_curve(self):
90
+ posterior_mean = self._posterior_mean()
91
+
92
+ factor = np.linspace(-2,2, 100)
93
+ y = list(map(partial(self._loss, theta=posterior_mean), factor))
94
+
95
+ fig, ax = plt.subplots(1,1, sharey=True, figsize=(4, 3))
96
+ ax.plot(
97
+ np.exp(factor), y,
98
+ color="black",
99
+ label=f"$\ell = S(t={self.time},x_{{in}}=C_{{ext}} \cdot \phi) - {self.x}$"
100
+ )
101
+ ax.set_ylabel("Loss ($\ell$)")
102
+ ax.set_xlabel("Multiplication factor ($\phi$)")
103
+ ax.set_title(f"ID: {self.sim.coordinates['id'][0]}")
104
+ ax.set_ylim(0, np.max(y) * 1.25)
105
+ ax.legend(frameon=False)
106
+ fig.tight_layout()
107
+
108
+ def estimate(
109
+ self,
110
+ mode: Literal["draws", "mean", "manual"] = "draws",
111
+ draws: Optional[int] = None,
112
+ parameters: Optional[Dict[str,float|List[float]]] = None,
113
+ log_x0: float = 0.0,
114
+ accept_tol: float = 1e-5,
115
+ optimizer_tol: float = 1e-5,
116
+ method: str = "cobyla",
117
+ **optimizer_kwargs
118
+ ):
119
+ """The minimizer for the EC_x operates on the unbounded linear scale, estimating
120
+ the log-modification factor. Converted to the linear scale by factor=exp(x), the
121
+ profile modification factor is obtained.
122
+
123
+ Using x0=0.0 means optimization will start on the linear scale at the unmodified
124
+ exposure profile. Using the log scale for optimization will provide much smoother
125
+ optimization performance because multiplicative steps on the log scale require
126
+ much less adaptation.
127
+
128
+ Parameters
129
+ ----------
130
+
131
+ mode : Literal['draws', 'mean', 'manual']
132
+ mode of estimation. mode='mean' takes the mean of the posterior and estimate
133
+ the ECx for this singular value. mode='draws' takes samples from the posterior
134
+ and estimate the ECx for each of the parameter draws. mode='manual' takes
135
+ a parameter set (Dict) in the parameters argument and uses that for estimation.
136
+ Default: 'draws'
137
+
138
+ draws : int
139
+ Number of draws to take from the posterior. Only takes effect if mode='draw'.
140
+ Raises an exception if draws < 100, because this is insufficient for a
141
+ reasonable uncertainty estimate. Default: None (using all samples from the
142
+ posterior)
143
+
144
+ parameters : Dict[str,float|list[float]]
145
+ a parameter dictionary passed used as model parameters for finding the ECx
146
+ value. Default: None
147
+
148
+ log_x0 : float
149
+ the starting value for the multiplication factor of the exposure profile for
150
+ the minimization algorithm. This value is on the log scale. This means,
151
+ exp(log_x0=0.0) = 1.0, which means that the log_x0=0.0 will start at an
152
+ unmodified exposure profile. Default: 0.0
153
+
154
+ accept_tol : float
155
+ After optimization is finished, accept_tol is used to assess if the loss
156
+ function for the individual draws exceed a tolerance. These results are
157
+ discarded and a warning is emitted. This is to assert that no faulty optimization
158
+ results enter the estimate. Default: 1e-5
159
+
160
+ optimizer_tol : float
161
+ Tolerance limit for the minimzer to stop optimization. Default 1e-5
162
+
163
+ method : str
164
+ Minization algorithm. See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
165
+ Default: 'cobyla'
166
+
167
+ optimizer_kwargs :
168
+ Additional arguments to pass to the optimizer
169
+
170
+ """
171
+ x0_tries = np.array([0.0, -1.0, 1.0, -2.0, 2.0]) + log_x0
172
+
173
+ if mode == "draws":
174
+ if draws is None:
175
+ draws = (
176
+ self.sim.inferer.idata.posterior.sizes["chain"] *
177
+ self.sim.inferer.idata.posterior.sizes["draw"]
178
+ )
179
+ elif draws < 100:
180
+ raise ValueError(
181
+ "draws must be larger than 100. Preferably > 1000. "
182
+ f"If you don't want uncertainty assessment of the {self._name} "
183
+ "estimates, use mode='mean'"
184
+ )
185
+ else:
186
+ pass
187
+
188
+ elif mode == "mean":
189
+ draws = 1
190
+ elif mode == "manual":
191
+ draws = 1
192
+ if parameters is None:
193
+ raise ValueError(
194
+ "parameters need to be provided if mode='manual'"
195
+ )
196
+ else:
197
+ raise NotImplementedError(
198
+ f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
199
+ )
200
+
201
+ self._mode = mode
202
+ mult_factor = []
203
+ loss = []
204
+ iterations = []
205
+ for i in tqdm(range(draws)):
206
+ if mode == "draws":
207
+ sample = self._posterior_sample(i)
208
+ elif mode == "mean":
209
+ sample = self._posterior_mean()
210
+ elif mode == "manual":
211
+ sample = parameters
212
+ else:
213
+ raise NotImplementedError(
214
+ f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
215
+ )
216
+
217
+ success = False
218
+ iteration = 0
219
+ while not success and iteration < len(x0_tries):
220
+ opt_res = minimize(
221
+ self._loss, x0=x0_tries[iteration],
222
+ method=method,
223
+ tol=optimizer_tol,
224
+ args=(sample,),
225
+ **optimizer_kwargs
226
+ )
227
+
228
+ success = opt_res.fun < accept_tol
229
+
230
+ # convert to linear scale from log scale
231
+ factor = np.exp(opt_res.x)
232
+
233
+ mult_factor.extend(factor)
234
+ iterations.append(iteration)
235
+ loss.append(opt_res.fun)
236
+
237
+ res_full = pd.DataFrame(dict(factor = mult_factor, loss=loss, retries=iterations))
238
+ if sum(res_full.loss >= accept_tol) > 0:
239
+ warnings.warn(
240
+ f"Not all optimizations converged on the {self._name}_{self.x}. " +
241
+ "Adjust starting values and method")
242
+ print(res_full)
243
+
244
+ res = res_full.loc[res_full.loss < accept_tol,:]
245
+
246
+ summary = {
247
+ "mean": np.round(np.mean(res.factor.values), 4),
248
+ "q05": np.round(np.quantile(res.factor.values, 0.05), 4),
249
+ "q95": np.round(np.quantile(res.factor.values, 0.95), 4),
250
+ "std": np.round(np.std(res.factor.values), 4),
251
+ "cv": np.round(np.std(res.factor.values)/np.mean(res.factor.values), 2),
252
+ }
253
+
254
+ self.results = pd.Series(summary)
255
+ self.results_full = res_full
256
+
257
+ print("{name}_{x}".format(name=self._name, x=int(self.x * 100),))
258
+ print(self.results)
259
+ print("\n")
260
+
261
+ def plot_profile_and_effect(
262
+ self,
263
+ parameters: Optional[Dict[str,float|List[float]]] = None
264
+ ):
265
+ coordinates_backup = self.sim.coordinates["time"].copy()
266
+
267
+ self.sim.coordinates["time"] = np.linspace(0, self.time, 100)
268
+ self.sim.dispatch_constructor()
269
+
270
+ if self._mode is None:
271
+ raise RuntimeError(
272
+ "Run .estimate() before plot_profile_and_effect()"
273
+ )
274
+ elif self._mode == "mean" or self._mode == "draws":
275
+ e_new = self._evaluate(factor=self.results["mean"], theta=self._posterior_mean())
276
+ e_old = self._evaluate(factor=1.0, theta=self._posterior_mean())
277
+ elif self._mode == "manual":
278
+ if parameters is None:
279
+ raise RuntimeError(
280
+ f"If {self._name}_x was estimated using manual mode, parameters must "+
281
+ "also be provided here."
282
+ )
283
+ e_new = self._evaluate(factor=self.results["mean"], theta=parameters)
284
+ e_old = self._evaluate(factor=1.0, theta=parameters)
285
+
286
+ extra_dim = [k for k in list(e_old.results.coords.keys()) if k not in ["time", "id"]]
287
+
288
+ if len(extra_dim) > 0:
289
+ labels_old = [
290
+ f"{l} (original)" for l
291
+ in e_old.results.coords[extra_dim[0]].values
292
+ ]
293
+ labels_new = [
294
+ f"{l} (modified)" for l
295
+ in e_new.results.coords[extra_dim[0]].values
296
+ ]
297
+ else:
298
+ labels_old = "original"
299
+ labels_new = "modified"
300
+
301
+
302
+
303
+ fig, (ax1, ax2) = plt.subplots(2,1, height_ratios=[1,3], sharex=True)
304
+ ax1.plot(
305
+ e_old.results.time, e_old.results.exposure.isel(id=0),
306
+ ls="--", label=labels_old,
307
+ )
308
+ ax1.set_prop_cycle(None)
309
+ ax1.plot(
310
+ e_new.results.time, e_new.results.exposure.isel(id=0),
311
+ label=labels_new
312
+ )
313
+
314
+
315
+ ax2.plot(
316
+ e_new.results.time, e_new.results.survival.isel(id=0),
317
+ color="black", ls="--", label="modified"
318
+ )
319
+ ax1.set_prop_cycle(None)
320
+
321
+ ax2.plot(
322
+ e_old.results.time, e_old.results.survival.isel(id=0),
323
+ color="black", ls="-", label="original"
324
+ )
325
+ ax2.hlines(self.x, e_new.results.time[0], self.time, color="grey")
326
+ ax1.set_ylabel("Exposure")
327
+ ax2.set_ylabel("Survival")
328
+ ax2.set_xlabel("Time")
329
+ ax1.legend()
330
+ ax2.legend()
331
+ ax2.set_xlim(0, None)
332
+ ax1.set_ylim(0, None)
333
+ ax2.set_ylim(0, None)
334
+ fig.tight_layout()
335
+
336
+ self.sim.coordinates["time"] = coordinates_backup
337
+ self.sim.dispatch_constructor()
338
+
339
+
340
+
341
+ class LPxEstimator(ECxEstimator):
342
+ """
343
+ the LPx is computed, using the existing exposure profile for
344
+ the specified ID and estimating the multiplication factor for the profile that results
345
+ in an effect of X %
346
+ """
347
+ _name = "LP"
348
+
349
+ def __init__(
350
+ self,
351
+ sim: SimulationBase,
352
+ id: str,
353
+ x: float=0.5
354
+ ):
355
+ x_in = sim.model_parameters["x_in"].sel(id=[id])
356
+ time = sim.coordinates["time"][-1]
357
+ super().__init__(sim=sim, effect="survival", x=x, id=id, time=time, x_in=x_in)
guts_base/sim/mempy.py ADDED
@@ -0,0 +1,252 @@
1
+ import pathlib
2
+ from typing import Dict, Optional, Literal
3
+ import re
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import xarray as xr
8
+ from pymob import SimulationBase
9
+ from pymob.sim.config import Config, DataVariable, Datastructure
10
+ from pymob.sim.parameters import Param
11
+ from guts_base.sim import GutsBase
12
+ from mempy.model import (
13
+ Model,
14
+ RED_IT,
15
+ RED_SD,
16
+ RED_IT_DA,
17
+ RED_SD_DA,
18
+ RED_IT_IA,
19
+ RED_SD_IA,
20
+ BufferGUTS_IT,
21
+ BufferGUTS_IT_CA,
22
+ BufferGUTS_IT_DA
23
+ )
24
+
25
+ __all__ = [
26
+ "PymobSimulator",
27
+ ]
28
+
29
+ class PymobSimulator(GutsBase):
30
+
31
+ @classmethod
32
+ def from_mempy(
33
+ cls,
34
+ exposure_data: Dict,
35
+ survival_data: Dict,
36
+ model: Model,
37
+ info_dict: Dict = {},
38
+ pymob_config: Optional[Config] = None,
39
+ output_directory: str|pathlib.Path = pathlib.Path("output/pymob"),
40
+ default_prior: Literal["uniform", "lognorm"] = "lognorm",
41
+ ) -> SimulationBase:
42
+ """Construct a PymobSimulator from the
43
+ """
44
+
45
+ if pymob_config is None:
46
+ cfg = Config()
47
+ # Configure: The configuration can be overridden in a subclass to override the
48
+ # configuration
49
+ cls.configure(config=cfg)
50
+ else:
51
+ cfg = pymob_config
52
+
53
+ if isinstance(output_directory, str):
54
+ output_directory = pathlib.Path(output_directory)
55
+
56
+ cfg.case_study.output = str(output_directory)
57
+
58
+ # parse observations
59
+ # obs can be simply subset by selection obs.sel(substance="Exposure-Dime")
60
+ observations = xr.combine_by_coords([
61
+ cls._exposure_data_to_xarray(exposure_data, dim=model.extra_dim),
62
+ cls._survival_data_to_xarray(survival_data)
63
+ ])
64
+
65
+ # configure model and likelihood function
66
+ cfg.simulation.model = type(model).__name__
67
+ cfg.inference_numpyro.user_defined_error_model = str(model._likelihood_func_jax.__name__)
68
+
69
+ # derive data structure and params from the model instance
70
+ cls._set_data_structure(config=cfg, model=model)
71
+ cls._set_params(config=cfg, model=model, default_prior=default_prior)
72
+
73
+ # configure starting values and input
74
+ cfg.simulation.x_in = ["exposure=exposure"]
75
+ cfg.simulation.y0 = [f"{k}={v['y0']}" for k, v in model.state_variables.items() if "y0" in v]
76
+
77
+ # create a simulation object
78
+ sim = cls(config=cfg)
79
+ sim.config.create_directory(directory="results", force=True)
80
+
81
+ # initialize
82
+ sim.load_modules()
83
+ sim.set_logger()
84
+
85
+ sim.initialize(input={"observations": observations, "model": model})
86
+
87
+ sim.validate()
88
+ sim.dispatch_constructor()
89
+
90
+
91
+ return sim
92
+
93
+ def initialize(self, input=None):
94
+ self.model = input["model"]._rhs_jax
95
+ self.solver_post_processing = input["model"]._solver_post_processing
96
+
97
+ super().initialize(input=input)
98
+
99
+
100
+ @classmethod
101
+ def configure(cls, config: Config):
102
+ """This is normally set in the configuration file passed to a SimulationBase class.
103
+ Since the mempy to pymob converter initializes pymob.SimulationBase from scratch
104
+ (without using a config file), the necessary settings have to be specified here.
105
+ """
106
+ config.case_study.output = "results"
107
+
108
+ config.simulation.x_dimension = "time"
109
+ config.simulation.batch_dimension = "id"
110
+ config.simulation.solver_post_processing = None
111
+ config.simulation.unit_time = "day"
112
+ config.simulation.n_reindexed_x = 100
113
+ config.simulation.forward_interpolate_exposure_data = True
114
+
115
+ config.inference.extra_vars = ["eps", "survivors_before_t"]
116
+ config.inference.n_predictions = 100
117
+
118
+ config.jaxsolver.diffrax_solver = "Tsit5"
119
+ config.jaxsolver.rtol = 1e-10
120
+ config.jaxsolver.atol = 1e-12
121
+ config.jaxsolver.throw_exception = True
122
+ config.jaxsolver.pcoeff = 0.3
123
+ config.jaxsolver.icoeff = 0.3
124
+ config.jaxsolver.dcoeff = 0.0
125
+ config.jaxsolver.max_steps = 1000000
126
+ config.jaxsolver.throw_exception = True
127
+
128
+
129
+ config.inference_numpyro.gaussian_base_distribution = True
130
+ config.inference_numpyro.kernel = "svi"
131
+ config.inference_numpyro.init_strategy = "init_to_median"
132
+ config.inference_numpyro.svi_iterations = 10_000
133
+ config.inference_numpyro.svi_learning_rate = 0.001
134
+
135
+ @staticmethod
136
+ def _exposure_data_to_xarray(exposure_data: Dict[str, pd.DataFrame], dim: str):
137
+ """
138
+ TODO: Currently no rect interpolation
139
+ """
140
+ arrays = {}
141
+ for key, df in exposure_data.items():
142
+ # this override is necessary to make all dimensions work out
143
+ df.index.name = "time"
144
+ arrays.update({
145
+ key: df.to_xarray().to_dataarray(dim="id", name=key)
146
+ })
147
+
148
+ exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
149
+ exposure_array = exposure_array.transpose("id", "time", ...)
150
+ return xr.Dataset({"exposure": exposure_array})
151
+
152
+ @staticmethod
153
+ def _survival_data_to_xarray(survival_data: pd.DataFrame):
154
+ # TODO: survival name is currently not kept because the raw data is not transferred from the survival
155
+ survival_data.index.name = "time"
156
+
157
+ survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
158
+ survival_array = survival_array.transpose("id", "time", ...)
159
+ arrays = {"survival": survival_array}
160
+ return xr.Dataset(arrays)
161
+
162
+ @classmethod
163
+ def _set_data_structure(cls, config: Config, model: Model):
164
+ """Takes a dictionary that is specified in the model and uses only keys that
165
+ are fields of the DataVariable config-model"""
166
+
167
+ state_dict = model.state_variables
168
+
169
+ config.data_structure = Datastructure(**{
170
+ key: DataVariable(**{
171
+ k: v for k, v in state_info.items()
172
+ if k in DataVariable.model_fields
173
+ })
174
+ for key, state_info in state_dict.items()
175
+ })
176
+
177
+
178
+ @classmethod
179
+ def _set_params(cls, config: Config, model: Model, default_prior: str):
180
+ params_info = model.params_info
181
+
182
+ if isinstance(model, (
183
+ RED_IT, RED_IT_DA, RED_IT_IA,
184
+ BufferGUTS_IT, BufferGUTS_IT_CA, BufferGUTS_IT_DA
185
+ )):
186
+ eps = config.jaxsolver.atol * 10
187
+ params_info["eps"] = {'name':'eps', 'initial':eps, 'vary':False}
188
+
189
+
190
+ for par, param_dict in params_info.items():
191
+ for k, v in model._params_info_defaults.items():
192
+ if k not in param_dict:
193
+ param_dict.update({k:v})
194
+
195
+ param_df = pd.DataFrame(params_info).T
196
+ param_df["param_index"] = param_df.name.apply(lambda x: re.findall(r"\d+", x))
197
+ param_df["param_index"] = param_df.param_index.apply(lambda x: int(x[0])-1 if len(x) == 1 else None)
198
+ param_df["name"] = param_df.name.apply(lambda x: re.sub(r"\d+", "", x).strip("_"))
199
+
200
+ for (param_name, ), group in param_df.groupby(["name"]):
201
+
202
+ dims = list(dict.fromkeys(group["dims"]))
203
+ dims = tuple([]) if dims == [None] else tuple(dims)
204
+
205
+ prior = list(dict.fromkeys(group["prior"]))
206
+ prior = prior[0] if len(prior) == 1 else prior
207
+
208
+ _min = np.min(np.ma.masked_invalid(group["min"].values.astype(float)))
209
+ _max = np.max(np.ma.masked_invalid(group["max"].values.astype(float)))
210
+ _init = group["initial"].values.astype(float)
211
+ _free = group["vary"].values
212
+
213
+ # TODO: allow for parsing one N-D prior from multiple priors
214
+ # TODO: Another choice would be to parse vary=False priors as deterministic
215
+ # and use a composite prior from a deterministic and a free prior as
216
+ # the input into the model
217
+
218
+ if prior is None:
219
+ if default_prior == "uniform":
220
+ _loc = _init * np.logical_not(_free) + _min * _free - config.jaxsolver.atol * 10 * np.logical_not(_free)
221
+ _scale = _init * np.logical_not(_free) + _max * _free + config.jaxsolver.atol * 10 * np.logical_not(_free)
222
+ _loc = _loc[0] if len(_loc) == 1 else _loc
223
+ _scale = _scale[0] if len(_scale) == 1 else _scale
224
+ prior = f"uniform(loc={_loc},scale={_scale})"
225
+ elif default_prior == "lognorm":
226
+ _s = 3 * _free + config.jaxsolver.atol * 10 * np.logical_not(_free)
227
+ _init = _init[0] if len(_init) == 1 else _init
228
+ _s = _s[0] if len(_s) == 1 else _s
229
+
230
+ prior = f"lognorm(scale={_init},s={_s})"
231
+ else:
232
+ raise ValueError(
233
+ f"Default prior: '{default_prior}' is not implemented. "+
234
+ "Use one of 'uniform', 'lognorm' or specify priors for each "+
235
+ "parameter directly with: "+
236
+ f"`model.params_dict['prior'] = {default_prior}(...)`"
237
+ )
238
+
239
+ prior = prior.replace(" ", ",")
240
+
241
+ # if isinstance(value,float):
242
+ param = Param(
243
+ value=_init,
244
+ free=np.max(_free),
245
+ min=_min,
246
+ max=_max,
247
+ prior=prior,
248
+ dims=dims
249
+ )
250
+
251
+ setattr(config.model_parameters, param_name, param)
252
+
@@ -0,0 +1,72 @@
1
+ import os
2
+ import itertools as it
3
+ import pandas as pd
4
+
5
+ from pymob import SimulationBase
6
+ from pymob.sim.report import Report, reporting
7
+
8
+ from guts_base.plot import plot_survival_multipanel
9
+ from guts_base.sim.ecx import ECxEstimator
10
+
11
+ class GutsReport(Report):
12
+
13
+ def additional_reports(self, sim: "SimulationBase"):
14
+ super().additional_reports(sim=sim)
15
+ self.model_fits(sim)
16
+ self.LCx_estimates(sim)
17
+
18
+ @reporting
19
+ def model_fits(self, sim: SimulationBase):
20
+ self._write("### Survival model fits")
21
+
22
+ out_mp = plot_survival_multipanel(
23
+ sim=sim,
24
+ results=sim.inferer.idata.posterior_model_fits,
25
+ ncols=6,
26
+ )
27
+
28
+ lab = self._label.format(placeholder='survival_fits')
29
+ self._write(f"![Surival model fits.\label{{{lab}}}]({os.path.basename(out_mp)})")
30
+
31
+ return out_mp
32
+
33
+
34
+ @reporting
35
+ def LCx_estimates(self, sim):
36
+ X = [0.1, 0.25, 0.5, 0.75, 0.9]
37
+ T = [1, 2]
38
+ P = sim.predefined_scenarios
39
+
40
+ estimates = pd.DataFrame(
41
+ it.product(X, T, P.keys()),
42
+ columns=["x", "time", "scenario"]
43
+ )
44
+
45
+ ecx = []
46
+
47
+ for i, row in estimates.iterrows():
48
+ ecx_estimator = ECxEstimator(
49
+ sim=sim,
50
+ effect="survival",
51
+ x=row.x,
52
+ id=None,
53
+ time=row.time,
54
+ x_in=P[row.scenario],
55
+ )
56
+
57
+ ecx_estimator.estimate(
58
+ mode=sim.ecx_mode,
59
+ draws=250,
60
+ )
61
+
62
+ ecx.append(ecx_estimator.results)
63
+
64
+ results = pd.DataFrame(ecx)
65
+ estimates[results.columns] = results
66
+
67
+
68
+ estimates.to_csv()
69
+ file = os.path.join(sim.output_path, "lcx_estimates.csv")
70
+ lab = self._label.format(placeholder='$LC_x$ estimates')
71
+ self._write_table(tab=estimates, label_insert=f"$LC_x$ estimates \label{{{lab}}}]({os.path.basename(file)})")
72
+
guts_base/sim.py ADDED
File without changes