guts-base 0.8.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

guts_base/sim/ecx.py CHANGED
@@ -9,45 +9,91 @@ from matplotlib import pyplot as plt
9
9
  from tqdm import tqdm
10
10
 
11
11
  from pymob import SimulationBase
12
+ from guts_base.sim.utils import GutsBaseError
12
13
 
13
14
  class ECxEstimator:
14
15
  """Estimates the exposure level that corresponds to a given effect. The algorithm
15
- operates by varying a given exposure profile (x_in)
16
+ operates by varying a given exposure profile (x_in). For each new estimation, a new
17
+ estimator is initialized.
18
+
19
+ Parameters
20
+ ----------
21
+
22
+ sim : SimulationBase
23
+ This must be a pymob.SimulationBase object. If the ECxEstimator.estimate method
24
+ is used with the modes 'draw' or 'mean'
25
+
16
26
  """
17
27
  _name = "EC"
28
+ _parameter_msg = (
29
+ "Manual estimation (mode='manual', without using posterior information) requires " +
30
+ "specification of parameters={...}. You can obtain and modify " +
31
+ "parameters using the pymob API: `sim.config.model_parameters.value_dict` " +
32
+ "returns a dictionary of DEFAULT PARAMETERS that you can customize to your liking " +
33
+ "(https://pymob.readthedocs.io/en/stable/api/pymob.sim.html#pymob.sim.config.Modelparameters.value_dict)."
34
+ )
18
35
 
19
36
  def __init__(
20
37
  self,
21
38
  sim: SimulationBase,
22
39
  effect: str,
23
- x: float=0.5,
24
- id: Optional[str]=None,
25
- time: Optional[float]=None,
26
- x_in: Optional[xr.Dataset]=None,
40
+ x: float,
41
+ time: float,
42
+ x_in: xr.Dataset,
27
43
  ):
28
44
  self.sim = sim.copy()
29
45
  self.time = time
30
46
  self.x = x
31
- self.id = id
32
47
  self.effect = effect
33
48
  self._mode = None
34
49
 
35
- if id is None:
36
- self.sim.coordinates["id"] = [self.sim.coordinates["id"][0]]
37
- else:
38
- self.sim.coordinates["id"] = [id]
50
+ # creates an empty observation dataset with the coordinates of the
51
+ # original observations (especially time), except the ID, which is overwritten
52
+ # and taken from the x_in dataset
53
+ pseudo_obs = self.sim.observations.isel(id=[0])
54
+ pseudo_obs = pseudo_obs.drop(["exposure","survival"])
55
+ pseudo_obs["id"] = x_in["id"]
56
+
57
+ self.sim.config.data_structure.survival.observed = False
58
+ self.sim.observations = pseudo_obs
39
59
 
40
- self.sim.model_parameters["x_in"] = x_in
60
+ # ensure correct coordinate order with x_in and raise errors early
61
+ self.sim.model_parameters["x_in"] = self.sim.parse_input("x_in", x_in)
41
62
 
42
- self.sim.config.data_structure.survival.observed = False
43
- self.sim.observations = self.sim.observations.sel(id=self.sim.coordinates["id"])
44
63
 
45
- # fix time after observations have been set
46
- self.sim.coordinates["time"] = [time]
64
+ # fix time after observations have been set. The outcome of the simulation
65
+ # can dependend on the time vector, because in e.g. IT models, the time resolution
66
+ # is important. Therefore the time at which the ECx is computed is just inserted
67
+ # into the time vector at the right position.
68
+ self.sim.coordinates["time"] = np.unique(np.concatenate([
69
+ self.sim.coordinates["time"], np.array(time, ndmin=1)
70
+ ]))
47
71
 
48
72
  self.sim.model_parameters["y0"] = self.sim.parse_input("y0", drop_dims="time")
49
73
  self.sim.dispatch_constructor()
50
74
 
75
+ self.results = pd.Series({
76
+ "mean": np.nan,
77
+ "q05": np.nan,
78
+ "q95": np.nan,
79
+ "std": np.nan,
80
+ "cv": np.nan,
81
+ "msg": np.nan
82
+ })
83
+
84
+ self.figure_profile_and_effect = None
85
+ self.figure_loss_curve = None
86
+
87
+
88
+ def _assert_posterior(self):
89
+ try:
90
+ p = self.sim.inferer.idata.posterior
91
+ except AttributeError:
92
+ raise GutsBaseError(
93
+ "Using mode='mode' or mode='draws', but sim did not contain a posterior. " +
94
+ "('sim.inferer.idata.posterior'). " + self._parameter_msg
95
+ )
96
+
51
97
 
52
98
 
53
99
  def _evaluate(self, factor, theta):
@@ -80,25 +126,128 @@ class ECxEstimator:
80
126
  sample = {k: v["data"] for k, v in sample.to_dict()["data_vars"].items()}
81
127
  return sample
82
128
 
83
- def plot_loss_curve(self):
84
- posterior_mean = self._posterior_mean()
129
+ def plot_loss_curve(self,
130
+ mode: Literal["draws", "mean", "manual"] = "draws",
131
+ draws: Optional[int] = None,
132
+ parameters: Optional[Dict[str,float|List[float]]] = None,
133
+ log_x0: float = 0.0,
134
+ force_draws: bool = False
135
+ ):
136
+ """
137
+ Parameters
138
+ ----------
85
139
 
86
- factor = np.linspace(-2,2, 100)
87
- y = list(map(partial(self._loss, theta=posterior_mean), factor))
140
+ mode : Literal['draws', 'mean', 'manual']
141
+ mode of estimation. mode='mean' takes the mean of the posterior and estimate
142
+ the ECx for this singular value. mode='draws' takes samples from the posterior
143
+ and estimate the ECx for each of the parameter draws. mode='manual' takes
144
+ a parameter set (Dict) in the parameters argument and uses that for estimation.
145
+ Default: 'draws'
146
+
147
+ draws : int
148
+ Number of draws to take from the posterior. Only takes effect if mode='draw'.
149
+ Raises an exception if draws < 100, because this is insufficient for a
150
+ reasonable uncertainty estimate. Default: None (using all samples from the
151
+ posterior)
152
+
153
+ parameters : Dict[str,float|list[float]]
154
+ a parameter dictionary passed used as model parameters for finding the ECx
155
+ value. Default: None
88
156
 
157
+
158
+ log_x0 : float
159
+ the starting value for the multiplication factor of the exposure profile for
160
+ the minimization algorithm. This value is on the log scale. This means,
161
+ exp(log_x0=0.0) = 1.0, which means that the log_x0=0.0 will start at an
162
+ unmodified exposure profile. Default: 0.0
163
+
164
+ force_draws : bool
165
+ Force the estimate method to accept a number of draws less than 100. Default: False
166
+
167
+ """
168
+ draws = self._check_mode_and_draws_and_parameters(mode, draws, parameters, force_draws)
169
+
170
+
171
+ factor = np.linspace(-2,2, 100) + log_x0
89
172
  fig, ax = plt.subplots(1,1, sharey=True, figsize=(4, 3))
173
+
174
+ for i in tqdm(range(draws)):
175
+ if mode == "draws":
176
+ sample = self._posterior_sample(i)
177
+ elif mode == "mean":
178
+ sample = self._posterior_mean()
179
+ elif mode == "manual":
180
+ sample = parameters
181
+ else:
182
+ raise NotImplementedError(
183
+ f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
184
+ )
185
+
186
+ y = list(map(partial(self._loss, theta=sample), factor))
187
+
188
+ ax.plot(
189
+ np.exp(factor), y,
190
+ color="black",
191
+ )
192
+
90
193
  ax.plot(
91
- np.exp(factor), y,
92
- color="black",
194
+ [], [], color="black",
93
195
  label=f"$\ell = S(t={self.time},x_{{in}}=C_{{ext}} \cdot \phi) - {self.x}$"
94
196
  )
95
197
  ax.set_ylabel("Loss ($\ell$)")
96
198
  ax.set_xlabel("Multiplication factor ($\phi$)")
97
199
  ax.set_title(f"ID: {self.sim.coordinates['id'][0]}")
98
- ax.set_ylim(0, np.max(y) * 1.25)
200
+ ax.set_ylim(0, ax.get_ylim()[1] * 1.25)
99
201
  ax.legend(frameon=False)
100
202
  fig.tight_layout()
101
203
 
204
+ self.figure_loss_curve = fig
205
+
206
+ def _check_mode_and_draws_and_parameters(self, mode, draws, parameters, force_draws):
207
+
208
+ if mode == "draws":
209
+ self._assert_posterior()
210
+
211
+ if draws is None:
212
+ draws = (
213
+ self.sim.inferer.idata.posterior.sizes["chain"] *
214
+ self.sim.inferer.idata.posterior.sizes["draw"]
215
+ )
216
+ elif draws < 100 and not force_draws:
217
+ raise GutsBaseError(
218
+ "draws must be larger than 100. Preferably > 1000. " +
219
+ f"If you don't want uncertainty assessment of the {self._name} " +
220
+ "estimates, use mode='mean'. If you really want to use less than " +
221
+ "100 draws, use force_draws=True at your own risk."
222
+ )
223
+ else:
224
+ pass
225
+
226
+ warnings.warn(
227
+ "Values passed to 'parameters' don't have an effect in mode='draws'"
228
+ )
229
+
230
+ elif mode == "mean":
231
+ self._assert_posterior()
232
+
233
+ draws = 1
234
+
235
+ warnings.warn(
236
+ "Values passed to 'parameters' don't have an effect in mode='draws'"
237
+ )
238
+
239
+ elif mode == "manual":
240
+ draws = 1
241
+ if parameters is None:
242
+ raise GutsBaseError(self._parameter_msg)
243
+ else:
244
+ raise GutsBaseError(
245
+ f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
246
+ )
247
+
248
+ return draws
249
+
250
+
102
251
  def estimate(
103
252
  self,
104
253
  mode: Literal["draws", "mean", "manual"] = "draws",
@@ -108,6 +257,8 @@ class ECxEstimator:
108
257
  accept_tol: float = 1e-5,
109
258
  optimizer_tol: float = 1e-5,
110
259
  method: str = "cobyla",
260
+ show_plot: bool = True,
261
+ force_draws: bool = False,
111
262
  **optimizer_kwargs
112
263
  ):
113
264
  """The minimizer for the EC_x operates on the unbounded linear scale, estimating
@@ -157,40 +308,19 @@ class ECxEstimator:
157
308
  method : str
158
309
  Minization algorithm. See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
159
310
  Default: 'cobyla'
311
+
312
+ show_plot : bool
313
+ Show the results plot of the lpx. Default: True
160
314
 
315
+ force_draws : bool
316
+ Force the estimate method to accept a number of draws less than 100. Default: False
317
+
161
318
  optimizer_kwargs :
162
319
  Additional arguments to pass to the optimizer
163
320
 
164
321
  """
165
322
  x0_tries = np.array([0.0, -1.0, 1.0, -2.0, 2.0]) + log_x0
166
-
167
- if mode == "draws":
168
- if draws is None:
169
- draws = (
170
- self.sim.inferer.idata.posterior.sizes["chain"] *
171
- self.sim.inferer.idata.posterior.sizes["draw"]
172
- )
173
- elif draws < 100:
174
- raise ValueError(
175
- "draws must be larger than 100. Preferably > 1000. "
176
- f"If you don't want uncertainty assessment of the {self._name} "
177
- "estimates, use mode='mean'"
178
- )
179
- else:
180
- pass
181
-
182
- elif mode == "mean":
183
- draws = 1
184
- elif mode == "manual":
185
- draws = 1
186
- if parameters is None:
187
- raise ValueError(
188
- "parameters need to be provided if mode='manual'"
189
- )
190
- else:
191
- raise NotImplementedError(
192
- f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
193
- )
323
+ draws = self._check_mode_and_draws_and_parameters(mode, draws, parameters, force_draws)
194
324
 
195
325
  self._mode = mode
196
326
  mult_factor = []
@@ -220,6 +350,7 @@ class ECxEstimator:
220
350
  )
221
351
 
222
352
  success = opt_res.fun < accept_tol
353
+ iteration += 1
223
354
 
224
355
  # convert to linear scale from log scale
225
356
  factor = np.exp(opt_res.x)
@@ -229,24 +360,43 @@ class ECxEstimator:
229
360
  loss.append(opt_res.fun)
230
361
 
231
362
  res_full = pd.DataFrame(dict(factor = mult_factor, loss=loss, retries=iterations))
232
- if sum(res_full.loss >= accept_tol) > 0:
363
+ self.results_full = res_full
364
+
365
+ metric = "{name}_{x}".format(name=self._name, x=int(self.x * 100),)
366
+
367
+ successes = sum(res_full.loss < accept_tol)
368
+ if successes < draws:
233
369
  warnings.warn(
234
- f"Not all optimizations converged on the {self._name}_{self.x}. " +
370
+ f"Not all optimizations converged on the {metric} ({successes/draws*100}%). " +
235
371
  "Adjust starting values and method")
236
372
  print(res_full)
373
+
374
+ short_msg = f"Estimation success rate: {successes/draws*100}%"
375
+ self.results["msg"] = short_msg
237
376
 
238
377
  res = res_full.loc[res_full.loss < accept_tol,:]
239
378
 
240
- summary = {
241
- "mean": np.round(np.mean(res.factor.values), 4),
242
- "q05": np.round(np.quantile(res.factor.values, 0.05), 4),
243
- "q95": np.round(np.quantile(res.factor.values, 0.95), 4),
244
- "std": np.round(np.std(res.factor.values), 4),
245
- "cv": np.round(np.std(res.factor.values)/np.mean(res.factor.values), 2),
246
- }
379
+ if len(res) == 0:
380
+ self.msg = (
381
+ f"{metric} could not be found. Two reasons typically cause this problem: "+
382
+ f"1) no expoure before the time at which the {metric} is calculated. "+
383
+ "Check the the exposure profile. " +
384
+ f"2) Too high background mortality. If the time at which the {metric} is "+
385
+ f"calculated is large and background mortality is high, the {metric}, " +
386
+ "may be reached independent of the effect and optimization cannot succeed."
387
+ )
247
388
 
248
- self.results = pd.Series(summary)
249
- self.results_full = res_full
389
+ print(self.msg)
390
+ return
391
+
392
+ self.results["mean"] = np.round(np.mean(res.factor.values), 4)
393
+ self.results["q05"] = np.round(np.quantile(res.factor.values, 0.05), 4)
394
+ self.results["q95"] = np.round(np.quantile(res.factor.values, 0.95), 4)
395
+ self.results["std"] = np.round(np.std(res.factor.values), 4)
396
+ self.results["cv"] = np.round(np.std(res.factor.values)/np.mean(res.factor.values), 2)
397
+
398
+ if show_plot:
399
+ self.plot_profile_and_effect(parameters=parameters)
250
400
 
251
401
  print("{name}_{x}".format(name=self._name, x=int(self.x * 100),))
252
402
  print(self.results)
@@ -262,7 +412,7 @@ class ECxEstimator:
262
412
  self.sim.dispatch_constructor()
263
413
 
264
414
  if self._mode is None:
265
- raise RuntimeError(
415
+ raise GutsBaseError(
266
416
  "Run .estimate() before plot_profile_and_effect()"
267
417
  )
268
418
  elif self._mode == "mean" or self._mode == "draws":
@@ -327,6 +477,8 @@ class ECxEstimator:
327
477
  ax2.set_ylim(0, None)
328
478
  fig.tight_layout()
329
479
 
480
+ self.figure_profile_and_effect = fig
481
+
330
482
  self.sim.coordinates["time"] = coordinates_backup
331
483
  self.sim.dispatch_constructor()
332
484
 
@@ -348,4 +500,10 @@ class LPxEstimator(ECxEstimator):
348
500
  ):
349
501
  x_in = sim.model_parameters["x_in"].sel(id=[id])
350
502
  time = sim.coordinates["time"][-1]
351
- super().__init__(sim=sim, effect="survival", x=x, id=id, time=time, x_in=x_in)
503
+ super().__init__(
504
+ sim=sim,
505
+ effect="survival",
506
+ x=x,
507
+ time=time,
508
+ x_in=x_in
509
+ )
guts_base/sim/mempy.py CHANGED
@@ -1,44 +1,64 @@
1
1
  import pathlib
2
- from typing import Dict, Optional, Literal
2
+ from typing import Dict, Optional, Literal, Protocol, TypedDict, List
3
3
  import re
4
-
4
+ import os
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  import xarray as xr
8
- from pymob import SimulationBase
9
8
  from pymob.sim.config import Config, DataVariable, Datastructure
10
9
  from pymob.sim.parameters import Param
11
10
  from guts_base.sim import GutsBase
12
- from mempy.model import (
13
- Model,
14
- RED_IT,
15
- RED_SD,
16
- RED_IT_DA,
17
- RED_SD_DA,
18
- RED_IT_IA,
19
- RED_SD_IA,
20
- BufferGUTS_IT,
21
- BufferGUTS_IT_CA,
22
- BufferGUTS_IT_DA
23
- )
24
11
 
25
12
  __all__ = [
26
13
  "PymobSimulator",
27
14
  ]
28
15
 
16
+ class ParamsInfoDict(TypedDict):
17
+ name: str
18
+ min: float
19
+ max: float
20
+ initial: float
21
+ vary: bool
22
+ prior: str
23
+
24
+ class StateVariablesDict(TypedDict):
25
+ dimensions: List[str]
26
+ observed: bool
27
+ y0: List[float]
28
+
29
+ class Model(Protocol):
30
+ extra_dim: Optional[str]
31
+ params_info: Dict[str, ParamsInfoDict]
32
+ state_variables: Dict[str, StateVariablesDict]
33
+ _params_info_defaults: Dict[str, ParamsInfoDict]
34
+ _it_model: bool
35
+
36
+ @staticmethod
37
+ def _rhs_jax():
38
+ raise NotImplementedError
39
+
40
+ @staticmethod
41
+ def _solver_post_processing():
42
+ raise NotImplementedError
43
+
44
+ @staticmethod
45
+ def _likelihood_func_jax():
46
+ raise NotImplementedError
47
+
48
+
29
49
  class PymobSimulator(GutsBase):
30
50
 
31
51
  @classmethod
32
- def from_mempy(
52
+ def from_model_and_dataset(
33
53
  cls,
34
- exposure_data: Dict,
35
- survival_data: Dict,
36
54
  model: Model,
55
+ exposure_data: Dict[str, pd.DataFrame],
56
+ survival_data: pd.DataFrame,
37
57
  info_dict: Dict = {},
38
58
  pymob_config: Optional[Config] = None,
39
59
  output_directory: str|pathlib.Path = pathlib.Path("output/pymob"),
40
60
  default_prior: Literal["uniform", "lognorm"] = "lognorm",
41
- ) -> SimulationBase:
61
+ ) -> "PymobSimulator":
42
62
  """Construct a PymobSimulator from the
43
63
  """
44
64
 
@@ -46,7 +66,7 @@ class PymobSimulator(GutsBase):
46
66
  cfg = Config()
47
67
  # Configure: The configuration can be overridden in a subclass to override the
48
68
  # configuration
49
- cls.configure(config=cfg)
69
+ cls._configure(config=cfg)
50
70
  else:
51
71
  cfg = pymob_config
52
72
 
@@ -55,6 +75,14 @@ class PymobSimulator(GutsBase):
55
75
 
56
76
  cfg.case_study.output = str(output_directory)
57
77
 
78
+ # overrides scenario path. This means the scenario is also expected in the
79
+ # same folder
80
+ cfg.case_study.scenario_path_override = str(output_directory)
81
+ cfg.case_study.scenario = output_directory.stem
82
+ cfg.case_study.data = cfg.case_study.output_path
83
+ cfg.case_study.observations = "observations.nc"
84
+ cfg.create_directory(directory="results", force=True)
85
+
58
86
  # parse observations
59
87
  # obs can be simply subset by selection obs.sel(substance="Exposure-Dime")
60
88
  observations = xr.combine_by_coords([
@@ -62,9 +90,19 @@ class PymobSimulator(GutsBase):
62
90
  cls._survival_data_to_xarray(survival_data)
63
91
  ])
64
92
 
93
+ observations.to_netcdf(
94
+ os.path.join(cfg.case_study.output_path, cfg.case_study.observations)
95
+ )
96
+
65
97
  # configure model and likelihood function
66
- cfg.simulation.model = type(model).__name__
67
- cfg.inference_numpyro.user_defined_error_model = str(model._likelihood_func_jax.__name__)
98
+ # extract the fully qualified name of the model module.name
99
+ cfg.simulation.model_class = "{module}.{name}".format(
100
+ module=model.__module__, name=type(model).__name__
101
+ )
102
+ cfg.inference_numpyro.user_defined_error_model = "{module}.{name}".format(
103
+ module=model._likelihood_func_jax.__module__,
104
+ name=model._likelihood_func_jax.__name__
105
+ )
68
106
 
69
107
  # derive data structure and params from the model instance
70
108
  cls._set_data_structure(config=cfg, model=model)
@@ -75,35 +113,37 @@ class PymobSimulator(GutsBase):
75
113
  cfg.simulation.y0 = [f"{k}={v['y0']}" for k, v in model.state_variables.items() if "y0" in v]
76
114
 
77
115
  # create a simulation object
116
+ # It is essential that all post processing tasks are done in self.setup()
117
+ # which is extended below. This ensures that the simulation can also be run
118
+ # from automated tools like pymob-infer
78
119
  sim = cls(config=cfg)
79
- sim.config.create_directory(directory="results", force=True)
80
-
81
- # initialize
82
- sim.load_modules()
83
- sim.set_logger()
84
-
85
- sim.initialize(input={"observations": observations, "model": model})
86
-
87
- sim.validate()
88
- sim.dispatch_constructor()
89
-
90
-
120
+ sim.setup()
91
121
  return sim
92
122
 
93
- def initialize(self, input=None):
94
- self.model = input["model"]._rhs_jax
95
- self.solver_post_processing = input["model"]._solver_post_processing
123
+ def reset_observations(self):
124
+ """Resets the observations to the original observations after using .from_mempy(...)
125
+ This also resets the sim.coordinates dictionary.
126
+ """
127
+
128
+ self.observations = self._obs_backup
96
129
 
97
- super().initialize(input=input)
130
+ def setup(self, **evaluator_kwargs):
131
+ super().setup(**evaluator_kwargs)
132
+ self._obs_backup = self.observations.copy(deep=True)
98
133
 
99
134
 
100
135
  @classmethod
101
- def configure(cls, config: Config):
136
+ def _configure(cls, config: Config):
102
137
  """This is normally set in the configuration file passed to a SimulationBase class.
103
138
  Since the mempy to pymob converter initializes pymob.SimulationBase from scratch
104
139
  (without using a config file), the necessary settings have to be specified here.
105
140
  """
106
141
  config.case_study.output = "results"
142
+ config.case_study.simulation = "PymobSimulator"
143
+
144
+ # this must be named guts_base, whihc is the name of the pip package and
145
+ # this regulates which packages are loaded.
146
+ config.case_study.name = "guts_base"
107
147
 
108
148
  config.simulation.x_dimension = "time"
109
149
  config.simulation.batch_dimension = "id"
@@ -132,33 +172,6 @@ class PymobSimulator(GutsBase):
132
172
  config.inference_numpyro.svi_iterations = 10_000
133
173
  config.inference_numpyro.svi_learning_rate = 0.001
134
174
 
135
- @staticmethod
136
- def _exposure_data_to_xarray(exposure_data: Dict[str, pd.DataFrame], dim: str):
137
- """
138
- TODO: Currently no rect interpolation
139
- """
140
- arrays = {}
141
- for key, df in exposure_data.items():
142
- # this override is necessary to make all dimensions work out
143
- df.index.name = "time"
144
- arrays.update({
145
- key: df.to_xarray().to_dataarray(dim="id", name=key)
146
- })
147
-
148
- exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
149
- exposure_array = exposure_array.transpose("id", "time", ...)
150
- return xr.Dataset({"exposure": exposure_array})
151
-
152
- @staticmethod
153
- def _survival_data_to_xarray(survival_data: pd.DataFrame):
154
- # TODO: survival name is currently not kept because the raw data is not transferred from the survival
155
- survival_data.index.name = "time"
156
-
157
- survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
158
- survival_array = survival_array.transpose("id", "time", ...)
159
- arrays = {"survival": survival_array}
160
- return xr.Dataset(arrays)
161
-
162
175
  @classmethod
163
176
  def _set_data_structure(cls, config: Config, model: Model):
164
177
  """Takes a dictionary that is specified in the model and uses only keys that
@@ -179,10 +192,7 @@ class PymobSimulator(GutsBase):
179
192
  def _set_params(cls, config: Config, model: Model, default_prior: str):
180
193
  params_info = model.params_info
181
194
 
182
- if isinstance(model, (
183
- RED_IT, RED_IT_DA, RED_IT_IA,
184
- BufferGUTS_IT, BufferGUTS_IT_CA, BufferGUTS_IT_DA
185
- )):
195
+ if model._it_model:
186
196
  eps = config.jaxsolver.atol * 10
187
197
  params_info["eps"] = {'name':'eps', 'initial':eps, 'vary':False}
188
198
 
@@ -210,6 +220,11 @@ class PymobSimulator(GutsBase):
210
220
  _init = group["initial"].values.astype(float)
211
221
  _free = group["vary"].values
212
222
 
223
+ if isinstance(_min, np.ma.core.MaskedConstant):
224
+ _min = None
225
+ if isinstance(_max, np.ma.core.MaskedConstant):
226
+ _max = None
227
+
213
228
  # TODO: allow for parsing one N-D prior from multiple priors
214
229
  # TODO: Another choice would be to parse vary=False priors as deterministic
215
230
  # and use a composite prior from a deterministic and a free prior as
guts_base/sim/report.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import os
2
2
  import itertools as it
3
+ from typing import List
3
4
  import pandas as pd
4
5
 
5
6
  from pymob import SimulationBase
@@ -9,6 +10,8 @@ from guts_base.plot import plot_survival_multipanel
9
10
  from guts_base.sim.ecx import ECxEstimator
10
11
 
11
12
  class GutsReport(Report):
13
+ ecx_estimates_times: List = [1, 2, 4, 10]
14
+ ecx_estimates_x: List = [0.1, 0.25, 0.5, 0.75, 0.9]
12
15
 
13
16
  def additional_reports(self, sim: "SimulationBase"):
14
17
  super().additional_reports(sim=sim)
@@ -33,9 +36,9 @@ class GutsReport(Report):
33
36
 
34
37
  @reporting
35
38
  def LCx_estimates(self, sim):
36
- X = [0.1, 0.25, 0.5, 0.75, 0.9]
37
- T = [1, 2]
38
- P = sim.predefined_scenarios
39
+ X = self.ecx_estimates_x
40
+ T = self.ecx_estimates_times
41
+ P = sim.predefined_scenarios()
39
42
 
40
43
  estimates = pd.DataFrame(
41
44
  it.product(X, T, P.keys()),
@@ -49,7 +52,6 @@ class GutsReport(Report):
49
52
  sim=sim,
50
53
  effect="survival",
51
54
  x=row.x,
52
- id=None,
53
55
  time=row.time,
54
56
  x_in=P[row.scenario],
55
57
  )
@@ -57,6 +59,7 @@ class GutsReport(Report):
57
59
  ecx_estimator.estimate(
58
60
  mode=sim.ecx_mode,
59
61
  draws=250,
62
+ show_plot=False
60
63
  )
61
64
 
62
65
  ecx.append(ecx_estimator.results)
@@ -64,9 +67,6 @@ class GutsReport(Report):
64
67
  results = pd.DataFrame(ecx)
65
68
  estimates[results.columns] = results
66
69
 
67
-
68
- estimates.to_csv()
69
- file = os.path.join(sim.output_path, "lcx_estimates.csv")
70
- lab = self._label.format(placeholder='$LC_x$ estimates')
71
- self._write_table(tab=estimates, label_insert=f"$LC_x$ estimates \label{{{lab}}}]({os.path.basename(file)})")
70
+ out = self._write_table(tab=estimates, label_insert="$LC_x$ estimates")
72
71
 
72
+ return out