guts-base 0.8.6__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

guts_base/__init__.py CHANGED
@@ -4,11 +4,12 @@ from . import data
4
4
  from . import prob
5
5
  from . import plot
6
6
 
7
- __version__ = "0.8.6"
7
+ __version__ = "1.0.1"
8
8
 
9
9
  from .sim import (
10
10
  GutsBase,
11
11
  PymobSimulator,
12
12
  ECxEstimator,
13
13
  LPxEstimator,
14
+ GutsBaseError,
14
15
  )
@@ -21,7 +21,7 @@ from .survival import (
21
21
  generate_survival_repeated_observations
22
22
  )
23
23
 
24
- from .generator import create_artificial_data, design_exposure_scenario
24
+ from .generator import create_artificial_data, design_exposure_scenario, ExposureDataDict
25
25
 
26
26
  from .expydb import (
27
27
  to_dataset,
@@ -11,7 +11,8 @@ def create_artificial_data(
11
11
  t_max,
12
12
  dt,
13
13
  exposure_paths=["oral", "topical", "contact"],
14
- intensity=[0.1, 0.5, 0.05]
14
+ intensity=[0.1, 0.5, 0.05],
15
+ seed=1,
15
16
  ):
16
17
  rng = np.random.default_rng(1)
17
18
  time = np.arange(0, t_max, step=dt) # daily time resolution
@@ -40,6 +40,12 @@ def prepare_survival_data_for_conditional_binomial(observations: xr.Dataset) ->
40
40
  nsurv.isel(time=list(range(0, len(nsurv.time)-1))).values
41
41
  ]).astype(int))})
42
42
 
43
+ observations = observations.assign_coords({
44
+ "survivors_at_start": (("id", "time"), np.broadcast_to(
45
+ nsurv.isel(time=0).values.reshape(-1,1),
46
+ shape=nsurv.shape
47
+ ).astype(int))})
48
+
43
49
  return observations
44
50
 
45
51
 
guts_base/mod.py CHANGED
@@ -1,10 +1,3 @@
1
- """
2
- GUTS models
3
-
4
- TODO: Import guts models from mempy and update to work well with jax
5
- TODO: Based on this implement the bufferguts model in mempy and import in the
6
- bufferguts case-study
7
- """
8
1
  from functools import partial
9
2
  from pathlib import Path
10
3
 
@@ -19,79 +12,37 @@ from pymob.solvers.symbolic import (
19
12
  PiecewiseSymbolicODESolver, FunctionPythonCode, get_return_arguments, dX_dt2X
20
13
  )
21
14
 
22
- from mempy.model import (
23
- RED_IT,
24
- RED_SD,
25
- RED_IT_DA,
26
- RED_SD_DA,
27
- RED_IT_IA,
28
- RED_SD_IA,
29
- BufferGUTS_IT,
30
- BufferGUTS_IT_CA,
31
- BufferGUTS_IT_DA
32
- )
15
+ class RED_SD:
16
+ """Simplest guts model, mainly for testing"""
17
+ @staticmethod
18
+ def _rhs_jax(t, y, x_in, kd, b, m, hb):
19
+ D, H = y
20
+ dD_dt = kd * (x_in.evaluate(t) - D)
21
+ dH_dt = b * jnp.maximum(0.0, D - m) + hb
22
+ return dD_dt, dH_dt
23
+
24
+ @staticmethod
25
+ def _solver_post_processing(results, t, interpolation):
26
+ results["survival"] = jnp.exp(-results["H"])
27
+ results["exposure"] = jax.vmap(interpolation.evaluate)(t)
28
+ return results
33
29
 
34
30
  red_sd = RED_SD._rhs_jax
35
31
  red_sd_post_processing = RED_SD._solver_post_processing
36
32
 
37
- red_it = RED_IT._rhs_jax
38
- red_it_post_processing = RED_IT._solver_post_processing
33
+ class RED_SD_DA(RED_SD):
34
+ @staticmethod
35
+ def _rhs_jax(t, y, x_in, kd, w, b, m, hb):
36
+ D, H = y
37
+ dD_dt = kd * (x_in.evaluate(t) - D)
38
+ dH_dt = b * jnp.maximum(0.0, jnp.sum(w * D) - m) + hb
39
+ return dD_dt, dH_dt
39
40
 
40
- red_sd_ia = RED_SD_IA._rhs_jax
41
- red_sd_ia_post_processing = RED_SD_IA._solver_post_processing
41
+ red_sd_da = RED_SD_DA._rhs_jax
42
+ red_sd_da_post_processing = RED_SD_DA._solver_post_processing
42
43
 
43
44
 
44
- def p_survival(results, t, interpolation, z, k_k, h_b):
45
- """Computes the stochastic death survival probability after computing
46
-
47
- """
48
- # calculate survival
49
- results["exposure"] = jax.vmap(interpolation.evaluate)(t)
50
- p_surv = survival_jax(t, results["D"], z, k_k, h_b)
51
- results["survival"] = p_surv
52
- results["lethality"] = 1 - p_surv
53
- return results
54
-
55
- def it_post_processing(results, t, interpolation, alpha, beta, h_b, eps):
56
- results["exposure"] = jax.vmap(interpolation.evaluate)(t)
57
- p_surv = survival_IT_jax(t, results["D"], alpha, beta, h_b, eps)
58
- results["survival"] = p_surv
59
- results["H"] = - jnp.log(p_surv)
60
- return results
61
-
62
- def post_exposure(results, t, interpolation):
63
- results["survival"] = jnp.exp(-results["H"])
64
- results["exposure"] = jax.vmap(interpolation.evaluate)(t)
65
- return results
66
-
67
-
68
- def no_post_processing(results):
69
- return results
70
-
71
-
72
- @jax.jit
73
- def survival_jax(t, damage, z, kk, h_b):
74
- """
75
- survival probability derived from hazard
76
- first calculate cumulative Hazard by integrating hazard cumulatively over t
77
- then calculate the resulting survival probability
78
- It was checked that `survival_jax` behaves exactly the same as `survival`
79
- """
80
- hazard = kk * jnp.where(damage - z < 0, 0, damage - z) + h_b
81
- H = jnp.array([jax.scipy.integrate.trapezoid(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
82
- # H = jnp.array([jnp.trapz(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
83
- S = jnp.exp(-H)
84
-
85
- return S
86
-
87
- @jax.jit
88
- def survival_IT_jax(t, damage, alpha, beta, h_b, eps):
89
- d_max = jnp.squeeze(jnp.array([jnp.max(damage[:i+1])+eps for i in range(len(t))]))
90
- F = jnp.where(d_max > 0, 1.0 / (1.0 + (d_max / alpha) ** -beta), 0)
91
- S = 1.0 * (jnp.array([1.0], dtype=float) - F) * jnp.exp(-h_b * t)
92
- return S
93
-
94
- def guts_jax(t, y, C_0, k_d, z, b, h_b):
45
+ def guts_constant_exposure(t, y, C_0, k_d, z, b, h_b):
95
46
  # for constant exposure
96
47
  D, H, S = y
97
48
  dD_dt = k_d * (C_0 - D)
@@ -103,16 +54,6 @@ def guts_jax(t, y, C_0, k_d, z, b, h_b):
103
54
 
104
55
  return dD_dt, dH_dt, dS_dt
105
56
 
106
-
107
- def RED_IT(t, y, x_in, kd):
108
- D, = y
109
- C = x_in.evaluate(t)
110
-
111
- dD_dt = kd * (C - D)
112
-
113
- return (dD_dt, )
114
-
115
-
116
57
  def guts_variable_exposure(t, y, x_in, k_d, z, b, h_b):
117
58
  # for constant exposure
118
59
  D, H, S = y
guts_base/prob.py CHANGED
@@ -2,270 +2,10 @@ from functools import partial
2
2
  import numpy as np
3
3
  import jax
4
4
  import jax.numpy as jnp
5
- import numpyro
6
5
  from numpyro.infer import Predictive
7
6
  from typing import Literal
8
7
 
9
8
 
10
- maximum = jnp.frompyfunc(jnp.maximum, nin=2, nout=1, identity=None) # type: ignore
11
-
12
-
13
- @jax.jit
14
- def ffill_na(x, mask):
15
- """Forward-fill nan values in x. If a mask is provided, assume
16
-
17
- Parameters
18
- ----------
19
- x : _type_
20
- _description_
21
- mask : _type_
22
- _description_
23
-
24
- Returns
25
- -------
26
- _type_
27
- _description_
28
- """
29
- if mask is None:
30
- mask = jnp.logical_not(jnp.isnan(x))
31
- idx = jnp.where(mask,jnp.arange(mask.shape[1]),0)
32
- idx_ = maximum.accumulate(jnp.array(idx), axis=1).astype(int)
33
- return x[jnp.arange(idx.shape[0])[:,None], idx_]
34
-
35
-
36
- def conditional_survival_from_hazard(x, mask):
37
- """Calculates the conditional survival from cumulative hazard values.
38
- This equation is used when survival is repeatedly observed over time and
39
-
40
- Parameters
41
- ----------
42
- x : np.ndarray[I,T, float]
43
- A 2-dimensional I x T array of cumulative hazards defined as H = -ln(S).
44
- I is the batch dimension and T is the time dimension
45
-
46
- mask : np.ndarray[I,T, bool]
47
- A 2-dimensional array of the same shape as x, taking True if the survival
48
- was observed for the given index (i,t) and taking False if survival was
49
- not observed for the given index (i,t).
50
-
51
- Returns
52
- -------
53
- out : np.ndarray[I,T, float]
54
- A matrix with conditional probabilities and nans in place where the
55
- mask has nans. Output has the same shape as input.
56
-
57
- Example
58
- -------
59
- Calculation example from survival probabilities to conditional survival
60
- probabilities given some masked values.
61
-
62
- >>> S_i = np.array([
63
- >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
64
- >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
65
- >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
66
- >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
67
- >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
68
- >>> ])
69
-
70
- >>> mask_obs = np.array([
71
- >>> [ True, True, True, True, True, True],
72
- >>> [ True, False, True, True, True, True],
73
- >>> [ True, False, False, True, True, True],
74
- >>> [ True, True, False, True, True, True],
75
- >>> [ True, False, True, False, True, True],
76
- >>> ])
77
-
78
- >>> conditional_survival_from_hazard(-jnp.log(S_i), mask_obs)
79
- array([
80
- [1.0 0.75 0.133 0.5 0.2 0.0]
81
- [1.0 nan 0.1 0.5 0.2 0.0]
82
- [1.0 nan nan 0.05 0.2 0.0]
83
- [1.0 0.75 nan 0.0667 0.2 0.0]
84
- [1.0 nan 0.1 nan 0.1 0.0]
85
- ])
86
- """
87
-
88
- # Append zeros (hazard) to the beginning of the array (this aligns with the
89
- # safe assumption that before the zeroth observation S(t=-1) = 1.0)
90
- x_ = jnp.column_stack([jnp.zeros_like(x[:, 0]), x])
91
-
92
- # also mask needs to be expanded accordingly
93
- mask_ = jnp.column_stack([jnp.ones_like(mask[:, 0]), mask])
94
-
95
- # fill NaNs with forward
96
- x_filled = ffill_na(x_, mask_)
97
-
98
- # calculate the conditional survival.
99
- conditional_survival = jnp.exp(x_filled[:, :-1] - (x_[:, 1:]))
100
-
101
- # add nans and return
102
- return jnp.where(
103
- mask, conditional_survival, jnp.nan
104
- )
105
-
106
-
107
- def conditional_survival_hazard_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
108
- """Computes the likelihood of observing K survivors at time t given that
109
- N survivors were alive at the previous observation t-1
110
-
111
- This is achieved by using the Cumulative hazard directly in the calculation
112
- of the conditional survival probability. Note that this equation can easily
113
- be adapted to compute the conditional lethality probability.
114
-
115
- $$\\Pr(t < T~|~t_0 < T) = \\frac{e^{-H(t)}}{e^{-H(t_0)}} = e^{-H(t) - (- H(t_0))} = e^{-H(t) + H(t_0)} = e^{-(H(t) - H(t_0))} = e^{H(t_0) - H(t)}$$
116
-
117
- Example
118
- -------
119
-
120
- t0 t1 t2 t3 tinf
121
- observations 10 8 5 2 0
122
- Pr(T > t) 1.0 0.8 0.5 0.2 0.0
123
- Pr(T > t | T > t-1) (1.0) 0.8 0.625 0.4 0.0
124
-
125
-
126
- If the experiment ends before the last organism has died, these information
127
- have to be included in he computation of the likelihood. For multinomial
128
- distributions this is done by computing the remaining (ubobserved) lethality
129
- until the end of the experiment, whihc is just the inverse of the number of
130
- survivors at the end of the experiment. Therefore, for conditional survival
131
- the information about the number of alive organisms
132
- at the end of the experiment is contained in the last observation.
133
-
134
- We compute an additional interval until the start of the experiment in order
135
- to align the results with the observations (which include t=0). Otherwise
136
- we would only compute T-1 intervals if we have T observations (including t=0).
137
- """
138
-
139
- H = simulation_results["H"]
140
- n_surv = observations["survivors_before_t"]
141
- S_mask = masks["survival"]
142
- obs_survival = observations["survival"]
143
- S_conditional = conditional_survival_from_hazard(H, S_mask)
144
-
145
- if make_predictions:
146
- obs_survival = None
147
-
148
- ID, T = S_mask.shape
149
-
150
- with numpyro.plate("time", size=T):
151
- with numpyro.plate("id", size=ID):
152
- numpyro.sample(
153
- "survival_obs", numpyro.distributions.Binomial(
154
- probs=S_conditional,
155
- total_count=n_surv
156
- ).mask(S_mask),
157
- obs=obs_survival
158
- )
159
-
160
-
161
- def conditional_survival_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
162
- """Note that the conditional survival error model will not work to generate
163
- posterior predictions for observational uncertainty out of the box
164
-
165
- This model can be used for SD and IT models, but for SD models, the
166
- conditional_survival_hazard_model is preferrable.
167
- """
168
- # error model
169
- EPS = observations["eps"]
170
- S = jnp.clip(simulation_results["survival"], EPS, 1 - EPS)
171
-
172
- S_mask = masks["survival"]
173
- n_surv = observations["survivors_before_t"]
174
- obs_survival = observations["survival"]
175
-
176
- S_ = ffill_na(S, S_mask)
177
- S_cond_na = S_[:, 1:] / S_[:, :-1]
178
- S_cond_na_ = jnp.column_stack([jnp.ones_like(S[:,[0]]), S_cond_na])
179
- S_cond_na__ = jnp.clip(S_cond_na_, EPS, 1 - EPS)
180
-
181
- if make_predictions:
182
- obs_survival = None
183
-
184
- ID, T = S_mask.shape
185
-
186
- with numpyro.plate("time", size=T):
187
- with numpyro.plate("id", size=ID):
188
- numpyro.sample(
189
- "survival_obs", numpyro.distributions.Binomial(
190
- probs=S_cond_na__,
191
- total_count=n_surv
192
- ).mask(S_mask),
193
- obs=obs_survival
194
- )
195
-
196
-
197
- def conditional_lethality_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
198
- raise NotImplementedError(
199
- "Conditional lethality error model needs to forward fill survival " +
200
- "probabilities for missing (observation) values"
201
- ""
202
- )
203
-
204
- # error model
205
- EPS = observations["eps"]
206
- S = jnp.clip(simulation_results["S"], EPS, 1 - EPS)
207
-
208
- # TODO work on this
209
- # Conditional survival may be wrong, but also maybe the data generation
210
- # fct may be wrong
211
- # a test showed that it looks preddy darn good. I guess this is because
212
- # it is the conditional probability to survive. Not the conditional
213
- # probability to die.
214
- S_cond = (S[:, :-1] - S[:, 1:]) / S[:, :-1]
215
- S_cond_ = jnp.column_stack([jnp.zeros_like(S[:,[0]]), S_cond])
216
- S_cond__ = jnp.clip(S_cond_, EPS, 1 - EPS)
217
-
218
- n_surv = observations["survivors_before_t"]
219
- L_mask = masks["L"]
220
- L = observations["L"]
221
-
222
- if make_predictions:
223
- L = None
224
-
225
- numpyro.sample(
226
- "L_obs", numpyro.distributions.Binomial(
227
- probs=S_cond__,
228
- total_count=n_surv
229
- ).mask(L_mask),
230
- obs=L
231
- )
232
-
233
-
234
- def multinomial_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
235
- raise NotImplementedError(
236
- "Multinomial error model needs to forward fill survival " +
237
- "probabilities for missing (observation) values"
238
- ""
239
- )
240
- # error model
241
- EPS = observations["eps"]
242
- # TODO This should not be done before calculating the difference
243
- S = jnp.clip(simulation_results["S"], EPS, 1 - EPS)
244
-
245
- # TODO work on this
246
- # This is not working with replicates
247
- s_probs = S[:, :-1] - S[:, 1:]
248
- s_probs_ = jnp.column_stack([jnp.zeros_like(S[:,[0]]), s_probs])
249
-
250
- L_mask = masks["L"]
251
- L = observations["L"]
252
- N = observations["n_subjects"]
253
- # N_ = jnp.broadcast_to(jnp.expand_dims(N, 1), L.shape)
254
-
255
- if make_predictions:
256
- L = None
257
-
258
- ID, T = L_mask.shape
259
-
260
- with numpyro.plate("id", ID):
261
- numpyro.sample(
262
- "L_obs", numpyro.distributions.Multinomial(
263
- probs=s_probs_,
264
- total_count=N
265
- ).mask(L_mask),
266
- obs=L
267
- )
268
-
269
9
  def survival_predictions(
270
10
  probs, n_trials,
271
11
  eps: float = 0.0,
@@ -351,6 +91,8 @@ def posterior_predictions(sim, idata, seed=None):
351
91
  obs, masks = sim.inferer.observation_parser()
352
92
 
353
93
  model_kwargs = sim.inferer.preprocessing(obs=obs, masks=masks)
94
+ if sim.config.inference_numpyro.user_defined_error_model is None:
95
+ model_kwargs["obs"]["survival"] = None
354
96
 
355
97
  # prepare model
356
98
  model = partial(
@@ -371,24 +113,30 @@ def posterior_predictions(sim, idata, seed=None):
371
113
 
372
114
  samples = predictive(key)
373
115
 
374
- chains = []
375
- for i in range(c):
376
- # TODO: Switch to vmap and jit, but it did not work, so if you do it TEST IT!!!!!!!
377
- predictions = list(map(
378
- partial(
379
- survival_predictions,
380
- n_trials=obs["survival"][:, 0].astype(int),
381
- eps=obs["eps"],
382
- seed=seed,
383
- mode="survival",
384
- ),
385
- samples["survival"][i]
386
- ))
387
- chains.append(predictions)
116
+ if sim.config.inference_numpyro.user_defined_error_model is not None:
117
+ chains = []
118
+ for i in range(c):
119
+ # TODO: Switch to vmap and jit, but it did not work, so if you do it TEST IT!!!!!!!
120
+ predictions = list(map(
121
+ partial(
122
+ survival_predictions,
123
+ n_trials=obs["survival"][:, 0].astype(int),
124
+ eps=obs["eps"],
125
+ seed=seed,
126
+ mode="survival",
127
+ ),
128
+ samples["survival"][i]
129
+ ))
130
+ chains.append(predictions)
131
+
132
+ posterior_predictive = {"survival_obs": np.array(chains)}
133
+ else:
134
+ posterior_predictive = {"survival_obs": samples.pop("survival_obs")}
135
+
388
136
 
389
137
  new_idata = sim.inferer.to_arviz_idata(
390
138
  posterior=samples,
391
- posterior_predictive={"survival_obs": np.array(chains)},
139
+ posterior_predictive=posterior_predictive,
392
140
  n_draws=n,
393
141
  n_chains=c
394
142
  )
guts_base/sim/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from . import base
2
2
  from . import ecx
3
3
  from . import report
4
+ from . import utils
4
5
 
5
6
  from .base import (
6
7
  GutsBase,
@@ -11,4 +12,12 @@ from .base import (
11
12
  from .ecx import ECxEstimator, LPxEstimator
12
13
  from .report import GutsReport
13
14
 
14
- from .mempy import PymobSimulator
15
+ from .mempy import PymobSimulator
16
+ from .utils import (
17
+ GutsBaseError
18
+ )
19
+
20
+ from .constructors import (
21
+ construct_sim_from_config,
22
+ load_idata
23
+ )