guts-base 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

guts_base/prob.py ADDED
@@ -0,0 +1,412 @@
1
+ from functools import partial
2
+ import numpy as np
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpyro
6
+ from numpyro.infer import Predictive
7
+ from typing import Literal
8
+
9
+
10
+ maximum = jnp.frompyfunc(jnp.maximum, nin=2, nout=1, identity=None) # type: ignore
11
+
12
+
13
+ @jax.jit
14
+ def ffill_na(x, mask):
15
+ """Forward-fill nan values in x. If a mask is provided, assume
16
+
17
+ Parameters
18
+ ----------
19
+ x : _type_
20
+ _description_
21
+ mask : _type_
22
+ _description_
23
+
24
+ Returns
25
+ -------
26
+ _type_
27
+ _description_
28
+ """
29
+ if mask is None:
30
+ mask = jnp.logical_not(jnp.isnan(x))
31
+ idx = jnp.where(mask,jnp.arange(mask.shape[1]),0)
32
+ idx_ = maximum.accumulate(jnp.array(idx), axis=1).astype(int)
33
+ return x[jnp.arange(idx.shape[0])[:,None], idx_]
34
+
35
+
36
+ def conditional_survival_from_hazard(x, mask):
37
+ """Calculates the conditional survival from cumulative hazard values.
38
+ This equation is used when survival is repeatedly observed over time and
39
+
40
+ Parameters
41
+ ----------
42
+ x : np.ndarray[I,T, float]
43
+ A 2-dimensional I x T array of cumulative hazards defined as H = -ln(S).
44
+ I is the batch dimension and T is the time dimension
45
+
46
+ mask : np.ndarray[I,T, bool]
47
+ A 2-dimensional array of the same shape as x, taking True if the survival
48
+ was observed for the given index (i,t) and taking False if survival was
49
+ not observed for the given index (i,t).
50
+
51
+ Returns
52
+ -------
53
+ out : np.ndarray[I,T, float]
54
+ A matrix with conditional probabilities and nans in place where the
55
+ mask has nans. Output has the same shape as input.
56
+
57
+ Example
58
+ -------
59
+ Calculation example from survival probabilities to conditional survival
60
+ probabilities given some masked values.
61
+
62
+ >>> S_i = np.array([
63
+ >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
64
+ >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
65
+ >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
66
+ >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
67
+ >>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
68
+ >>> ])
69
+
70
+ >>> mask_obs = np.array([
71
+ >>> [ True, True, True, True, True, True],
72
+ >>> [ True, False, True, True, True, True],
73
+ >>> [ True, False, False, True, True, True],
74
+ >>> [ True, True, False, True, True, True],
75
+ >>> [ True, False, True, False, True, True],
76
+ >>> ])
77
+
78
+ >>> conditional_survival_from_hazard(-jnp.log(S_i), mask_obs)
79
+ array([
80
+ [1.0 0.75 0.133 0.5 0.2 0.0]
81
+ [1.0 nan 0.1 0.5 0.2 0.0]
82
+ [1.0 nan nan 0.05 0.2 0.0]
83
+ [1.0 0.75 nan 0.0667 0.2 0.0]
84
+ [1.0 nan 0.1 nan 0.1 0.0]
85
+ ])
86
+ """
87
+
88
+ # Append zeros (hazard) to the beginning of the array (this aligns with the
89
+ # safe assumption that before the zeroth observation S(t=-1) = 1.0)
90
+ x_ = jnp.column_stack([jnp.zeros_like(x[:, 0]), x])
91
+
92
+ # also mask needs to be expanded accordingly
93
+ mask_ = jnp.column_stack([jnp.ones_like(mask[:, 0]), mask])
94
+
95
+ # fill NaNs with forward
96
+ x_filled = ffill_na(x_, mask_)
97
+
98
+ # calculate the conditional survival.
99
+ conditional_survival = jnp.exp(x_filled[:, :-1] - (x_[:, 1:]))
100
+
101
+ # add nans and return
102
+ return jnp.where(
103
+ mask, conditional_survival, jnp.nan
104
+ )
105
+
106
+
107
+ def conditional_survival_hazard_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
108
+ """Computes the likelihood of observing K survivors at time t given that
109
+ N survivors were alive at the previous observation t-1
110
+
111
+ This is achieved by using the Cumulative hazard directly in the calculation
112
+ of the conditional survival probability. Note that this equation can easily
113
+ be adapted to compute the conditional lethality probability.
114
+
115
+ $$\\Pr(t < T~|~t_0 < T) = \\frac{e^{-H(t)}}{e^{-H(t_0)}} = e^{-H(t) - (- H(t_0))} = e^{-H(t) + H(t_0)} = e^{-(H(t) - H(t_0))} = e^{H(t_0) - H(t)}$$
116
+
117
+ Example
118
+ -------
119
+
120
+ t0 t1 t2 t3 tinf
121
+ observations 10 8 5 2 0
122
+ Pr(T > t) 1.0 0.8 0.5 0.2 0.0
123
+ Pr(T > t | T > t-1) (1.0) 0.8 0.625 0.4 0.0
124
+
125
+
126
+ If the experiment ends before the last organism has died, these information
127
+ have to be included in he computation of the likelihood. For multinomial
128
+ distributions this is done by computing the remaining (ubobserved) lethality
129
+ until the end of the experiment, whihc is just the inverse of the number of
130
+ survivors at the end of the experiment. Therefore, for conditional survival
131
+ the information about the number of alive organisms
132
+ at the end of the experiment is contained in the last observation.
133
+
134
+ We compute an additional interval until the start of the experiment in order
135
+ to align the results with the observations (which include t=0). Otherwise
136
+ we would only compute T-1 intervals if we have T observations (including t=0).
137
+ """
138
+
139
+ H = simulation_results["H"]
140
+ n_surv = observations["survivors_before_t"]
141
+ S_mask = masks["survival"]
142
+ obs_survival = observations["survival"]
143
+ S_conditional = conditional_survival_from_hazard(H, S_mask)
144
+
145
+ if make_predictions:
146
+ obs_survival = None
147
+
148
+ ID, T = S_mask.shape
149
+
150
+ with numpyro.plate("time", size=T):
151
+ with numpyro.plate("id", size=ID):
152
+ numpyro.sample(
153
+ "survival_obs", numpyro.distributions.Binomial(
154
+ probs=S_conditional,
155
+ total_count=n_surv
156
+ ).mask(S_mask),
157
+ obs=obs_survival
158
+ )
159
+
160
+
161
+ def conditional_survival_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
162
+ """Note that the conditional survival error model will not work to generate
163
+ posterior predictions for observational uncertainty out of the box
164
+
165
+ This model can be used for SD and IT models, but for SD models, the
166
+ conditional_survival_hazard_model is preferrable.
167
+ """
168
+ # error model
169
+ EPS = observations["eps"]
170
+ S = jnp.clip(simulation_results["survival"], EPS, 1 - EPS)
171
+
172
+ S_mask = masks["survival"]
173
+ n_surv = observations["survivors_before_t"]
174
+ obs_survival = observations["survival"]
175
+
176
+ S_ = ffill_na(S, S_mask)
177
+ S_cond_na = S_[:, 1:] / S_[:, :-1]
178
+ S_cond_na_ = jnp.column_stack([jnp.ones_like(S[:,[0]]), S_cond_na])
179
+ S_cond_na__ = jnp.clip(S_cond_na_, EPS, 1 - EPS)
180
+
181
+ if make_predictions:
182
+ obs_survival = None
183
+
184
+ ID, T = S_mask.shape
185
+
186
+ with numpyro.plate("time", size=T):
187
+ with numpyro.plate("id", size=ID):
188
+ numpyro.sample(
189
+ "survival_obs", numpyro.distributions.Binomial(
190
+ probs=S_cond_na__,
191
+ total_count=n_surv
192
+ ).mask(S_mask),
193
+ obs=obs_survival
194
+ )
195
+
196
+
197
+ def conditional_lethality_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
198
+ raise NotImplementedError(
199
+ "Conditional lethality error model needs to forward fill survival " +
200
+ "probabilities for missing (observation) values"
201
+ ""
202
+ )
203
+
204
+ # error model
205
+ EPS = observations["eps"]
206
+ S = jnp.clip(simulation_results["S"], EPS, 1 - EPS)
207
+
208
+ # TODO work on this
209
+ # Conditional survival may be wrong, but also maybe the data generation
210
+ # fct may be wrong
211
+ # a test showed that it looks preddy darn good. I guess this is because
212
+ # it is the conditional probability to survive. Not the conditional
213
+ # probability to die.
214
+ S_cond = (S[:, :-1] - S[:, 1:]) / S[:, :-1]
215
+ S_cond_ = jnp.column_stack([jnp.zeros_like(S[:,[0]]), S_cond])
216
+ S_cond__ = jnp.clip(S_cond_, EPS, 1 - EPS)
217
+
218
+ n_surv = observations["survivors_before_t"]
219
+ L_mask = masks["L"]
220
+ L = observations["L"]
221
+
222
+ if make_predictions:
223
+ L = None
224
+
225
+ numpyro.sample(
226
+ "L_obs", numpyro.distributions.Binomial(
227
+ probs=S_cond__,
228
+ total_count=n_surv
229
+ ).mask(L_mask),
230
+ obs=L
231
+ )
232
+
233
+
234
+ def multinomial_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
235
+ raise NotImplementedError(
236
+ "Multinomial error model needs to forward fill survival " +
237
+ "probabilities for missing (observation) values"
238
+ ""
239
+ )
240
+ # error model
241
+ EPS = observations["eps"]
242
+ # TODO This should not be done before calculating the difference
243
+ S = jnp.clip(simulation_results["S"], EPS, 1 - EPS)
244
+
245
+ # TODO work on this
246
+ # This is not working with replicates
247
+ s_probs = S[:, :-1] - S[:, 1:]
248
+ s_probs_ = jnp.column_stack([jnp.zeros_like(S[:,[0]]), s_probs])
249
+
250
+ L_mask = masks["L"]
251
+ L = observations["L"]
252
+ N = observations["n_subjects"]
253
+ # N_ = jnp.broadcast_to(jnp.expand_dims(N, 1), L.shape)
254
+
255
+ if make_predictions:
256
+ L = None
257
+
258
+ ID, T = L_mask.shape
259
+
260
+ with numpyro.plate("id", ID):
261
+ numpyro.sample(
262
+ "L_obs", numpyro.distributions.Multinomial(
263
+ probs=s_probs_,
264
+ total_count=N
265
+ ).mask(L_mask),
266
+ obs=L
267
+ )
268
+
269
+ def survival_predictions(
270
+ probs, n_trials,
271
+ eps: float = 0.0,
272
+ seed=1,
273
+ mode: Literal["survival", "lethality", "deaths"] = "survival"
274
+ ):
275
+ """Generate predictions for survival based on a multinomial survival distribution
276
+
277
+ Parameters
278
+ ----------
279
+ probs : ArrayLike
280
+ 2D Array denoting the multinomial probabilities of deaths for each time interval
281
+ per experiment
282
+ dims=(experiment, time)
283
+ n_trials : ArrayLike
284
+ 1D Array denoting the number of organisms at the beginning in each experiment
285
+ dims = (experiment,)
286
+ seed : int, optional
287
+ Seed for the random number generator, by default 1
288
+ mode : str, optional
289
+ How should the random draws be returned?
290
+ - survival: Decreasing from n_trials to 0
291
+ - lethality: Increasing from 0 to n_trials
292
+ - deaths: Between 0 and n_trials in each interval. Summing to n_trials
293
+ """
294
+
295
+ def survival_to_death_probs(pr_survival):
296
+ # truncate here, because numeric errors below the solver tolerance can
297
+ # lead to negative values in the difference. This needs to be cured here
298
+ pr_survival_ = np.trunc(pr_survival / eps) * eps
299
+ pr_death = pr_survival_[:-1] - pr_survival_[1:]
300
+
301
+ pr_death = np.concatenate([
302
+ # concatenate a zero at the beginning in order to "simulate" no
303
+ # deaths at T = 0
304
+ jnp.zeros((1,)),
305
+ # Delta S
306
+ pr_death,
307
+ # The remaining mortility as T -> infinity
308
+ jnp.ones((1,))-pr_death.sum()
309
+ ])
310
+
311
+ # make sure the vector is not zero or 1 (this is always problematic for
312
+ # probabilities) and make sure the vector sums to 1
313
+ pr_death = np.clip(pr_death, eps, 1-eps)
314
+ pr_death = pr_death / pr_death.sum()
315
+ return pr_death
316
+
317
+ rng = np.random.default_rng(seed)
318
+ deaths = jnp.array(list(map(
319
+ lambda n, pr_survival: rng.multinomial(
320
+ n=n, pvals=survival_to_death_probs(pr_survival)
321
+ ),
322
+ n_trials,
323
+ probs
324
+ )))
325
+
326
+ # remove the last observations to trim off the simulated unobserved mortality
327
+ deaths = deaths[:, :-1]
328
+
329
+ if mode == "deaths":
330
+ return deaths
331
+ elif mode == "lethality":
332
+ return deaths.cumsum(axis=1)
333
+ elif mode == "survival":
334
+ return np.expand_dims(n_trials, axis=1) - deaths.cumsum(axis=1)
335
+ else:
336
+ raise NotImplementedError(
337
+ f"Mode {mode} is not implemented."+
338
+ "Use one of 'survival', 'lethality', or 'deaths'."
339
+ )
340
+
341
+ def posterior_predictions(sim, idata, seed=None):
342
+ """Make posterior predictions for survival data"""
343
+ if seed is None:
344
+ seed = sim.config.simulation.seed
345
+
346
+ n = idata.posterior.dims["draw"]
347
+ c = idata.posterior.dims["chain"]
348
+
349
+ key = jax.random.PRNGKey(seed)
350
+
351
+ obs, masks = sim.inferer.observation_parser()
352
+
353
+ model_kwargs = sim.inferer.preprocessing(obs=obs, masks=masks)
354
+
355
+ # prepare model
356
+ model = partial(
357
+ sim.inferer.inference_model,
358
+ solver=sim.inferer.evaluator,
359
+ **model_kwargs
360
+ )
361
+
362
+ posterior_samples = {
363
+ k: np.array(v["data"]) for k, v
364
+ in idata.unconstrained_posterior.to_dict()["data_vars"].items()
365
+ }
366
+
367
+ predictive = Predictive(
368
+ model, posterior_samples=posterior_samples,
369
+ num_samples=n, batch_ndims=2
370
+ )
371
+
372
+ samples = predictive(key)
373
+
374
+ chains = []
375
+ for i in range(c):
376
+ # TODO: Switch to vmap and jit, but it did not work, so if you do it TEST IT!!!!!!!
377
+ predictions = list(map(
378
+ partial(
379
+ survival_predictions,
380
+ n_trials=obs["survival"][:, 0].astype(int),
381
+ eps=obs["eps"],
382
+ seed=seed,
383
+ mode="survival",
384
+ ),
385
+ samples["survival"][i]
386
+ ))
387
+ chains.append(predictions)
388
+
389
+ new_idata = sim.inferer.to_arviz_idata(
390
+ posterior=samples,
391
+ posterior_predictive={"survival_obs": np.array(chains)},
392
+ n_draws=n,
393
+ n_chains=c
394
+ )
395
+
396
+ # update chain names in case they were subselected (clustering)
397
+ # new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
398
+ new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
399
+
400
+ # assert the new posterior matches the old posterior
401
+ tol = sim.config.jaxsolver.atol * 100
402
+ abs_diff_posterior = np.abs(idata.posterior - new_idata.posterior)
403
+ np.testing.assert_array_less(abs_diff_posterior.mean().to_array(), tol)
404
+
405
+ fit_tol = tol * sim.coordinates["time"].max()
406
+ abs_diff_fits = np.abs(new_idata.posterior_model_fits - idata.posterior_model_fits)
407
+ np.testing.assert_array_less(abs_diff_fits.mean().to_array(), fit_tol)
408
+
409
+ idata.posterior_model_fits = new_idata.posterior_model_fits
410
+ idata.posterior_predictive = new_idata.posterior_predictive
411
+
412
+ return idata
@@ -0,0 +1,14 @@
1
+ from . import base
2
+ from . import ecx
3
+ from . import report
4
+
5
+ from .base import (
6
+ GutsBase,
7
+ GutsSimulationConstantExposure,
8
+ GutsSimulationVariableExposure
9
+ )
10
+
11
+ from .ecx import ECxEstimator, LPxEstimator
12
+ from .report import GutsReport
13
+
14
+ from .mempy import PymobSimulator