guts-base 0.8.6__tar.gz → 1.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- {guts_base-0.8.6 → guts_base-1.0.1}/PKG-INFO +2 -3
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/__init__.py +2 -1
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/data/__init__.py +1 -1
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/data/generator.py +2 -1
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/data/survival.py +6 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/mod.py +24 -83
- guts_base-1.0.1/guts_base/prob.py +160 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/sim/__init__.py +10 -1
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/sim/base.py +285 -75
- guts_base-1.0.1/guts_base/sim/constructors.py +31 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/sim/ecx.py +174 -59
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/sim/mempy.py +85 -70
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/sim/report.py +0 -1
- guts_base-1.0.1/guts_base/sim/utils.py +10 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base.egg-info/PKG-INFO +2 -3
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base.egg-info/SOURCES.txt +2 -2
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base.egg-info/requires.txt +1 -2
- {guts_base-0.8.6 → guts_base-1.0.1}/pyproject.toml +3 -4
- guts_base-1.0.1/tests/test_ecx.py +38 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/tests/test_from_pymob.py +1 -8
- {guts_base-0.8.6 → guts_base-1.0.1}/tests/test_simulations.py +3 -29
- guts_base-0.8.6/guts_base/prob.py +0 -412
- guts_base-0.8.6/guts_base/sim.py +0 -0
- guts_base-0.8.6/tests/test_ecx.py +0 -62
- guts_base-0.8.6/tests/test_simulations_from_mempy.py +0 -119
- {guts_base-0.8.6 → guts_base-1.0.1}/LICENSE +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/README.md +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/data/expydb.py +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/data/openguts.py +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/data/preprocessing.py +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/data/time_of_death.py +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/data/utils.py +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base/plot.py +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base.egg-info/dependency_links.txt +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base.egg-info/entry_points.txt +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/guts_base.egg-info/top_level.txt +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/setup.cfg +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/tests/test_data_import.py +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/tests/test_scripted_simulations.py +0 -0
- {guts_base-0.8.6 → guts_base-1.0.1}/tests/test_symbolic_solve.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: guts_base
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.1
|
|
4
4
|
Summary: Basic GUTS model implementation in pymob
|
|
5
5
|
Author-email: Florian Schunck <fluncki@protonmail.com>
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -692,8 +692,7 @@ License-File: LICENSE
|
|
|
692
692
|
Requires-Dist: openpyxl>=3.1.3
|
|
693
693
|
Requires-Dist: Bottleneck>=1.5.0
|
|
694
694
|
Requires-Dist: expydb>=0.6.0
|
|
695
|
-
Requires-Dist:
|
|
696
|
-
Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.4.1
|
|
695
|
+
Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.5.10
|
|
697
696
|
Provides-Extra: dev
|
|
698
697
|
Requires-Dist: pytest>=7.3; extra == "dev"
|
|
699
698
|
Requires-Dist: bumpver; extra == "dev"
|
|
@@ -21,7 +21,7 @@ from .survival import (
|
|
|
21
21
|
generate_survival_repeated_observations
|
|
22
22
|
)
|
|
23
23
|
|
|
24
|
-
from .generator import create_artificial_data, design_exposure_scenario
|
|
24
|
+
from .generator import create_artificial_data, design_exposure_scenario, ExposureDataDict
|
|
25
25
|
|
|
26
26
|
from .expydb import (
|
|
27
27
|
to_dataset,
|
|
@@ -11,7 +11,8 @@ def create_artificial_data(
|
|
|
11
11
|
t_max,
|
|
12
12
|
dt,
|
|
13
13
|
exposure_paths=["oral", "topical", "contact"],
|
|
14
|
-
intensity=[0.1, 0.5, 0.05]
|
|
14
|
+
intensity=[0.1, 0.5, 0.05],
|
|
15
|
+
seed=1,
|
|
15
16
|
):
|
|
16
17
|
rng = np.random.default_rng(1)
|
|
17
18
|
time = np.arange(0, t_max, step=dt) # daily time resolution
|
|
@@ -40,6 +40,12 @@ def prepare_survival_data_for_conditional_binomial(observations: xr.Dataset) ->
|
|
|
40
40
|
nsurv.isel(time=list(range(0, len(nsurv.time)-1))).values
|
|
41
41
|
]).astype(int))})
|
|
42
42
|
|
|
43
|
+
observations = observations.assign_coords({
|
|
44
|
+
"survivors_at_start": (("id", "time"), np.broadcast_to(
|
|
45
|
+
nsurv.isel(time=0).values.reshape(-1,1),
|
|
46
|
+
shape=nsurv.shape
|
|
47
|
+
).astype(int))})
|
|
48
|
+
|
|
43
49
|
return observations
|
|
44
50
|
|
|
45
51
|
|
|
@@ -1,10 +1,3 @@
|
|
|
1
|
-
"""
|
|
2
|
-
GUTS models
|
|
3
|
-
|
|
4
|
-
TODO: Import guts models from mempy and update to work well with jax
|
|
5
|
-
TODO: Based on this implement the bufferguts model in mempy and import in the
|
|
6
|
-
bufferguts case-study
|
|
7
|
-
"""
|
|
8
1
|
from functools import partial
|
|
9
2
|
from pathlib import Path
|
|
10
3
|
|
|
@@ -19,79 +12,37 @@ from pymob.solvers.symbolic import (
|
|
|
19
12
|
PiecewiseSymbolicODESolver, FunctionPythonCode, get_return_arguments, dX_dt2X
|
|
20
13
|
)
|
|
21
14
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
)
|
|
15
|
+
class RED_SD:
|
|
16
|
+
"""Simplest guts model, mainly for testing"""
|
|
17
|
+
@staticmethod
|
|
18
|
+
def _rhs_jax(t, y, x_in, kd, b, m, hb):
|
|
19
|
+
D, H = y
|
|
20
|
+
dD_dt = kd * (x_in.evaluate(t) - D)
|
|
21
|
+
dH_dt = b * jnp.maximum(0.0, D - m) + hb
|
|
22
|
+
return dD_dt, dH_dt
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def _solver_post_processing(results, t, interpolation):
|
|
26
|
+
results["survival"] = jnp.exp(-results["H"])
|
|
27
|
+
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
28
|
+
return results
|
|
33
29
|
|
|
34
30
|
red_sd = RED_SD._rhs_jax
|
|
35
31
|
red_sd_post_processing = RED_SD._solver_post_processing
|
|
36
32
|
|
|
37
|
-
|
|
38
|
-
|
|
33
|
+
class RED_SD_DA(RED_SD):
|
|
34
|
+
@staticmethod
|
|
35
|
+
def _rhs_jax(t, y, x_in, kd, w, b, m, hb):
|
|
36
|
+
D, H = y
|
|
37
|
+
dD_dt = kd * (x_in.evaluate(t) - D)
|
|
38
|
+
dH_dt = b * jnp.maximum(0.0, jnp.sum(w * D) - m) + hb
|
|
39
|
+
return dD_dt, dH_dt
|
|
39
40
|
|
|
40
|
-
|
|
41
|
-
|
|
41
|
+
red_sd_da = RED_SD_DA._rhs_jax
|
|
42
|
+
red_sd_da_post_processing = RED_SD_DA._solver_post_processing
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
def
|
|
45
|
-
"""Computes the stochastic death survival probability after computing
|
|
46
|
-
|
|
47
|
-
"""
|
|
48
|
-
# calculate survival
|
|
49
|
-
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
50
|
-
p_surv = survival_jax(t, results["D"], z, k_k, h_b)
|
|
51
|
-
results["survival"] = p_surv
|
|
52
|
-
results["lethality"] = 1 - p_surv
|
|
53
|
-
return results
|
|
54
|
-
|
|
55
|
-
def it_post_processing(results, t, interpolation, alpha, beta, h_b, eps):
|
|
56
|
-
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
57
|
-
p_surv = survival_IT_jax(t, results["D"], alpha, beta, h_b, eps)
|
|
58
|
-
results["survival"] = p_surv
|
|
59
|
-
results["H"] = - jnp.log(p_surv)
|
|
60
|
-
return results
|
|
61
|
-
|
|
62
|
-
def post_exposure(results, t, interpolation):
|
|
63
|
-
results["survival"] = jnp.exp(-results["H"])
|
|
64
|
-
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
65
|
-
return results
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def no_post_processing(results):
|
|
69
|
-
return results
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@jax.jit
|
|
73
|
-
def survival_jax(t, damage, z, kk, h_b):
|
|
74
|
-
"""
|
|
75
|
-
survival probability derived from hazard
|
|
76
|
-
first calculate cumulative Hazard by integrating hazard cumulatively over t
|
|
77
|
-
then calculate the resulting survival probability
|
|
78
|
-
It was checked that `survival_jax` behaves exactly the same as `survival`
|
|
79
|
-
"""
|
|
80
|
-
hazard = kk * jnp.where(damage - z < 0, 0, damage - z) + h_b
|
|
81
|
-
H = jnp.array([jax.scipy.integrate.trapezoid(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
|
|
82
|
-
# H = jnp.array([jnp.trapz(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
|
|
83
|
-
S = jnp.exp(-H)
|
|
84
|
-
|
|
85
|
-
return S
|
|
86
|
-
|
|
87
|
-
@jax.jit
|
|
88
|
-
def survival_IT_jax(t, damage, alpha, beta, h_b, eps):
|
|
89
|
-
d_max = jnp.squeeze(jnp.array([jnp.max(damage[:i+1])+eps for i in range(len(t))]))
|
|
90
|
-
F = jnp.where(d_max > 0, 1.0 / (1.0 + (d_max / alpha) ** -beta), 0)
|
|
91
|
-
S = 1.0 * (jnp.array([1.0], dtype=float) - F) * jnp.exp(-h_b * t)
|
|
92
|
-
return S
|
|
93
|
-
|
|
94
|
-
def guts_jax(t, y, C_0, k_d, z, b, h_b):
|
|
45
|
+
def guts_constant_exposure(t, y, C_0, k_d, z, b, h_b):
|
|
95
46
|
# for constant exposure
|
|
96
47
|
D, H, S = y
|
|
97
48
|
dD_dt = k_d * (C_0 - D)
|
|
@@ -103,16 +54,6 @@ def guts_jax(t, y, C_0, k_d, z, b, h_b):
|
|
|
103
54
|
|
|
104
55
|
return dD_dt, dH_dt, dS_dt
|
|
105
56
|
|
|
106
|
-
|
|
107
|
-
def RED_IT(t, y, x_in, kd):
|
|
108
|
-
D, = y
|
|
109
|
-
C = x_in.evaluate(t)
|
|
110
|
-
|
|
111
|
-
dD_dt = kd * (C - D)
|
|
112
|
-
|
|
113
|
-
return (dD_dt, )
|
|
114
|
-
|
|
115
|
-
|
|
116
57
|
def guts_variable_exposure(t, y, x_in, k_d, z, b, h_b):
|
|
117
58
|
# for constant exposure
|
|
118
59
|
D, H, S = y
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
import numpy as np
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from numpyro.infer import Predictive
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def survival_predictions(
|
|
10
|
+
probs, n_trials,
|
|
11
|
+
eps: float = 0.0,
|
|
12
|
+
seed=1,
|
|
13
|
+
mode: Literal["survival", "lethality", "deaths"] = "survival"
|
|
14
|
+
):
|
|
15
|
+
"""Generate predictions for survival based on a multinomial survival distribution
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
probs : ArrayLike
|
|
20
|
+
2D Array denoting the multinomial probabilities of deaths for each time interval
|
|
21
|
+
per experiment
|
|
22
|
+
dims=(experiment, time)
|
|
23
|
+
n_trials : ArrayLike
|
|
24
|
+
1D Array denoting the number of organisms at the beginning in each experiment
|
|
25
|
+
dims = (experiment,)
|
|
26
|
+
seed : int, optional
|
|
27
|
+
Seed for the random number generator, by default 1
|
|
28
|
+
mode : str, optional
|
|
29
|
+
How should the random draws be returned?
|
|
30
|
+
- survival: Decreasing from n_trials to 0
|
|
31
|
+
- lethality: Increasing from 0 to n_trials
|
|
32
|
+
- deaths: Between 0 and n_trials in each interval. Summing to n_trials
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def survival_to_death_probs(pr_survival):
|
|
36
|
+
# truncate here, because numeric errors below the solver tolerance can
|
|
37
|
+
# lead to negative values in the difference. This needs to be cured here
|
|
38
|
+
pr_survival_ = np.trunc(pr_survival / eps) * eps
|
|
39
|
+
pr_death = pr_survival_[:-1] - pr_survival_[1:]
|
|
40
|
+
|
|
41
|
+
pr_death = np.concatenate([
|
|
42
|
+
# concatenate a zero at the beginning in order to "simulate" no
|
|
43
|
+
# deaths at T = 0
|
|
44
|
+
jnp.zeros((1,)),
|
|
45
|
+
# Delta S
|
|
46
|
+
pr_death,
|
|
47
|
+
# The remaining mortility as T -> infinity
|
|
48
|
+
jnp.ones((1,))-pr_death.sum()
|
|
49
|
+
])
|
|
50
|
+
|
|
51
|
+
# make sure the vector is not zero or 1 (this is always problematic for
|
|
52
|
+
# probabilities) and make sure the vector sums to 1
|
|
53
|
+
pr_death = np.clip(pr_death, eps, 1-eps)
|
|
54
|
+
pr_death = pr_death / pr_death.sum()
|
|
55
|
+
return pr_death
|
|
56
|
+
|
|
57
|
+
rng = np.random.default_rng(seed)
|
|
58
|
+
deaths = jnp.array(list(map(
|
|
59
|
+
lambda n, pr_survival: rng.multinomial(
|
|
60
|
+
n=n, pvals=survival_to_death_probs(pr_survival)
|
|
61
|
+
),
|
|
62
|
+
n_trials,
|
|
63
|
+
probs
|
|
64
|
+
)))
|
|
65
|
+
|
|
66
|
+
# remove the last observations to trim off the simulated unobserved mortality
|
|
67
|
+
deaths = deaths[:, :-1]
|
|
68
|
+
|
|
69
|
+
if mode == "deaths":
|
|
70
|
+
return deaths
|
|
71
|
+
elif mode == "lethality":
|
|
72
|
+
return deaths.cumsum(axis=1)
|
|
73
|
+
elif mode == "survival":
|
|
74
|
+
return np.expand_dims(n_trials, axis=1) - deaths.cumsum(axis=1)
|
|
75
|
+
else:
|
|
76
|
+
raise NotImplementedError(
|
|
77
|
+
f"Mode {mode} is not implemented."+
|
|
78
|
+
"Use one of 'survival', 'lethality', or 'deaths'."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def posterior_predictions(sim, idata, seed=None):
|
|
82
|
+
"""Make posterior predictions for survival data"""
|
|
83
|
+
if seed is None:
|
|
84
|
+
seed = sim.config.simulation.seed
|
|
85
|
+
|
|
86
|
+
n = idata.posterior.dims["draw"]
|
|
87
|
+
c = idata.posterior.dims["chain"]
|
|
88
|
+
|
|
89
|
+
key = jax.random.PRNGKey(seed)
|
|
90
|
+
|
|
91
|
+
obs, masks = sim.inferer.observation_parser()
|
|
92
|
+
|
|
93
|
+
model_kwargs = sim.inferer.preprocessing(obs=obs, masks=masks)
|
|
94
|
+
if sim.config.inference_numpyro.user_defined_error_model is None:
|
|
95
|
+
model_kwargs["obs"]["survival"] = None
|
|
96
|
+
|
|
97
|
+
# prepare model
|
|
98
|
+
model = partial(
|
|
99
|
+
sim.inferer.inference_model,
|
|
100
|
+
solver=sim.inferer.evaluator,
|
|
101
|
+
**model_kwargs
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
posterior_samples = {
|
|
105
|
+
k: np.array(v["data"]) for k, v
|
|
106
|
+
in idata.unconstrained_posterior.to_dict()["data_vars"].items()
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
predictive = Predictive(
|
|
110
|
+
model, posterior_samples=posterior_samples,
|
|
111
|
+
num_samples=n, batch_ndims=2
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
samples = predictive(key)
|
|
115
|
+
|
|
116
|
+
if sim.config.inference_numpyro.user_defined_error_model is not None:
|
|
117
|
+
chains = []
|
|
118
|
+
for i in range(c):
|
|
119
|
+
# TODO: Switch to vmap and jit, but it did not work, so if you do it TEST IT!!!!!!!
|
|
120
|
+
predictions = list(map(
|
|
121
|
+
partial(
|
|
122
|
+
survival_predictions,
|
|
123
|
+
n_trials=obs["survival"][:, 0].astype(int),
|
|
124
|
+
eps=obs["eps"],
|
|
125
|
+
seed=seed,
|
|
126
|
+
mode="survival",
|
|
127
|
+
),
|
|
128
|
+
samples["survival"][i]
|
|
129
|
+
))
|
|
130
|
+
chains.append(predictions)
|
|
131
|
+
|
|
132
|
+
posterior_predictive = {"survival_obs": np.array(chains)}
|
|
133
|
+
else:
|
|
134
|
+
posterior_predictive = {"survival_obs": samples.pop("survival_obs")}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
new_idata = sim.inferer.to_arviz_idata(
|
|
138
|
+
posterior=samples,
|
|
139
|
+
posterior_predictive=posterior_predictive,
|
|
140
|
+
n_draws=n,
|
|
141
|
+
n_chains=c
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# update chain names in case they were subselected (clustering)
|
|
145
|
+
# new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
|
|
146
|
+
new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
|
|
147
|
+
|
|
148
|
+
# assert the new posterior matches the old posterior
|
|
149
|
+
tol = sim.config.jaxsolver.atol * 100
|
|
150
|
+
abs_diff_posterior = np.abs(idata.posterior - new_idata.posterior)
|
|
151
|
+
np.testing.assert_array_less(abs_diff_posterior.mean().to_array(), tol)
|
|
152
|
+
|
|
153
|
+
fit_tol = tol * sim.coordinates["time"].max()
|
|
154
|
+
abs_diff_fits = np.abs(new_idata.posterior_model_fits - idata.posterior_model_fits)
|
|
155
|
+
np.testing.assert_array_less(abs_diff_fits.mean().to_array(), fit_tol)
|
|
156
|
+
|
|
157
|
+
idata.posterior_model_fits = new_idata.posterior_model_fits
|
|
158
|
+
idata.posterior_predictive = new_idata.posterior_predictive
|
|
159
|
+
|
|
160
|
+
return idata
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from . import base
|
|
2
2
|
from . import ecx
|
|
3
3
|
from . import report
|
|
4
|
+
from . import utils
|
|
4
5
|
|
|
5
6
|
from .base import (
|
|
6
7
|
GutsBase,
|
|
@@ -11,4 +12,12 @@ from .base import (
|
|
|
11
12
|
from .ecx import ECxEstimator, LPxEstimator
|
|
12
13
|
from .report import GutsReport
|
|
13
14
|
|
|
14
|
-
from .mempy import PymobSimulator
|
|
15
|
+
from .mempy import PymobSimulator
|
|
16
|
+
from .utils import (
|
|
17
|
+
GutsBaseError
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from .constructors import (
|
|
21
|
+
construct_sim_from_config,
|
|
22
|
+
load_idata
|
|
23
|
+
)
|