guts-base 0.8.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- guts_base/__init__.py +2 -1
- guts_base/data/__init__.py +1 -1
- guts_base/data/generator.py +2 -1
- guts_base/data/survival.py +6 -0
- guts_base/mod.py +24 -83
- guts_base/prob.py +23 -275
- guts_base/sim/__init__.py +10 -1
- guts_base/sim/base.py +285 -75
- guts_base/sim/constructors.py +31 -0
- guts_base/sim/ecx.py +168 -58
- guts_base/sim/mempy.py +85 -70
- guts_base/sim/report.py +0 -1
- guts_base/sim/utils.py +10 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/METADATA +2 -3
- guts_base-1.0.0.dist-info/RECORD +25 -0
- guts_base/sim.py +0 -0
- guts_base-0.8.6.dist-info/RECORD +0 -24
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/WHEEL +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/entry_points.txt +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {guts_base-0.8.6.dist-info → guts_base-1.0.0.dist-info}/top_level.txt +0 -0
guts_base/__init__.py
CHANGED
guts_base/data/__init__.py
CHANGED
|
@@ -21,7 +21,7 @@ from .survival import (
|
|
|
21
21
|
generate_survival_repeated_observations
|
|
22
22
|
)
|
|
23
23
|
|
|
24
|
-
from .generator import create_artificial_data, design_exposure_scenario
|
|
24
|
+
from .generator import create_artificial_data, design_exposure_scenario, ExposureDataDict
|
|
25
25
|
|
|
26
26
|
from .expydb import (
|
|
27
27
|
to_dataset,
|
guts_base/data/generator.py
CHANGED
|
@@ -11,7 +11,8 @@ def create_artificial_data(
|
|
|
11
11
|
t_max,
|
|
12
12
|
dt,
|
|
13
13
|
exposure_paths=["oral", "topical", "contact"],
|
|
14
|
-
intensity=[0.1, 0.5, 0.05]
|
|
14
|
+
intensity=[0.1, 0.5, 0.05],
|
|
15
|
+
seed=1,
|
|
15
16
|
):
|
|
16
17
|
rng = np.random.default_rng(1)
|
|
17
18
|
time = np.arange(0, t_max, step=dt) # daily time resolution
|
guts_base/data/survival.py
CHANGED
|
@@ -40,6 +40,12 @@ def prepare_survival_data_for_conditional_binomial(observations: xr.Dataset) ->
|
|
|
40
40
|
nsurv.isel(time=list(range(0, len(nsurv.time)-1))).values
|
|
41
41
|
]).astype(int))})
|
|
42
42
|
|
|
43
|
+
observations = observations.assign_coords({
|
|
44
|
+
"survivors_at_start": (("id", "time"), np.broadcast_to(
|
|
45
|
+
nsurv.isel(time=0).values.reshape(-1,1),
|
|
46
|
+
shape=nsurv.shape
|
|
47
|
+
).astype(int))})
|
|
48
|
+
|
|
43
49
|
return observations
|
|
44
50
|
|
|
45
51
|
|
guts_base/mod.py
CHANGED
|
@@ -1,10 +1,3 @@
|
|
|
1
|
-
"""
|
|
2
|
-
GUTS models
|
|
3
|
-
|
|
4
|
-
TODO: Import guts models from mempy and update to work well with jax
|
|
5
|
-
TODO: Based on this implement the bufferguts model in mempy and import in the
|
|
6
|
-
bufferguts case-study
|
|
7
|
-
"""
|
|
8
1
|
from functools import partial
|
|
9
2
|
from pathlib import Path
|
|
10
3
|
|
|
@@ -19,79 +12,37 @@ from pymob.solvers.symbolic import (
|
|
|
19
12
|
PiecewiseSymbolicODESolver, FunctionPythonCode, get_return_arguments, dX_dt2X
|
|
20
13
|
)
|
|
21
14
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
)
|
|
15
|
+
class RED_SD:
|
|
16
|
+
"""Simplest guts model, mainly for testing"""
|
|
17
|
+
@staticmethod
|
|
18
|
+
def _rhs_jax(t, y, x_in, kd, b, m, hb):
|
|
19
|
+
D, H = y
|
|
20
|
+
dD_dt = kd * (x_in.evaluate(t) - D)
|
|
21
|
+
dH_dt = b * jnp.maximum(0.0, D - m) + hb
|
|
22
|
+
return dD_dt, dH_dt
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def _solver_post_processing(results, t, interpolation):
|
|
26
|
+
results["survival"] = jnp.exp(-results["H"])
|
|
27
|
+
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
28
|
+
return results
|
|
33
29
|
|
|
34
30
|
red_sd = RED_SD._rhs_jax
|
|
35
31
|
red_sd_post_processing = RED_SD._solver_post_processing
|
|
36
32
|
|
|
37
|
-
|
|
38
|
-
|
|
33
|
+
class RED_SD_DA(RED_SD):
|
|
34
|
+
@staticmethod
|
|
35
|
+
def _rhs_jax(t, y, x_in, kd, w, b, m, hb):
|
|
36
|
+
D, H = y
|
|
37
|
+
dD_dt = kd * (x_in.evaluate(t) - D)
|
|
38
|
+
dH_dt = b * jnp.maximum(0.0, jnp.sum(w * D) - m) + hb
|
|
39
|
+
return dD_dt, dH_dt
|
|
39
40
|
|
|
40
|
-
|
|
41
|
-
|
|
41
|
+
red_sd_da = RED_SD_DA._rhs_jax
|
|
42
|
+
red_sd_da_post_processing = RED_SD_DA._solver_post_processing
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
def
|
|
45
|
-
"""Computes the stochastic death survival probability after computing
|
|
46
|
-
|
|
47
|
-
"""
|
|
48
|
-
# calculate survival
|
|
49
|
-
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
50
|
-
p_surv = survival_jax(t, results["D"], z, k_k, h_b)
|
|
51
|
-
results["survival"] = p_surv
|
|
52
|
-
results["lethality"] = 1 - p_surv
|
|
53
|
-
return results
|
|
54
|
-
|
|
55
|
-
def it_post_processing(results, t, interpolation, alpha, beta, h_b, eps):
|
|
56
|
-
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
57
|
-
p_surv = survival_IT_jax(t, results["D"], alpha, beta, h_b, eps)
|
|
58
|
-
results["survival"] = p_surv
|
|
59
|
-
results["H"] = - jnp.log(p_surv)
|
|
60
|
-
return results
|
|
61
|
-
|
|
62
|
-
def post_exposure(results, t, interpolation):
|
|
63
|
-
results["survival"] = jnp.exp(-results["H"])
|
|
64
|
-
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
65
|
-
return results
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def no_post_processing(results):
|
|
69
|
-
return results
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@jax.jit
|
|
73
|
-
def survival_jax(t, damage, z, kk, h_b):
|
|
74
|
-
"""
|
|
75
|
-
survival probability derived from hazard
|
|
76
|
-
first calculate cumulative Hazard by integrating hazard cumulatively over t
|
|
77
|
-
then calculate the resulting survival probability
|
|
78
|
-
It was checked that `survival_jax` behaves exactly the same as `survival`
|
|
79
|
-
"""
|
|
80
|
-
hazard = kk * jnp.where(damage - z < 0, 0, damage - z) + h_b
|
|
81
|
-
H = jnp.array([jax.scipy.integrate.trapezoid(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
|
|
82
|
-
# H = jnp.array([jnp.trapz(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
|
|
83
|
-
S = jnp.exp(-H)
|
|
84
|
-
|
|
85
|
-
return S
|
|
86
|
-
|
|
87
|
-
@jax.jit
|
|
88
|
-
def survival_IT_jax(t, damage, alpha, beta, h_b, eps):
|
|
89
|
-
d_max = jnp.squeeze(jnp.array([jnp.max(damage[:i+1])+eps for i in range(len(t))]))
|
|
90
|
-
F = jnp.where(d_max > 0, 1.0 / (1.0 + (d_max / alpha) ** -beta), 0)
|
|
91
|
-
S = 1.0 * (jnp.array([1.0], dtype=float) - F) * jnp.exp(-h_b * t)
|
|
92
|
-
return S
|
|
93
|
-
|
|
94
|
-
def guts_jax(t, y, C_0, k_d, z, b, h_b):
|
|
45
|
+
def guts_constant_exposure(t, y, C_0, k_d, z, b, h_b):
|
|
95
46
|
# for constant exposure
|
|
96
47
|
D, H, S = y
|
|
97
48
|
dD_dt = k_d * (C_0 - D)
|
|
@@ -103,16 +54,6 @@ def guts_jax(t, y, C_0, k_d, z, b, h_b):
|
|
|
103
54
|
|
|
104
55
|
return dD_dt, dH_dt, dS_dt
|
|
105
56
|
|
|
106
|
-
|
|
107
|
-
def RED_IT(t, y, x_in, kd):
|
|
108
|
-
D, = y
|
|
109
|
-
C = x_in.evaluate(t)
|
|
110
|
-
|
|
111
|
-
dD_dt = kd * (C - D)
|
|
112
|
-
|
|
113
|
-
return (dD_dt, )
|
|
114
|
-
|
|
115
|
-
|
|
116
57
|
def guts_variable_exposure(t, y, x_in, k_d, z, b, h_b):
|
|
117
58
|
# for constant exposure
|
|
118
59
|
D, H, S = y
|
guts_base/prob.py
CHANGED
|
@@ -2,270 +2,10 @@ from functools import partial
|
|
|
2
2
|
import numpy as np
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
5
|
-
import numpyro
|
|
6
5
|
from numpyro.infer import Predictive
|
|
7
6
|
from typing import Literal
|
|
8
7
|
|
|
9
8
|
|
|
10
|
-
maximum = jnp.frompyfunc(jnp.maximum, nin=2, nout=1, identity=None) # type: ignore
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@jax.jit
|
|
14
|
-
def ffill_na(x, mask):
|
|
15
|
-
"""Forward-fill nan values in x. If a mask is provided, assume
|
|
16
|
-
|
|
17
|
-
Parameters
|
|
18
|
-
----------
|
|
19
|
-
x : _type_
|
|
20
|
-
_description_
|
|
21
|
-
mask : _type_
|
|
22
|
-
_description_
|
|
23
|
-
|
|
24
|
-
Returns
|
|
25
|
-
-------
|
|
26
|
-
_type_
|
|
27
|
-
_description_
|
|
28
|
-
"""
|
|
29
|
-
if mask is None:
|
|
30
|
-
mask = jnp.logical_not(jnp.isnan(x))
|
|
31
|
-
idx = jnp.where(mask,jnp.arange(mask.shape[1]),0)
|
|
32
|
-
idx_ = maximum.accumulate(jnp.array(idx), axis=1).astype(int)
|
|
33
|
-
return x[jnp.arange(idx.shape[0])[:,None], idx_]
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def conditional_survival_from_hazard(x, mask):
|
|
37
|
-
"""Calculates the conditional survival from cumulative hazard values.
|
|
38
|
-
This equation is used when survival is repeatedly observed over time and
|
|
39
|
-
|
|
40
|
-
Parameters
|
|
41
|
-
----------
|
|
42
|
-
x : np.ndarray[I,T, float]
|
|
43
|
-
A 2-dimensional I x T array of cumulative hazards defined as H = -ln(S).
|
|
44
|
-
I is the batch dimension and T is the time dimension
|
|
45
|
-
|
|
46
|
-
mask : np.ndarray[I,T, bool]
|
|
47
|
-
A 2-dimensional array of the same shape as x, taking True if the survival
|
|
48
|
-
was observed for the given index (i,t) and taking False if survival was
|
|
49
|
-
not observed for the given index (i,t).
|
|
50
|
-
|
|
51
|
-
Returns
|
|
52
|
-
-------
|
|
53
|
-
out : np.ndarray[I,T, float]
|
|
54
|
-
A matrix with conditional probabilities and nans in place where the
|
|
55
|
-
mask has nans. Output has the same shape as input.
|
|
56
|
-
|
|
57
|
-
Example
|
|
58
|
-
-------
|
|
59
|
-
Calculation example from survival probabilities to conditional survival
|
|
60
|
-
probabilities given some masked values.
|
|
61
|
-
|
|
62
|
-
>>> S_i = np.array([
|
|
63
|
-
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
|
|
64
|
-
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
|
|
65
|
-
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
|
|
66
|
-
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
|
|
67
|
-
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
|
|
68
|
-
>>> ])
|
|
69
|
-
|
|
70
|
-
>>> mask_obs = np.array([
|
|
71
|
-
>>> [ True, True, True, True, True, True],
|
|
72
|
-
>>> [ True, False, True, True, True, True],
|
|
73
|
-
>>> [ True, False, False, True, True, True],
|
|
74
|
-
>>> [ True, True, False, True, True, True],
|
|
75
|
-
>>> [ True, False, True, False, True, True],
|
|
76
|
-
>>> ])
|
|
77
|
-
|
|
78
|
-
>>> conditional_survival_from_hazard(-jnp.log(S_i), mask_obs)
|
|
79
|
-
array([
|
|
80
|
-
[1.0 0.75 0.133 0.5 0.2 0.0]
|
|
81
|
-
[1.0 nan 0.1 0.5 0.2 0.0]
|
|
82
|
-
[1.0 nan nan 0.05 0.2 0.0]
|
|
83
|
-
[1.0 0.75 nan 0.0667 0.2 0.0]
|
|
84
|
-
[1.0 nan 0.1 nan 0.1 0.0]
|
|
85
|
-
])
|
|
86
|
-
"""
|
|
87
|
-
|
|
88
|
-
# Append zeros (hazard) to the beginning of the array (this aligns with the
|
|
89
|
-
# safe assumption that before the zeroth observation S(t=-1) = 1.0)
|
|
90
|
-
x_ = jnp.column_stack([jnp.zeros_like(x[:, 0]), x])
|
|
91
|
-
|
|
92
|
-
# also mask needs to be expanded accordingly
|
|
93
|
-
mask_ = jnp.column_stack([jnp.ones_like(mask[:, 0]), mask])
|
|
94
|
-
|
|
95
|
-
# fill NaNs with forward
|
|
96
|
-
x_filled = ffill_na(x_, mask_)
|
|
97
|
-
|
|
98
|
-
# calculate the conditional survival.
|
|
99
|
-
conditional_survival = jnp.exp(x_filled[:, :-1] - (x_[:, 1:]))
|
|
100
|
-
|
|
101
|
-
# add nans and return
|
|
102
|
-
return jnp.where(
|
|
103
|
-
mask, conditional_survival, jnp.nan
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
def conditional_survival_hazard_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
|
|
108
|
-
"""Computes the likelihood of observing K survivors at time t given that
|
|
109
|
-
N survivors were alive at the previous observation t-1
|
|
110
|
-
|
|
111
|
-
This is achieved by using the Cumulative hazard directly in the calculation
|
|
112
|
-
of the conditional survival probability. Note that this equation can easily
|
|
113
|
-
be adapted to compute the conditional lethality probability.
|
|
114
|
-
|
|
115
|
-
$$\\Pr(t < T~|~t_0 < T) = \\frac{e^{-H(t)}}{e^{-H(t_0)}} = e^{-H(t) - (- H(t_0))} = e^{-H(t) + H(t_0)} = e^{-(H(t) - H(t_0))} = e^{H(t_0) - H(t)}$$
|
|
116
|
-
|
|
117
|
-
Example
|
|
118
|
-
-------
|
|
119
|
-
|
|
120
|
-
t0 t1 t2 t3 tinf
|
|
121
|
-
observations 10 8 5 2 0
|
|
122
|
-
Pr(T > t) 1.0 0.8 0.5 0.2 0.0
|
|
123
|
-
Pr(T > t | T > t-1) (1.0) 0.8 0.625 0.4 0.0
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
If the experiment ends before the last organism has died, these information
|
|
127
|
-
have to be included in he computation of the likelihood. For multinomial
|
|
128
|
-
distributions this is done by computing the remaining (ubobserved) lethality
|
|
129
|
-
until the end of the experiment, whihc is just the inverse of the number of
|
|
130
|
-
survivors at the end of the experiment. Therefore, for conditional survival
|
|
131
|
-
the information about the number of alive organisms
|
|
132
|
-
at the end of the experiment is contained in the last observation.
|
|
133
|
-
|
|
134
|
-
We compute an additional interval until the start of the experiment in order
|
|
135
|
-
to align the results with the observations (which include t=0). Otherwise
|
|
136
|
-
we would only compute T-1 intervals if we have T observations (including t=0).
|
|
137
|
-
"""
|
|
138
|
-
|
|
139
|
-
H = simulation_results["H"]
|
|
140
|
-
n_surv = observations["survivors_before_t"]
|
|
141
|
-
S_mask = masks["survival"]
|
|
142
|
-
obs_survival = observations["survival"]
|
|
143
|
-
S_conditional = conditional_survival_from_hazard(H, S_mask)
|
|
144
|
-
|
|
145
|
-
if make_predictions:
|
|
146
|
-
obs_survival = None
|
|
147
|
-
|
|
148
|
-
ID, T = S_mask.shape
|
|
149
|
-
|
|
150
|
-
with numpyro.plate("time", size=T):
|
|
151
|
-
with numpyro.plate("id", size=ID):
|
|
152
|
-
numpyro.sample(
|
|
153
|
-
"survival_obs", numpyro.distributions.Binomial(
|
|
154
|
-
probs=S_conditional,
|
|
155
|
-
total_count=n_surv
|
|
156
|
-
).mask(S_mask),
|
|
157
|
-
obs=obs_survival
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def conditional_survival_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
|
|
162
|
-
"""Note that the conditional survival error model will not work to generate
|
|
163
|
-
posterior predictions for observational uncertainty out of the box
|
|
164
|
-
|
|
165
|
-
This model can be used for SD and IT models, but for SD models, the
|
|
166
|
-
conditional_survival_hazard_model is preferrable.
|
|
167
|
-
"""
|
|
168
|
-
# error model
|
|
169
|
-
EPS = observations["eps"]
|
|
170
|
-
S = jnp.clip(simulation_results["survival"], EPS, 1 - EPS)
|
|
171
|
-
|
|
172
|
-
S_mask = masks["survival"]
|
|
173
|
-
n_surv = observations["survivors_before_t"]
|
|
174
|
-
obs_survival = observations["survival"]
|
|
175
|
-
|
|
176
|
-
S_ = ffill_na(S, S_mask)
|
|
177
|
-
S_cond_na = S_[:, 1:] / S_[:, :-1]
|
|
178
|
-
S_cond_na_ = jnp.column_stack([jnp.ones_like(S[:,[0]]), S_cond_na])
|
|
179
|
-
S_cond_na__ = jnp.clip(S_cond_na_, EPS, 1 - EPS)
|
|
180
|
-
|
|
181
|
-
if make_predictions:
|
|
182
|
-
obs_survival = None
|
|
183
|
-
|
|
184
|
-
ID, T = S_mask.shape
|
|
185
|
-
|
|
186
|
-
with numpyro.plate("time", size=T):
|
|
187
|
-
with numpyro.plate("id", size=ID):
|
|
188
|
-
numpyro.sample(
|
|
189
|
-
"survival_obs", numpyro.distributions.Binomial(
|
|
190
|
-
probs=S_cond_na__,
|
|
191
|
-
total_count=n_surv
|
|
192
|
-
).mask(S_mask),
|
|
193
|
-
obs=obs_survival
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
def conditional_lethality_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
|
|
198
|
-
raise NotImplementedError(
|
|
199
|
-
"Conditional lethality error model needs to forward fill survival " +
|
|
200
|
-
"probabilities for missing (observation) values"
|
|
201
|
-
""
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
# error model
|
|
205
|
-
EPS = observations["eps"]
|
|
206
|
-
S = jnp.clip(simulation_results["S"], EPS, 1 - EPS)
|
|
207
|
-
|
|
208
|
-
# TODO work on this
|
|
209
|
-
# Conditional survival may be wrong, but also maybe the data generation
|
|
210
|
-
# fct may be wrong
|
|
211
|
-
# a test showed that it looks preddy darn good. I guess this is because
|
|
212
|
-
# it is the conditional probability to survive. Not the conditional
|
|
213
|
-
# probability to die.
|
|
214
|
-
S_cond = (S[:, :-1] - S[:, 1:]) / S[:, :-1]
|
|
215
|
-
S_cond_ = jnp.column_stack([jnp.zeros_like(S[:,[0]]), S_cond])
|
|
216
|
-
S_cond__ = jnp.clip(S_cond_, EPS, 1 - EPS)
|
|
217
|
-
|
|
218
|
-
n_surv = observations["survivors_before_t"]
|
|
219
|
-
L_mask = masks["L"]
|
|
220
|
-
L = observations["L"]
|
|
221
|
-
|
|
222
|
-
if make_predictions:
|
|
223
|
-
L = None
|
|
224
|
-
|
|
225
|
-
numpyro.sample(
|
|
226
|
-
"L_obs", numpyro.distributions.Binomial(
|
|
227
|
-
probs=S_cond__,
|
|
228
|
-
total_count=n_surv
|
|
229
|
-
).mask(L_mask),
|
|
230
|
-
obs=L
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
def multinomial_error_model(theta, simulation_results, observations, indices, masks, make_predictions):
|
|
235
|
-
raise NotImplementedError(
|
|
236
|
-
"Multinomial error model needs to forward fill survival " +
|
|
237
|
-
"probabilities for missing (observation) values"
|
|
238
|
-
""
|
|
239
|
-
)
|
|
240
|
-
# error model
|
|
241
|
-
EPS = observations["eps"]
|
|
242
|
-
# TODO This should not be done before calculating the difference
|
|
243
|
-
S = jnp.clip(simulation_results["S"], EPS, 1 - EPS)
|
|
244
|
-
|
|
245
|
-
# TODO work on this
|
|
246
|
-
# This is not working with replicates
|
|
247
|
-
s_probs = S[:, :-1] - S[:, 1:]
|
|
248
|
-
s_probs_ = jnp.column_stack([jnp.zeros_like(S[:,[0]]), s_probs])
|
|
249
|
-
|
|
250
|
-
L_mask = masks["L"]
|
|
251
|
-
L = observations["L"]
|
|
252
|
-
N = observations["n_subjects"]
|
|
253
|
-
# N_ = jnp.broadcast_to(jnp.expand_dims(N, 1), L.shape)
|
|
254
|
-
|
|
255
|
-
if make_predictions:
|
|
256
|
-
L = None
|
|
257
|
-
|
|
258
|
-
ID, T = L_mask.shape
|
|
259
|
-
|
|
260
|
-
with numpyro.plate("id", ID):
|
|
261
|
-
numpyro.sample(
|
|
262
|
-
"L_obs", numpyro.distributions.Multinomial(
|
|
263
|
-
probs=s_probs_,
|
|
264
|
-
total_count=N
|
|
265
|
-
).mask(L_mask),
|
|
266
|
-
obs=L
|
|
267
|
-
)
|
|
268
|
-
|
|
269
9
|
def survival_predictions(
|
|
270
10
|
probs, n_trials,
|
|
271
11
|
eps: float = 0.0,
|
|
@@ -351,6 +91,8 @@ def posterior_predictions(sim, idata, seed=None):
|
|
|
351
91
|
obs, masks = sim.inferer.observation_parser()
|
|
352
92
|
|
|
353
93
|
model_kwargs = sim.inferer.preprocessing(obs=obs, masks=masks)
|
|
94
|
+
if sim.config.inference_numpyro.user_defined_error_model is None:
|
|
95
|
+
model_kwargs["obs"]["survival"] = None
|
|
354
96
|
|
|
355
97
|
# prepare model
|
|
356
98
|
model = partial(
|
|
@@ -371,24 +113,30 @@ def posterior_predictions(sim, idata, seed=None):
|
|
|
371
113
|
|
|
372
114
|
samples = predictive(key)
|
|
373
115
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
116
|
+
if sim.config.inference_numpyro.user_defined_error_model is not None:
|
|
117
|
+
chains = []
|
|
118
|
+
for i in range(c):
|
|
119
|
+
# TODO: Switch to vmap and jit, but it did not work, so if you do it TEST IT!!!!!!!
|
|
120
|
+
predictions = list(map(
|
|
121
|
+
partial(
|
|
122
|
+
survival_predictions,
|
|
123
|
+
n_trials=obs["survival"][:, 0].astype(int),
|
|
124
|
+
eps=obs["eps"],
|
|
125
|
+
seed=seed,
|
|
126
|
+
mode="survival",
|
|
127
|
+
),
|
|
128
|
+
samples["survival"][i]
|
|
129
|
+
))
|
|
130
|
+
chains.append(predictions)
|
|
131
|
+
|
|
132
|
+
posterior_predictive = {"survival_obs": np.array(chains)}
|
|
133
|
+
else:
|
|
134
|
+
posterior_predictive = {"survival_obs": samples.pop("survival_obs")}
|
|
135
|
+
|
|
388
136
|
|
|
389
137
|
new_idata = sim.inferer.to_arviz_idata(
|
|
390
138
|
posterior=samples,
|
|
391
|
-
posterior_predictive=
|
|
139
|
+
posterior_predictive=posterior_predictive,
|
|
392
140
|
n_draws=n,
|
|
393
141
|
n_chains=c
|
|
394
142
|
)
|
guts_base/sim/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from . import base
|
|
2
2
|
from . import ecx
|
|
3
3
|
from . import report
|
|
4
|
+
from . import utils
|
|
4
5
|
|
|
5
6
|
from .base import (
|
|
6
7
|
GutsBase,
|
|
@@ -11,4 +12,12 @@ from .base import (
|
|
|
11
12
|
from .ecx import ECxEstimator, LPxEstimator
|
|
12
13
|
from .report import GutsReport
|
|
13
14
|
|
|
14
|
-
from .mempy import PymobSimulator
|
|
15
|
+
from .mempy import PymobSimulator
|
|
16
|
+
from .utils import (
|
|
17
|
+
GutsBaseError
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from .constructors import (
|
|
21
|
+
construct_sim_from_config,
|
|
22
|
+
load_idata
|
|
23
|
+
)
|