guts-base 2.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
guts_base/sim/ecx.py ADDED
@@ -0,0 +1,585 @@
1
+ import warnings
2
+ from functools import partial
3
+ import numpy as np
4
+ import xarray as xr
5
+ from typing import Literal, Optional, Dict, List
6
+ import pandas as pd
7
+ from scipy.optimize import minimize
8
+ from matplotlib import pyplot as plt
9
+ from tqdm import tqdm
10
+
11
+ from pymob import SimulationBase
12
+ from guts_base.sim.utils import GutsBaseError
13
+
14
+ class ECxEstimator:
15
+ """Estimates the exposure level that corresponds to a given effect. The algorithm
16
+ operates by varying a given exposure profile (x_in). For each new estimation, a new
17
+ estimator is initialized.
18
+
19
+ Parameters
20
+ ----------
21
+
22
+ sim : SimulationBase
23
+ This must be a pymob.SimulationBase object. If the ECxEstimator.estimate method
24
+ is used with the modes 'draw' or 'mean'
25
+
26
+ effect : str
27
+ The data variable for which the effect concentration is computed. This is one
28
+ of sim.observations and sim.results
29
+
30
+ x : float
31
+ Effect level. This is the level of the effect, for which the concentration is
32
+ computed.
33
+
34
+ time : float
35
+ Time at which the ECx is computed
36
+
37
+ x_in : xr.Dataset
38
+ The model input 'x_in' for which the effect is computed.
39
+
40
+ conditionals_posterior : Dict
41
+ Dictionary that overwrites values in the posterior. This is useful if for instance
42
+ background mortality should be set to a fixed value (e.g. zero). Consequently this
43
+ setting does not take effect in estimation mode 'manual' but only for mean and
44
+ draws. Defaults to an empty dict (no conditions applied).
45
+ """
46
+ _name = "EC"
47
+ _parameter_msg = (
48
+ "Manual estimation (mode='manual', without using posterior information) requires " +
49
+ "specification of parameters={...}. You can obtain and modify " +
50
+ "parameters using the pymob API: `sim.config.model_parameters.value_dict` " +
51
+ "returns a dictionary of DEFAULT PARAMETERS that you can customize to your liking " +
52
+ "(https://pymob.readthedocs.io/en/stable/api/pymob.sim.html#pymob.sim.config.Modelparameters.value_dict)."
53
+ )
54
+
55
+ def __init__(
56
+ self,
57
+ sim: SimulationBase,
58
+ effect: str,
59
+ x: float,
60
+ time: float,
61
+ x_in: xr.Dataset,
62
+ conditions_posterior: Dict[str, float] = {}
63
+ ):
64
+ self.sim = sim.copy()
65
+ self.time = time
66
+ self.x = x
67
+ self.effect = effect
68
+ self._mode = None
69
+ self._conditions_posterior = conditions_posterior
70
+
71
+ # creates an empty observation dataset with the coordinates of the
72
+ # original observations (especially time), except the ID, which is overwritten
73
+ # and taken from the x_in dataset
74
+ pseudo_obs = self.sim.observations.isel(id=[0])
75
+ pseudo_obs = pseudo_obs.drop([v for v in pseudo_obs.data_vars.keys()])
76
+ pseudo_obs["id"] = x_in["id"]
77
+
78
+ self.sim.config.data_structure.survival.observed = False
79
+ self.sim.observations = pseudo_obs
80
+
81
+ # overwrite x_in to make sure that parse_input takes x_in from exposure and
82
+ # does not use the string that is tied to another data variable which was
83
+ # originally present
84
+ self.sim.config.simulation.x_in = ["exposure=exposure"]
85
+
86
+ # ensure correct coordinate order with x_in and raise errors early
87
+ self.sim.model_parameters["x_in"] = self.sim.parse_input("x_in", x_in)
88
+
89
+
90
+ # fix time after observations have been set. The outcome of the simulation
91
+ # can dependend on the time vector, because in e.g. IT models, the time resolution
92
+ # is important. Therefore the time at which the ECx is computed is just inserted
93
+ # into the time vector at the right position.
94
+ self.sim.coordinates["time"] = np.unique(np.concatenate([
95
+ self.sim.coordinates["time"], np.array(time, ndmin=1)
96
+ ]))
97
+
98
+ self.sim.model_parameters["y0"] = self.sim.parse_input("y0", drop_dims=["time"])
99
+ self.sim.dispatch_constructor()
100
+
101
+ self.results = pd.Series({
102
+ "mean": np.nan,
103
+ "q05": np.nan,
104
+ "q95": np.nan,
105
+ "std": np.nan,
106
+ "cv": np.nan,
107
+ "msg": np.nan
108
+ })
109
+
110
+ self.figure_profile_and_effect = None
111
+ self.figure_loss_curve = None
112
+
113
+ self.condition_posterior_parameters(conditions=conditions_posterior)
114
+
115
+ def _assert_posterior(self):
116
+ try:
117
+ p = self.sim.inferer.idata.posterior
118
+ except AttributeError:
119
+ raise GutsBaseError(
120
+ "Using mode='mode' or mode='draws', but sim did not contain a posterior. " +
121
+ "('sim.inferer.idata.posterior'). " + self._parameter_msg
122
+ )
123
+
124
+ def condition_posterior_parameters(self, conditions: Dict[str, float]):
125
+ for parameter, value in conditions.items():
126
+ if self.sim.config.model_parameters[parameter].free:
127
+ self.sim.inferer.idata.posterior = self.sim._condition_posterior(
128
+ posterior=self.sim.inferer.idata.posterior,
129
+ parameter=parameter,
130
+ value=value,
131
+ exception="raise",
132
+ )
133
+ else:
134
+ self.sim.config.model_parameters[parameter].value = value
135
+ self.sim.model_parameters["parameters"] = self.sim.config.model_parameters.value_dict
136
+ self.sim.dispatch_constructor()
137
+
138
+ def _evaluate(self, factor, theta):
139
+ evaluator = self.sim.dispatch(
140
+ theta=theta,
141
+ x_in=self.sim.validate_model_input(self.sim.model_parameters["x_in"] * factor)
142
+ )
143
+ evaluator()
144
+ return evaluator
145
+
146
+ def _loss(self, log_factor, theta):
147
+ # exponentiate the log factor
148
+ factor = np.exp(log_factor)
149
+
150
+ e = self._evaluate(factor, theta)
151
+ s = e.results.sel(time=self.time)[self.effect].values
152
+
153
+ return (s - (1 - self.x)) ** 2
154
+
155
+ def _posterior_mean(self):
156
+ mean = self.sim.inferer.idata.posterior.mean(("chain", "draw"))
157
+ mean = {k: v["data"] for k, v in mean.to_dict()["data_vars"].items()}
158
+ return mean
159
+
160
+ def _posterior_sample(self, i):
161
+ posterior_stacked = self.sim.inferer.idata.posterior.stack(
162
+ sample=("chain", "draw")
163
+ )
164
+ sample = posterior_stacked.isel(sample=i)
165
+ sample = {k: v["data"] for k, v in sample.to_dict()["data_vars"].items()}
166
+ return sample
167
+
168
+ def plot_loss_curve(self,
169
+ mode: Literal["draws", "mean", "manual"] = "draws",
170
+ draws: Optional[int] = None,
171
+ parameters: Optional[Dict[str,float|List[float]]] = None,
172
+ log_x0: float = 0.0,
173
+ log_interval_radius: float = 2.0,
174
+ log_interval_num: int = 100,
175
+ force_draws: bool = False
176
+ ):
177
+ """
178
+ Parameters
179
+ ----------
180
+
181
+ mode : Literal['draws', 'mean', 'manual']
182
+ mode of estimation. mode='mean' takes the mean of the posterior and estimate
183
+ the ECx for this singular value. mode='draws' takes samples from the posterior
184
+ and estimate the ECx for each of the parameter draws. mode='manual' takes
185
+ a parameter set (Dict) in the parameters argument and uses that for estimation.
186
+ Default: 'draws'
187
+
188
+ draws : int
189
+ Number of draws to take from the posterior. Only takes effect if mode='draw'.
190
+ Raises an exception if draws < 100, because this is insufficient for a
191
+ reasonable uncertainty estimate. Default: None (using all samples from the
192
+ posterior)
193
+
194
+ parameters : Dict[str,float|list[float]]
195
+ a parameter dictionary passed used as model parameters for finding the ECx
196
+ value. Default: None
197
+
198
+ log_x0 : float
199
+ the starting value for the multiplication factor of the exposure profile for
200
+ the minimization algorithm. This value is on the log scale. This means,
201
+ exp(log_x0=0.0) = 1.0, which means that the log_x0=0.0 will start at an
202
+ unmodified exposure profile. Default: 0.0
203
+
204
+ log_interval_radius : float
205
+ the radius of the interval around the starting value log_x0, within which the
206
+ loss function is evaluated. For example, log_interval_radius=2.0 will evaluate
207
+ the loss function at log_factor values ranging from log_x0 - 2.0 to log_x0 + 2.0.
208
+ Default: 2.0
209
+
210
+ log_interval_num : int
211
+ the number of points at which the loss function is evaluated within the interval
212
+ defined by log_x0 and log_interval_radius. For example, log_interval_num=100 will
213
+ evaluate the loss function at 100 evenly spaced points between log_x0 - log_interval_radius
214
+ and log_x0 + log_interval_radius. Default: 100
215
+
216
+ force_draws : bool
217
+ Force the estimate method to accept a number of draws less than 100. Default: False
218
+
219
+ """
220
+ draws = self._check_mode_and_draws_and_parameters(mode, draws, parameters, force_draws)
221
+
222
+
223
+ factor = np.linspace(-log_interval_radius, log_interval_radius, log_interval_num) + log_x0
224
+ fig, ax = plt.subplots(1,1, sharey=True, figsize=(4, 3))
225
+
226
+ X_lowest = []
227
+
228
+ for i in tqdm(range(draws)):
229
+ if mode == "draws":
230
+ sample = self._posterior_sample(i)
231
+ elif mode == "mean":
232
+ sample = self._posterior_mean()
233
+ elif mode == "manual":
234
+ sample = parameters
235
+ else:
236
+ raise NotImplementedError(
237
+ f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
238
+ )
239
+
240
+ y = list(map(partial(self._loss, theta=sample), factor))
241
+
242
+ x_lowest = factor[np.array(y).argmin()]
243
+ X_lowest.append(x_lowest)
244
+
245
+ ax.plot(
246
+ np.exp(factor), y,
247
+ color="black",
248
+ )
249
+
250
+ self.ecx_candidates = X_lowest
251
+
252
+ ax.plot(
253
+ [], [], color="black",
254
+ label=f"$\ell = S(t={self.time},x_{{in}}=C_{{ext}} \cdot \phi) - {self.x}$"
255
+ )
256
+ ax.set_ylabel("Loss ($\ell$)")
257
+ ax.set_xlabel("Multiplication factor ($\phi$)")
258
+ ax.set_title(f"ID: {self.sim.coordinates['id'][0]}")
259
+ ax.set_ylim(0, ax.get_ylim()[1] * 1.25)
260
+ ax.set_xscale("log")
261
+ ax.legend(frameon=False)
262
+ fig.tight_layout()
263
+
264
+ self.figure_loss_curve = fig
265
+
266
+ def _check_mode_and_draws_and_parameters(self, mode, draws, parameters, force_draws):
267
+
268
+ if mode == "draws":
269
+ self._assert_posterior()
270
+
271
+ if draws is None:
272
+ draws = (
273
+ self.sim.inferer.idata.posterior.sizes["chain"] *
274
+ self.sim.inferer.idata.posterior.sizes["draw"]
275
+ )
276
+ elif draws < 100 and not force_draws:
277
+ raise GutsBaseError(
278
+ "draws must be larger than 100. Preferably > 1000. " +
279
+ f"If you don't want uncertainty assessment of the {self._name} " +
280
+ "estimates, use mode='mean'. If you really want to use less than " +
281
+ "100 draws, use force_draws=True at your own risk."
282
+ )
283
+ else:
284
+ pass
285
+
286
+ if parameters is not None:
287
+ warnings.warn(
288
+ "Values passed to 'parameters' don't have an effect in mode='draws'"
289
+ )
290
+
291
+ elif mode == "mean":
292
+ self._assert_posterior()
293
+
294
+ draws = 1
295
+
296
+ if parameters is not None:
297
+ warnings.warn(
298
+ "Values passed to 'parameters' don't have an effect in mode='draws'"
299
+ )
300
+
301
+ elif mode == "manual":
302
+ draws = 1
303
+ if parameters is None:
304
+ raise GutsBaseError(self._parameter_msg)
305
+
306
+ if self._conditions_posterior is not None:
307
+ warnings.warn(
308
+ "Conditions applied to the posterior do not take effect in mode 'manual'"
309
+ )
310
+
311
+ else:
312
+ raise GutsBaseError(
313
+ f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
314
+ )
315
+
316
+ return draws
317
+
318
+
319
+ def estimate(
320
+ self,
321
+ mode: Literal["draws", "mean", "manual"] = "draws",
322
+ draws: Optional[int] = None,
323
+ parameters: Optional[Dict[str,float|List[float]]] = None,
324
+ log_x0: float = 0.0,
325
+ x0_retries: List[int] = [0.0, -1.0, 1.0, -2.0, 2.0],
326
+ accept_tol: float = 1e-5,
327
+ optimizer_tol: float = 1e-5,
328
+ method: str = "cobyla",
329
+ show_plot: bool = True,
330
+ force_draws: bool = False,
331
+ **optimizer_kwargs
332
+ ):
333
+ """The minimizer for the EC_x operates on the unbounded linear scale, estimating
334
+ the log-modification factor. Converted to the linear scale by factor=exp(x), the
335
+ profile modification factor is obtained.
336
+
337
+ Using x0=0.0 means optimization will start on the linear scale at the unmodified
338
+ exposure profile. Using the log scale for optimization will provide much smoother
339
+ optimization performance because multiplicative steps on the log scale require
340
+ much less adaptation.
341
+
342
+ Parameters
343
+ ----------
344
+
345
+ mode : Literal['draws', 'mean', 'manual']
346
+ mode of estimation. mode='mean' takes the mean of the posterior and estimate
347
+ the ECx for this singular value. mode='draws' takes samples from the posterior
348
+ and estimate the ECx for each of the parameter draws. mode='manual' takes
349
+ a parameter set (Dict) in the parameters argument and uses that for estimation.
350
+ Default: 'draws'
351
+
352
+ draws : int
353
+ Number of draws to take from the posterior. Only takes effect if mode='draw'.
354
+ Raises an exception if draws < 100, because this is insufficient for a
355
+ reasonable uncertainty estimate. Default: None (using all samples from the
356
+ posterior)
357
+
358
+ parameters : Dict[str,float|list[float]]
359
+ a parameter dictionary passed used as model parameters for finding the ECx
360
+ value. Default: None
361
+
362
+ log_x0 : float
363
+ the starting value for the multiplication factor of the exposure profile for
364
+ the minimization algorithm. This value is on the log scale. This means,
365
+ exp(log_x0=0.0) = 1.0, which means that the log_x0=0.0 will start at an
366
+ unmodified exposure profile. Default: 0.0
367
+
368
+ x0_retries : List[int]
369
+ a list of values to use as starting points for the minimization algorithm if
370
+ the initial optimization attempt does not converge. The values are added to
371
+ log_x0. For example, if log_x0=0.0 and x0_retries=[-1.0, 1.0], the minimization
372
+ algorithm will first try to start at exp(0.0), then at exp(-1.0) and finally
373
+ at exp(1.0) if the previous attempts do not converge. Default: [0.0, -1.0, 1.0, -2.0, 2.0]
374
+
375
+ accept_tol : float
376
+ After optimization is finished, accept_tol is used to assess if the loss
377
+ function for the individual draws exceed a tolerance. These results are
378
+ discarded and a warning is emitted. This is to assert that no faulty optimization
379
+ results enter the estimate. Default: 1e-5
380
+
381
+ optimizer_tol : float
382
+ Tolerance limit for the minimzer to stop optimization. Default 1e-5
383
+
384
+ method : str
385
+ Minization algorithm. See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
386
+ Default: 'cobyla'
387
+
388
+ show_plot : bool
389
+ Show the results plot of the lpx. Default: True
390
+
391
+ force_draws : bool
392
+ Force the estimate method to accept a number of draws less than 100. Default: False
393
+
394
+ optimizer_kwargs :
395
+ Additional arguments to pass to the optimizer
396
+
397
+ """
398
+ x0_tries = np.array(x0_retries) + log_x0
399
+ draws = self._check_mode_and_draws_and_parameters(mode, draws, parameters, force_draws)
400
+
401
+ self._mode = mode
402
+ mult_factor = []
403
+ loss = []
404
+ iterations = []
405
+ for i in tqdm(range(draws)):
406
+ if mode == "draws":
407
+ sample = self._posterior_sample(i)
408
+ elif mode == "mean":
409
+ sample = self._posterior_mean()
410
+ elif mode == "manual":
411
+ sample = parameters
412
+ else:
413
+ raise NotImplementedError(
414
+ f"Bad mode: {mode}. Mode must be one 'mean' or 'draws'"
415
+ )
416
+
417
+ success = False
418
+ iteration = 0
419
+ while not success and iteration < len(x0_tries):
420
+ opt_res = minimize(
421
+ self._loss, x0=x0_tries[iteration],
422
+ method=method,
423
+ tol=optimizer_tol,
424
+ args=(sample,),
425
+ **optimizer_kwargs
426
+ )
427
+
428
+ success = opt_res.fun < accept_tol
429
+ iteration += 1
430
+
431
+ # convert to linear scale from log scale
432
+ factor = np.exp(opt_res.x)
433
+
434
+ mult_factor.extend(factor)
435
+ iterations.append(iteration)
436
+ loss.append(opt_res.fun)
437
+
438
+ res_full = pd.DataFrame(dict(factor = mult_factor, loss=loss, retries=iterations))
439
+ self.results_full = res_full
440
+
441
+ metric = "{name}_{x}".format(name=self._name, x=int(self.x * 100),)
442
+
443
+ successes = sum(res_full.loss < accept_tol)
444
+ if successes < draws:
445
+ warnings.warn(
446
+ f"Not all optimizations converged on the {metric} ({successes/draws*100}%). " +
447
+ "Adjust starting values and method")
448
+ print(res_full)
449
+
450
+ short_msg = f"Estimation success rate: {successes/draws*100}%"
451
+ self.results["msg"] = short_msg
452
+
453
+ res = res_full.loc[res_full.loss < accept_tol,:]
454
+
455
+ if len(res) == 0:
456
+ self.msg = (
457
+ f"{metric} could not be found. Two reasons typically cause this problem: "+
458
+ f"1) no expoure before the time at which the {metric} is calculated. "+
459
+ "Check the the exposure profile. " +
460
+ f"2) Too high background mortality. If the time at which the {metric} is "+
461
+ f"calculated is large and background mortality is high, the {metric}, " +
462
+ "may be reached independent of the effect and optimization cannot succeed."
463
+ )
464
+
465
+ print(self.msg)
466
+ return
467
+
468
+ self.results["mean"] = np.round(np.mean(res.factor.values), 4)
469
+ self.results["q05"] = np.round(np.quantile(res.factor.values, 0.05), 4)
470
+ self.results["q95"] = np.round(np.quantile(res.factor.values, 0.95), 4)
471
+ self.results["std"] = np.round(np.std(res.factor.values), 4)
472
+ self.results["cv"] = np.round(np.std(res.factor.values)/np.mean(res.factor.values), 2)
473
+
474
+ if show_plot:
475
+ self.plot_profile_and_effect(parameters=parameters)
476
+
477
+ print("{name}_{x}".format(name=self._name, x=int(self.x * 100),))
478
+ print(self.results)
479
+ print("\n")
480
+
481
+ def plot_profile_and_effect(
482
+ self,
483
+ parameters: Optional[Dict[str,float|List[float]]] = None
484
+ ):
485
+ coordinates_backup = self.sim.coordinates["time"].copy()
486
+
487
+ self.sim.coordinates["time"] = np.linspace(0, self.time, 100)
488
+ self.sim.dispatch_constructor()
489
+
490
+ if self._mode is None:
491
+ raise GutsBaseError(
492
+ "Run .estimate() before plot_profile_and_effect()"
493
+ )
494
+ elif self._mode == "mean" or self._mode == "draws":
495
+ e_new = self._evaluate(factor=self.results["mean"], theta=self._posterior_mean())
496
+ e_old = self._evaluate(factor=1.0, theta=self._posterior_mean())
497
+ elif self._mode == "manual":
498
+ if parameters is None:
499
+ raise RuntimeError(
500
+ f"If {self._name}_x was estimated using manual mode, parameters must "+
501
+ "also be provided here."
502
+ )
503
+ e_new = self._evaluate(factor=self.results["mean"], theta=parameters)
504
+ e_old = self._evaluate(factor=1.0, theta=parameters)
505
+
506
+ extra_dim = [k for k in list(e_old.results.coords.keys()) if k not in ["time", "id"]]
507
+
508
+ if len(extra_dim) > 0:
509
+ labels_old = [
510
+ f"{l} (original)" for l
511
+ in e_old.results.coords[extra_dim[0]].values
512
+ ]
513
+ labels_new = [
514
+ f"{l} (modified)" for l
515
+ in e_new.results.coords[extra_dim[0]].values
516
+ ]
517
+ else:
518
+ labels_old = "original"
519
+ labels_new = "modified"
520
+
521
+
522
+
523
+ fig, (ax1, ax2) = plt.subplots(2,1, height_ratios=[1,3], sharex=True)
524
+ ax1.plot(
525
+ e_old.results.time, e_old.results.exposure.isel(id=0),
526
+ ls="--", label=labels_old,
527
+ )
528
+ ax1.set_prop_cycle(None)
529
+ ax1.plot(
530
+ e_new.results.time, e_new.results.exposure.isel(id=0),
531
+ label=labels_new
532
+ )
533
+
534
+
535
+ ax2.plot(
536
+ e_new.results.time, e_new.results.survival.isel(id=0),
537
+ color="black", ls="--", label="modified"
538
+ )
539
+ ax1.set_prop_cycle(None)
540
+
541
+ ax2.plot(
542
+ e_old.results.time, e_old.results.survival.isel(id=0),
543
+ color="black", ls="-", label="original"
544
+ )
545
+ ax2.hlines(self.x, e_new.results.time[0], self.time, color="grey")
546
+ ax1.set_ylabel("Exposure")
547
+ ax2.set_ylabel("Survival")
548
+ ax2.set_xlabel("Time")
549
+ ax1.legend()
550
+ ax2.legend()
551
+ ax2.set_xlim(0, None)
552
+ ax1.set_ylim(0, None)
553
+ ax2.set_ylim(0, None)
554
+ fig.tight_layout()
555
+
556
+ self.figure_profile_and_effect = fig
557
+
558
+ self.sim.coordinates["time"] = coordinates_backup
559
+ self.sim.dispatch_constructor()
560
+
561
+
562
+
563
+ class LPxEstimator(ECxEstimator):
564
+ """
565
+ the LPx is computed, using the existing exposure profile for
566
+ the specified ID and estimating the multiplication factor for the profile that results
567
+ in an effect of X %
568
+ """
569
+ _name = "LP"
570
+
571
+ def __init__(
572
+ self,
573
+ sim: SimulationBase,
574
+ id: str,
575
+ x: float=0.5
576
+ ):
577
+ x_in = sim.model_parameters["x_in"].sel(id=[id])
578
+ time = sim.coordinates["time"][-1]
579
+ super().__init__(
580
+ sim=sim,
581
+ effect="survival",
582
+ x=x,
583
+ time=time,
584
+ x_in=x_in
585
+ )