guts-base 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

guts_base/mod.py ADDED
@@ -0,0 +1,251 @@
1
+ """
2
+ GUTS models
3
+
4
+ TODO: Import guts models from mempy and update to work well with jax
5
+ TODO: Based on this implement the bufferguts model in mempy and import in the
6
+ bufferguts case-study
7
+ """
8
+ from functools import partial
9
+ from pathlib import Path
10
+
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import sympy as sp
14
+ import numpy as np
15
+ from scipy.interpolate import interp1d
16
+
17
+ from pymob.solvers.base import mappar
18
+ from pymob.solvers.symbolic import (
19
+ PiecewiseSymbolicODESolver, FunctionPythonCode, get_return_arguments, dX_dt2X
20
+ )
21
+
22
+ from mempy.model import (
23
+ RED_IT,
24
+ RED_SD,
25
+ RED_IT_DA,
26
+ RED_SD_DA,
27
+ RED_IT_IA,
28
+ RED_SD_IA,
29
+ BufferGUTS_IT,
30
+ BufferGUTS_IT_CA,
31
+ BufferGUTS_IT_DA
32
+ )
33
+
34
+ red_sd = RED_SD._rhs_jax
35
+ red_it = RED_IT._rhs_jax
36
+
37
+
38
+ def p_survival(results, t, interpolation, z, k_k, h_b):
39
+ """Computes the stochastic death survival probability after computing
40
+
41
+ """
42
+ # calculate survival
43
+ results["exposure"] = jax.vmap(interpolation.evaluate)(t)
44
+ p_surv = survival_jax(t, results["D"], z, k_k, h_b)
45
+ results["survival"] = p_surv
46
+ results["lethality"] = 1 - p_surv
47
+ return results
48
+
49
+ def it_post_processing(results, t, interpolation, alpha, beta, h_b, eps):
50
+ results["exposure"] = jax.vmap(interpolation.evaluate)(t)
51
+ p_surv = survival_IT_jax(t, results["D"], alpha, beta, h_b, eps)
52
+ results["survival"] = p_surv
53
+ results["H"] = - jnp.log(p_surv)
54
+ return results
55
+
56
+ def post_exposure(results, t, interpolation):
57
+ results["survival"] = jnp.exp(-results["H"])
58
+ results["exposure"] = jax.vmap(interpolation.evaluate)(t)
59
+ return results
60
+
61
+
62
+ def no_post_processing(results):
63
+ return results
64
+
65
+
66
+ @jax.jit
67
+ def survival_jax(t, damage, z, kk, h_b):
68
+ """
69
+ survival probability derived from hazard
70
+ first calculate cumulative Hazard by integrating hazard cumulatively over t
71
+ then calculate the resulting survival probability
72
+ It was checked that `survival_jax` behaves exactly the same as `survival`
73
+ """
74
+ hazard = kk * jnp.where(damage - z < 0, 0, damage - z) + h_b
75
+ H = jnp.array([jax.scipy.integrate.trapezoid(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
76
+ # H = jnp.array([jnp.trapz(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
77
+ S = jnp.exp(-H)
78
+
79
+ return S
80
+
81
+ @jax.jit
82
+ def survival_IT_jax(t, damage, alpha, beta, h_b, eps):
83
+ d_max = jnp.squeeze(jnp.array([jnp.max(damage[:i+1])+eps for i in range(len(t))]))
84
+ F = jnp.where(d_max > 0, 1.0 / (1.0 + (d_max / alpha) ** -beta), 0)
85
+ S = 1.0 * (jnp.array([1.0], dtype=float) - F) * jnp.exp(-h_b * t)
86
+ return S
87
+
88
+ def guts_jax(t, y, C_0, k_d, z, b, h_b):
89
+ # for constant exposure
90
+ D, H, S = y
91
+ dD_dt = k_d * (C_0 - D)
92
+
93
+ switchDS = 0.5 + (1 / jnp.pi) * jnp.arctan(1e16 * (D - z))
94
+ dH_dt = (b * switchDS * (D - z) + h_b)
95
+
96
+ dS_dt = -dH_dt * S
97
+
98
+ return dD_dt, dH_dt, dS_dt
99
+
100
+
101
+ def RED_IT(t, y, x_in, kd):
102
+ D, = y
103
+ C = x_in.evaluate(t)
104
+
105
+ dD_dt = kd * (C - D)
106
+
107
+ return (dD_dt, )
108
+
109
+
110
+ def guts_variable_exposure(t, y, x_in, k_d, z, b, h_b):
111
+ # for constant exposure
112
+ D, H, S = y
113
+ C = x_in.evaluate(t)
114
+ dD_dt = k_d * (C - D)
115
+
116
+ switchDS = 0.5 + (1 / jnp.pi) * jnp.arctan(1e16 * (D - z))
117
+ dH_dt = (b * switchDS * (D - z) + h_b)
118
+
119
+ dS_dt = -dH_dt * S
120
+
121
+ return dD_dt, dH_dt, dS_dt
122
+
123
+
124
+ class Interpolation:
125
+ def __init__(self, xs, ys, method="previous") -> None:
126
+ self.f = interp1d(
127
+ x=xs,
128
+ y=ys,
129
+ axis=0,
130
+ kind=method
131
+ )
132
+
133
+ def evaluate(self, t) -> np.ndarray:
134
+ return self.f(t)
135
+
136
+
137
+ class PiecewiseSymbolicSolver(PiecewiseSymbolicODESolver):
138
+ interpolation_type = "previous"
139
+
140
+ def t_jump(self, func_name, compiled_functions={}):
141
+ t, Y, Y_0, theta = self.define_symbols()
142
+
143
+ D_t = compiled_functions["F"]["algebraic_solutions"]["D"]
144
+ z = theta["z"]
145
+ eq = sp.Eq(D_t.rhs, z).expand()
146
+
147
+ t_0 = sp.solve(eq, t)
148
+
149
+ assert len(t_0) == 1
150
+ func = sp.simplify(t_0[0])
151
+
152
+ python_code = FunctionPythonCode(
153
+ func_name=func_name,
154
+ lhs_0=("Y_0", tuple(Y_0.keys())),
155
+ theta=("theta", tuple(theta.keys())),
156
+ lhs=("Y",),
157
+ rhs=(func,),
158
+ expand_arguments=False,
159
+ modules=("numpy","scipy"),
160
+ docstring=""
161
+ )
162
+
163
+ tex = self.to_latex(solutions=[sp.Eq(sp.Symbol(func_name), func)])
164
+ code_file = Path(self.output_path, f"{func_name}.tex")
165
+ with open(code_file, "w") as f:
166
+ f.writelines(tex)
167
+
168
+ return func, python_code
169
+
170
+ def define_symbols(self):
171
+ """Define the necessary symbols solely based on the function"""
172
+ thetanames = mappar(
173
+ self.model, {},
174
+ exclude=["t", "dt", "y", "x_in", "Y", "X"],
175
+ to="names"
176
+ )
177
+ ynames = [dX_dt2X(a) for a in get_return_arguments(self.model)]
178
+
179
+ # define symbols for t, Y, Y_0 and theta
180
+ t = sp.Symbol("t", positive=True, real=True)
181
+ Y = {y: sp.Function(y, positive=True, real=True) for y in ynames}
182
+ Y_0 = {
183
+ f"{y}_0": sp.Symbol(f"{y}_0", positive=True, real=True)
184
+ for y in ynames
185
+ }
186
+ theta = {p: sp.Symbol(p, positive=True, real=True) for p in thetanames}
187
+
188
+ symbols = (t, Y, Y_0, theta)
189
+
190
+ return symbols
191
+
192
+ @property
193
+ def compiler_recipe(self):
194
+ return {
195
+ "F":self.compile_model,
196
+ "t_jump":self.t_jump,
197
+ "F_piecewise":partial(self.jump_solution, funcnames="F t_jump")
198
+ }
199
+
200
+ def solve(self, parameters, y0, x_in):
201
+ odeargs = mappar(
202
+ self.model,
203
+ parameters,
204
+ exclude=["t", "dt", "y", "x_in"],
205
+ to="dict"
206
+ )
207
+
208
+ F_piecewise = self.compiled_functions["F_piecewise"]["compiled_function"]
209
+
210
+ # get arguments
211
+ time = np.array(self.x)
212
+ Y_0_values = [v for v in y0.values()]
213
+
214
+ # handle interpolation
215
+ if "exposure" in self.coordinates_input_vars["x_in"]:
216
+ exposure_interpolation = Interpolation(
217
+ xs=self.coordinates_input_vars["x_in"]["exposure"][self.x_dim],
218
+ ys=x_in["exposure"],
219
+ method=self.interpolation_type
220
+ )
221
+ else:
222
+ exposure_interpolation = Interpolation(
223
+ xs=self.x,
224
+ ys=np.full_like(self.x, odeargs["C_0"]), #type: ignore
225
+ method=self.interpolation_type
226
+ )
227
+
228
+ # run main loop
229
+ sol = [Y_0_values]
230
+ for i in range(1, len(time)):
231
+ # parse arguments
232
+ C_0 = exposure_interpolation.evaluate(time[i-1])
233
+ odeargs.update({"C_0": C_0}) # type: ignore
234
+ dt = time[i] - time[i-1]
235
+
236
+ # call piecewise function
237
+ y_t = F_piecewise(
238
+ t=dt,
239
+ Y_0=sol[i-1],
240
+ θ=tuple(odeargs.values()), # type: ignore
241
+ ε=1e-14
242
+ )
243
+
244
+ sol.append(y_t)
245
+
246
+ Y_t = np.array(sol)
247
+
248
+ results = {k:y for k, y in zip(y0.keys(), Y_t.T)}
249
+ results["exposure"] = exposure_interpolation.evaluate(np.array(self.x))
250
+
251
+ return results
guts_base/plot.py ADDED
@@ -0,0 +1,162 @@
1
+ import numpy as np
2
+ import arviz as az
3
+ from matplotlib import pyplot as plt
4
+
5
+ from pymob.sim.plot import SimulationPlot
6
+ from pymob import SimulationBase
7
+
8
+ def plot_prior_predictions(
9
+ sim: SimulationBase,
10
+ data_vars=["survival"],
11
+ title_func=lambda sp, c: f"{sp.observations.id.values[0]}"
12
+ ):
13
+ idata = sim.inferer.prior_predictions() # type: ignore
14
+ def plot_survival_data_probs(idata):
15
+ return idata["survival"] * sim.observations.sel(id=[id_], time=0).survival
16
+
17
+ fig, axes = plt.subplots(
18
+ ncols=sim.observations.dims["id"], nrows=len(data_vars),
19
+ sharex=True, sharey="row", figsize=(30,10), squeeze=False)
20
+ for i, id_ in enumerate(sim.coordinates["id"]):
21
+ simplot = SimulationPlot(
22
+ observations=sim.observations.sel(id=[id_]),
23
+ idata=idata.sel(id=[id_]), # type: ignore
24
+ rows=data_vars,
25
+ coordinates=sim.coordinates,
26
+ config=sim.config,
27
+ obs_idata_map={
28
+ "survival": plot_survival_data_probs
29
+ },
30
+ idata_groups=["prior_model_fits"], # type: ignore
31
+ )
32
+ # replace simplot axis
33
+ for j, k in enumerate(simplot.rows):
34
+ simplot.axes_map[k]["all"] = axes[j][i]
35
+
36
+ simplot.plot_data_variables()
37
+ simplot.set_titles(title_func)
38
+ for j, k in enumerate(simplot.rows):
39
+ if i != 0:
40
+ simplot.axes_map[k]["all"].set_ylabel("")
41
+ if j != 3:
42
+ simplot.axes_map[k]["all"].set_xlabel("")
43
+
44
+ simplot.close()
45
+
46
+ fig.tight_layout()
47
+ fig.savefig(f"{sim.output_path}/combined_prior_predictions.png")
48
+
49
+
50
+ def plot_posterior_predictions(
51
+ sim: SimulationBase,
52
+ data_vars=["survival"],
53
+ title_func=lambda sp, c: f"{sp.observations.id.values[0]}",
54
+ groups=["posterior_model_fits", "posterior_predictive"],
55
+ ):
56
+ fig, axes = plt.subplots(
57
+ ncols=sim.observations.dims["id"], nrows=len(data_vars),
58
+ sharex=True, sharey="row", figsize=(30,10), squeeze=False
59
+ )
60
+
61
+ def plot_survival_data_probs_and_preds(dataset):
62
+ if dataset.attrs["group"] == "posterior_model_fits":
63
+ return dataset["survival"] * sim.observations.sel(id=[id_], time=0).survival
64
+ if dataset.attrs["group"] == "posterior_predictive":
65
+ return dataset["survival"]
66
+
67
+ for i, id_ in enumerate(sim.coordinates["id"]):
68
+ simplot = SimulationPlot(
69
+ observations=sim.observations.sel(id=[id_]),
70
+ idata=sim.inferer.idata.sel(id=[id_]), # type: ignore
71
+ rows=data_vars,
72
+ coordinates=sim.coordinates,
73
+ config=sim.config,
74
+ obs_idata_map={
75
+ "survival": plot_survival_data_probs_and_preds,
76
+ },
77
+ idata_groups=groups, # type: ignore
78
+ )
79
+
80
+ # replace simplot axis
81
+ for j, k in enumerate(simplot.rows):
82
+ simplot.axes_map[k]["all"] = axes[j][i]
83
+
84
+ simplot.plot_data_variables()
85
+ simplot.set_titles(title_func)
86
+ for j, k in enumerate(simplot.rows):
87
+ if i != 0:
88
+ simplot.axes_map[k]["all"].set_ylabel("")
89
+ if j != 3:
90
+ simplot.axes_map[k]["all"].set_xlabel("")
91
+
92
+ simplot.close()
93
+
94
+ fig.tight_layout()
95
+ fig.savefig(f"{sim.output_path}/combined_posterior_predictions.png")
96
+
97
+
98
+ def plot_survival(sim, results):
99
+ fig, ax = plt.subplots(1,1)
100
+ obs = sim.observations.survival / sim.observations.subject_count
101
+ ax.plot(sim.observations.time, obs.T, marker="o", color="black")
102
+ ax.plot(results.time, results.survival.T, color="black")
103
+
104
+
105
+ def plot_survival_multipanel(sim, results, ncols=6, title=lambda _id: _id):
106
+
107
+ n_panels = results.sizes["id"]
108
+
109
+ nrows = int(np.ceil(n_panels / ncols))
110
+
111
+ fig, axes = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=(ncols*2+2, nrows*1.5+2))
112
+ axes = axes.flatten()
113
+ mean = results.mean(("chain", "draw"))
114
+ hdi = az.hdi(results, 0.95)
115
+ survival = sim.observations.survival / sim.observations.survival.isel(time=0)
116
+
117
+ plot_kwargs = {"color": "black"}
118
+ # param_cycler = plt.rcParams['axes.prop_cycle']
119
+ for _id, ax in zip(sim.observations.id.values, axes):
120
+ ax.set_ylim(-0.05,1.05)
121
+
122
+ # TODO: use time unit from observations (?)
123
+ ax.set_xlabel("Time")
124
+ ax.set_ylabel("Survival")
125
+ ax.plot(mean.time, mean.sel(id=_id).survival.T, **plot_kwargs)
126
+ ax.fill_between(hdi.time, *hdi.sel(id=_id).survival.T, alpha=.5, **plot_kwargs)
127
+ ax.plot(survival.time, survival.sel(id=_id).T, ls="", marker="o", alpha=.5, **plot_kwargs)
128
+ ax.set_title(title(_id))
129
+
130
+ out = f"{sim.output_path}/survival_multipanel.png"
131
+ fig.tight_layout()
132
+ fig.savefig(out)
133
+
134
+ return out
135
+
136
+
137
+ def multipanel_title(sim, _id):
138
+ oid = sim.observations.sel(id=_id)
139
+ exposure_path = oid.exposure_path.values
140
+ rac = np.round(oid.concentration_closer_x_rac.max().values * 100, 3)
141
+ return "{ep}\n{c} %RAC".format(ep=str(exposure_path), c=str(rac))
142
+
143
+ def plot_intermediate_results(sim: SimulationBase, id=0):
144
+ e = sim.dispatch()
145
+ e()
146
+ e.results
147
+
148
+ results = e.results.isel(id=[id])
149
+
150
+ datavars = sim.config.data_structure.data_variables
151
+ fig, axes = plt.subplots(len(datavars), 1)
152
+
153
+ for ax, dv in zip(axes, datavars):
154
+ res = results[dv]
155
+ if len(res.shape) > 2:
156
+ res_ = res.transpose(..., "time", sim.config.simulation.batch_dimension)
157
+ for r in res_:
158
+ ax.plot(r.time, r, alpha=.5)
159
+ else:
160
+ ax.plot(results.time, res.transpose("time", sim.config.simulation.batch_dimension), alpha=.5)
161
+
162
+ ax.set_ylabel(dv)