guts-base 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- guts_base/__init__.py +14 -0
- guts_base/data/__init__.py +34 -0
- guts_base/data/expydb.py +247 -0
- guts_base/data/generator.py +96 -0
- guts_base/data/openguts.py +294 -0
- guts_base/data/preprocessing.py +55 -0
- guts_base/data/survival.py +137 -0
- guts_base/data/time_of_death.py +571 -0
- guts_base/data/utils.py +8 -0
- guts_base/mod.py +251 -0
- guts_base/plot.py +162 -0
- guts_base/prob.py +412 -0
- guts_base/sim/__init__.py +14 -0
- guts_base/sim/base.py +464 -0
- guts_base/sim/ecx.py +357 -0
- guts_base/sim/mempy.py +252 -0
- guts_base/sim/report.py +72 -0
- guts_base/sim.py +0 -0
- guts_base-0.8.2.dist-info/METADATA +836 -0
- guts_base-0.8.2.dist-info/RECORD +24 -0
- guts_base-0.8.2.dist-info/WHEEL +5 -0
- guts_base-0.8.2.dist-info/entry_points.txt +3 -0
- guts_base-0.8.2.dist-info/licenses/LICENSE +674 -0
- guts_base-0.8.2.dist-info/top_level.txt +1 -0
guts_base/mod.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GUTS models
|
|
3
|
+
|
|
4
|
+
TODO: Import guts models from mempy and update to work well with jax
|
|
5
|
+
TODO: Based on this implement the bufferguts model in mempy and import in the
|
|
6
|
+
bufferguts case-study
|
|
7
|
+
"""
|
|
8
|
+
from functools import partial
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import jax
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
import sympy as sp
|
|
14
|
+
import numpy as np
|
|
15
|
+
from scipy.interpolate import interp1d
|
|
16
|
+
|
|
17
|
+
from pymob.solvers.base import mappar
|
|
18
|
+
from pymob.solvers.symbolic import (
|
|
19
|
+
PiecewiseSymbolicODESolver, FunctionPythonCode, get_return_arguments, dX_dt2X
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from mempy.model import (
|
|
23
|
+
RED_IT,
|
|
24
|
+
RED_SD,
|
|
25
|
+
RED_IT_DA,
|
|
26
|
+
RED_SD_DA,
|
|
27
|
+
RED_IT_IA,
|
|
28
|
+
RED_SD_IA,
|
|
29
|
+
BufferGUTS_IT,
|
|
30
|
+
BufferGUTS_IT_CA,
|
|
31
|
+
BufferGUTS_IT_DA
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
red_sd = RED_SD._rhs_jax
|
|
35
|
+
red_it = RED_IT._rhs_jax
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def p_survival(results, t, interpolation, z, k_k, h_b):
|
|
39
|
+
"""Computes the stochastic death survival probability after computing
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
# calculate survival
|
|
43
|
+
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
44
|
+
p_surv = survival_jax(t, results["D"], z, k_k, h_b)
|
|
45
|
+
results["survival"] = p_surv
|
|
46
|
+
results["lethality"] = 1 - p_surv
|
|
47
|
+
return results
|
|
48
|
+
|
|
49
|
+
def it_post_processing(results, t, interpolation, alpha, beta, h_b, eps):
|
|
50
|
+
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
51
|
+
p_surv = survival_IT_jax(t, results["D"], alpha, beta, h_b, eps)
|
|
52
|
+
results["survival"] = p_surv
|
|
53
|
+
results["H"] = - jnp.log(p_surv)
|
|
54
|
+
return results
|
|
55
|
+
|
|
56
|
+
def post_exposure(results, t, interpolation):
|
|
57
|
+
results["survival"] = jnp.exp(-results["H"])
|
|
58
|
+
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
59
|
+
return results
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def no_post_processing(results):
|
|
63
|
+
return results
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@jax.jit
|
|
67
|
+
def survival_jax(t, damage, z, kk, h_b):
|
|
68
|
+
"""
|
|
69
|
+
survival probability derived from hazard
|
|
70
|
+
first calculate cumulative Hazard by integrating hazard cumulatively over t
|
|
71
|
+
then calculate the resulting survival probability
|
|
72
|
+
It was checked that `survival_jax` behaves exactly the same as `survival`
|
|
73
|
+
"""
|
|
74
|
+
hazard = kk * jnp.where(damage - z < 0, 0, damage - z) + h_b
|
|
75
|
+
H = jnp.array([jax.scipy.integrate.trapezoid(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
|
|
76
|
+
# H = jnp.array([jnp.trapz(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
|
|
77
|
+
S = jnp.exp(-H)
|
|
78
|
+
|
|
79
|
+
return S
|
|
80
|
+
|
|
81
|
+
@jax.jit
|
|
82
|
+
def survival_IT_jax(t, damage, alpha, beta, h_b, eps):
|
|
83
|
+
d_max = jnp.squeeze(jnp.array([jnp.max(damage[:i+1])+eps for i in range(len(t))]))
|
|
84
|
+
F = jnp.where(d_max > 0, 1.0 / (1.0 + (d_max / alpha) ** -beta), 0)
|
|
85
|
+
S = 1.0 * (jnp.array([1.0], dtype=float) - F) * jnp.exp(-h_b * t)
|
|
86
|
+
return S
|
|
87
|
+
|
|
88
|
+
def guts_jax(t, y, C_0, k_d, z, b, h_b):
|
|
89
|
+
# for constant exposure
|
|
90
|
+
D, H, S = y
|
|
91
|
+
dD_dt = k_d * (C_0 - D)
|
|
92
|
+
|
|
93
|
+
switchDS = 0.5 + (1 / jnp.pi) * jnp.arctan(1e16 * (D - z))
|
|
94
|
+
dH_dt = (b * switchDS * (D - z) + h_b)
|
|
95
|
+
|
|
96
|
+
dS_dt = -dH_dt * S
|
|
97
|
+
|
|
98
|
+
return dD_dt, dH_dt, dS_dt
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def RED_IT(t, y, x_in, kd):
|
|
102
|
+
D, = y
|
|
103
|
+
C = x_in.evaluate(t)
|
|
104
|
+
|
|
105
|
+
dD_dt = kd * (C - D)
|
|
106
|
+
|
|
107
|
+
return (dD_dt, )
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def guts_variable_exposure(t, y, x_in, k_d, z, b, h_b):
|
|
111
|
+
# for constant exposure
|
|
112
|
+
D, H, S = y
|
|
113
|
+
C = x_in.evaluate(t)
|
|
114
|
+
dD_dt = k_d * (C - D)
|
|
115
|
+
|
|
116
|
+
switchDS = 0.5 + (1 / jnp.pi) * jnp.arctan(1e16 * (D - z))
|
|
117
|
+
dH_dt = (b * switchDS * (D - z) + h_b)
|
|
118
|
+
|
|
119
|
+
dS_dt = -dH_dt * S
|
|
120
|
+
|
|
121
|
+
return dD_dt, dH_dt, dS_dt
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Interpolation:
|
|
125
|
+
def __init__(self, xs, ys, method="previous") -> None:
|
|
126
|
+
self.f = interp1d(
|
|
127
|
+
x=xs,
|
|
128
|
+
y=ys,
|
|
129
|
+
axis=0,
|
|
130
|
+
kind=method
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def evaluate(self, t) -> np.ndarray:
|
|
134
|
+
return self.f(t)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class PiecewiseSymbolicSolver(PiecewiseSymbolicODESolver):
|
|
138
|
+
interpolation_type = "previous"
|
|
139
|
+
|
|
140
|
+
def t_jump(self, func_name, compiled_functions={}):
|
|
141
|
+
t, Y, Y_0, theta = self.define_symbols()
|
|
142
|
+
|
|
143
|
+
D_t = compiled_functions["F"]["algebraic_solutions"]["D"]
|
|
144
|
+
z = theta["z"]
|
|
145
|
+
eq = sp.Eq(D_t.rhs, z).expand()
|
|
146
|
+
|
|
147
|
+
t_0 = sp.solve(eq, t)
|
|
148
|
+
|
|
149
|
+
assert len(t_0) == 1
|
|
150
|
+
func = sp.simplify(t_0[0])
|
|
151
|
+
|
|
152
|
+
python_code = FunctionPythonCode(
|
|
153
|
+
func_name=func_name,
|
|
154
|
+
lhs_0=("Y_0", tuple(Y_0.keys())),
|
|
155
|
+
theta=("theta", tuple(theta.keys())),
|
|
156
|
+
lhs=("Y",),
|
|
157
|
+
rhs=(func,),
|
|
158
|
+
expand_arguments=False,
|
|
159
|
+
modules=("numpy","scipy"),
|
|
160
|
+
docstring=""
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
tex = self.to_latex(solutions=[sp.Eq(sp.Symbol(func_name), func)])
|
|
164
|
+
code_file = Path(self.output_path, f"{func_name}.tex")
|
|
165
|
+
with open(code_file, "w") as f:
|
|
166
|
+
f.writelines(tex)
|
|
167
|
+
|
|
168
|
+
return func, python_code
|
|
169
|
+
|
|
170
|
+
def define_symbols(self):
|
|
171
|
+
"""Define the necessary symbols solely based on the function"""
|
|
172
|
+
thetanames = mappar(
|
|
173
|
+
self.model, {},
|
|
174
|
+
exclude=["t", "dt", "y", "x_in", "Y", "X"],
|
|
175
|
+
to="names"
|
|
176
|
+
)
|
|
177
|
+
ynames = [dX_dt2X(a) for a in get_return_arguments(self.model)]
|
|
178
|
+
|
|
179
|
+
# define symbols for t, Y, Y_0 and theta
|
|
180
|
+
t = sp.Symbol("t", positive=True, real=True)
|
|
181
|
+
Y = {y: sp.Function(y, positive=True, real=True) for y in ynames}
|
|
182
|
+
Y_0 = {
|
|
183
|
+
f"{y}_0": sp.Symbol(f"{y}_0", positive=True, real=True)
|
|
184
|
+
for y in ynames
|
|
185
|
+
}
|
|
186
|
+
theta = {p: sp.Symbol(p, positive=True, real=True) for p in thetanames}
|
|
187
|
+
|
|
188
|
+
symbols = (t, Y, Y_0, theta)
|
|
189
|
+
|
|
190
|
+
return symbols
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def compiler_recipe(self):
|
|
194
|
+
return {
|
|
195
|
+
"F":self.compile_model,
|
|
196
|
+
"t_jump":self.t_jump,
|
|
197
|
+
"F_piecewise":partial(self.jump_solution, funcnames="F t_jump")
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
def solve(self, parameters, y0, x_in):
|
|
201
|
+
odeargs = mappar(
|
|
202
|
+
self.model,
|
|
203
|
+
parameters,
|
|
204
|
+
exclude=["t", "dt", "y", "x_in"],
|
|
205
|
+
to="dict"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
F_piecewise = self.compiled_functions["F_piecewise"]["compiled_function"]
|
|
209
|
+
|
|
210
|
+
# get arguments
|
|
211
|
+
time = np.array(self.x)
|
|
212
|
+
Y_0_values = [v for v in y0.values()]
|
|
213
|
+
|
|
214
|
+
# handle interpolation
|
|
215
|
+
if "exposure" in self.coordinates_input_vars["x_in"]:
|
|
216
|
+
exposure_interpolation = Interpolation(
|
|
217
|
+
xs=self.coordinates_input_vars["x_in"]["exposure"][self.x_dim],
|
|
218
|
+
ys=x_in["exposure"],
|
|
219
|
+
method=self.interpolation_type
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
exposure_interpolation = Interpolation(
|
|
223
|
+
xs=self.x,
|
|
224
|
+
ys=np.full_like(self.x, odeargs["C_0"]), #type: ignore
|
|
225
|
+
method=self.interpolation_type
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# run main loop
|
|
229
|
+
sol = [Y_0_values]
|
|
230
|
+
for i in range(1, len(time)):
|
|
231
|
+
# parse arguments
|
|
232
|
+
C_0 = exposure_interpolation.evaluate(time[i-1])
|
|
233
|
+
odeargs.update({"C_0": C_0}) # type: ignore
|
|
234
|
+
dt = time[i] - time[i-1]
|
|
235
|
+
|
|
236
|
+
# call piecewise function
|
|
237
|
+
y_t = F_piecewise(
|
|
238
|
+
t=dt,
|
|
239
|
+
Y_0=sol[i-1],
|
|
240
|
+
θ=tuple(odeargs.values()), # type: ignore
|
|
241
|
+
ε=1e-14
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
sol.append(y_t)
|
|
245
|
+
|
|
246
|
+
Y_t = np.array(sol)
|
|
247
|
+
|
|
248
|
+
results = {k:y for k, y in zip(y0.keys(), Y_t.T)}
|
|
249
|
+
results["exposure"] = exposure_interpolation.evaluate(np.array(self.x))
|
|
250
|
+
|
|
251
|
+
return results
|
guts_base/plot.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import arviz as az
|
|
3
|
+
from matplotlib import pyplot as plt
|
|
4
|
+
|
|
5
|
+
from pymob.sim.plot import SimulationPlot
|
|
6
|
+
from pymob import SimulationBase
|
|
7
|
+
|
|
8
|
+
def plot_prior_predictions(
|
|
9
|
+
sim: SimulationBase,
|
|
10
|
+
data_vars=["survival"],
|
|
11
|
+
title_func=lambda sp, c: f"{sp.observations.id.values[0]}"
|
|
12
|
+
):
|
|
13
|
+
idata = sim.inferer.prior_predictions() # type: ignore
|
|
14
|
+
def plot_survival_data_probs(idata):
|
|
15
|
+
return idata["survival"] * sim.observations.sel(id=[id_], time=0).survival
|
|
16
|
+
|
|
17
|
+
fig, axes = plt.subplots(
|
|
18
|
+
ncols=sim.observations.dims["id"], nrows=len(data_vars),
|
|
19
|
+
sharex=True, sharey="row", figsize=(30,10), squeeze=False)
|
|
20
|
+
for i, id_ in enumerate(sim.coordinates["id"]):
|
|
21
|
+
simplot = SimulationPlot(
|
|
22
|
+
observations=sim.observations.sel(id=[id_]),
|
|
23
|
+
idata=idata.sel(id=[id_]), # type: ignore
|
|
24
|
+
rows=data_vars,
|
|
25
|
+
coordinates=sim.coordinates,
|
|
26
|
+
config=sim.config,
|
|
27
|
+
obs_idata_map={
|
|
28
|
+
"survival": plot_survival_data_probs
|
|
29
|
+
},
|
|
30
|
+
idata_groups=["prior_model_fits"], # type: ignore
|
|
31
|
+
)
|
|
32
|
+
# replace simplot axis
|
|
33
|
+
for j, k in enumerate(simplot.rows):
|
|
34
|
+
simplot.axes_map[k]["all"] = axes[j][i]
|
|
35
|
+
|
|
36
|
+
simplot.plot_data_variables()
|
|
37
|
+
simplot.set_titles(title_func)
|
|
38
|
+
for j, k in enumerate(simplot.rows):
|
|
39
|
+
if i != 0:
|
|
40
|
+
simplot.axes_map[k]["all"].set_ylabel("")
|
|
41
|
+
if j != 3:
|
|
42
|
+
simplot.axes_map[k]["all"].set_xlabel("")
|
|
43
|
+
|
|
44
|
+
simplot.close()
|
|
45
|
+
|
|
46
|
+
fig.tight_layout()
|
|
47
|
+
fig.savefig(f"{sim.output_path}/combined_prior_predictions.png")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def plot_posterior_predictions(
|
|
51
|
+
sim: SimulationBase,
|
|
52
|
+
data_vars=["survival"],
|
|
53
|
+
title_func=lambda sp, c: f"{sp.observations.id.values[0]}",
|
|
54
|
+
groups=["posterior_model_fits", "posterior_predictive"],
|
|
55
|
+
):
|
|
56
|
+
fig, axes = plt.subplots(
|
|
57
|
+
ncols=sim.observations.dims["id"], nrows=len(data_vars),
|
|
58
|
+
sharex=True, sharey="row", figsize=(30,10), squeeze=False
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def plot_survival_data_probs_and_preds(dataset):
|
|
62
|
+
if dataset.attrs["group"] == "posterior_model_fits":
|
|
63
|
+
return dataset["survival"] * sim.observations.sel(id=[id_], time=0).survival
|
|
64
|
+
if dataset.attrs["group"] == "posterior_predictive":
|
|
65
|
+
return dataset["survival"]
|
|
66
|
+
|
|
67
|
+
for i, id_ in enumerate(sim.coordinates["id"]):
|
|
68
|
+
simplot = SimulationPlot(
|
|
69
|
+
observations=sim.observations.sel(id=[id_]),
|
|
70
|
+
idata=sim.inferer.idata.sel(id=[id_]), # type: ignore
|
|
71
|
+
rows=data_vars,
|
|
72
|
+
coordinates=sim.coordinates,
|
|
73
|
+
config=sim.config,
|
|
74
|
+
obs_idata_map={
|
|
75
|
+
"survival": plot_survival_data_probs_and_preds,
|
|
76
|
+
},
|
|
77
|
+
idata_groups=groups, # type: ignore
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# replace simplot axis
|
|
81
|
+
for j, k in enumerate(simplot.rows):
|
|
82
|
+
simplot.axes_map[k]["all"] = axes[j][i]
|
|
83
|
+
|
|
84
|
+
simplot.plot_data_variables()
|
|
85
|
+
simplot.set_titles(title_func)
|
|
86
|
+
for j, k in enumerate(simplot.rows):
|
|
87
|
+
if i != 0:
|
|
88
|
+
simplot.axes_map[k]["all"].set_ylabel("")
|
|
89
|
+
if j != 3:
|
|
90
|
+
simplot.axes_map[k]["all"].set_xlabel("")
|
|
91
|
+
|
|
92
|
+
simplot.close()
|
|
93
|
+
|
|
94
|
+
fig.tight_layout()
|
|
95
|
+
fig.savefig(f"{sim.output_path}/combined_posterior_predictions.png")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def plot_survival(sim, results):
|
|
99
|
+
fig, ax = plt.subplots(1,1)
|
|
100
|
+
obs = sim.observations.survival / sim.observations.subject_count
|
|
101
|
+
ax.plot(sim.observations.time, obs.T, marker="o", color="black")
|
|
102
|
+
ax.plot(results.time, results.survival.T, color="black")
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def plot_survival_multipanel(sim, results, ncols=6, title=lambda _id: _id):
|
|
106
|
+
|
|
107
|
+
n_panels = results.sizes["id"]
|
|
108
|
+
|
|
109
|
+
nrows = int(np.ceil(n_panels / ncols))
|
|
110
|
+
|
|
111
|
+
fig, axes = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=(ncols*2+2, nrows*1.5+2))
|
|
112
|
+
axes = axes.flatten()
|
|
113
|
+
mean = results.mean(("chain", "draw"))
|
|
114
|
+
hdi = az.hdi(results, 0.95)
|
|
115
|
+
survival = sim.observations.survival / sim.observations.survival.isel(time=0)
|
|
116
|
+
|
|
117
|
+
plot_kwargs = {"color": "black"}
|
|
118
|
+
# param_cycler = plt.rcParams['axes.prop_cycle']
|
|
119
|
+
for _id, ax in zip(sim.observations.id.values, axes):
|
|
120
|
+
ax.set_ylim(-0.05,1.05)
|
|
121
|
+
|
|
122
|
+
# TODO: use time unit from observations (?)
|
|
123
|
+
ax.set_xlabel("Time")
|
|
124
|
+
ax.set_ylabel("Survival")
|
|
125
|
+
ax.plot(mean.time, mean.sel(id=_id).survival.T, **plot_kwargs)
|
|
126
|
+
ax.fill_between(hdi.time, *hdi.sel(id=_id).survival.T, alpha=.5, **plot_kwargs)
|
|
127
|
+
ax.plot(survival.time, survival.sel(id=_id).T, ls="", marker="o", alpha=.5, **plot_kwargs)
|
|
128
|
+
ax.set_title(title(_id))
|
|
129
|
+
|
|
130
|
+
out = f"{sim.output_path}/survival_multipanel.png"
|
|
131
|
+
fig.tight_layout()
|
|
132
|
+
fig.savefig(out)
|
|
133
|
+
|
|
134
|
+
return out
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def multipanel_title(sim, _id):
|
|
138
|
+
oid = sim.observations.sel(id=_id)
|
|
139
|
+
exposure_path = oid.exposure_path.values
|
|
140
|
+
rac = np.round(oid.concentration_closer_x_rac.max().values * 100, 3)
|
|
141
|
+
return "{ep}\n{c} %RAC".format(ep=str(exposure_path), c=str(rac))
|
|
142
|
+
|
|
143
|
+
def plot_intermediate_results(sim: SimulationBase, id=0):
|
|
144
|
+
e = sim.dispatch()
|
|
145
|
+
e()
|
|
146
|
+
e.results
|
|
147
|
+
|
|
148
|
+
results = e.results.isel(id=[id])
|
|
149
|
+
|
|
150
|
+
datavars = sim.config.data_structure.data_variables
|
|
151
|
+
fig, axes = plt.subplots(len(datavars), 1)
|
|
152
|
+
|
|
153
|
+
for ax, dv in zip(axes, datavars):
|
|
154
|
+
res = results[dv]
|
|
155
|
+
if len(res.shape) > 2:
|
|
156
|
+
res_ = res.transpose(..., "time", sim.config.simulation.batch_dimension)
|
|
157
|
+
for r in res_:
|
|
158
|
+
ax.plot(r.time, r, alpha=.5)
|
|
159
|
+
else:
|
|
160
|
+
ax.plot(results.time, res.transpose("time", sim.config.simulation.batch_dimension), alpha=.5)
|
|
161
|
+
|
|
162
|
+
ax.set_ylabel(dv)
|