guts-base 2.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guts_base/__init__.py +15 -0
- guts_base/data/__init__.py +35 -0
- guts_base/data/expydb.py +248 -0
- guts_base/data/generator.py +191 -0
- guts_base/data/openguts.py +296 -0
- guts_base/data/preprocessing.py +55 -0
- guts_base/data/survival.py +148 -0
- guts_base/data/time_of_death.py +595 -0
- guts_base/data/utils.py +8 -0
- guts_base/mod.py +332 -0
- guts_base/plot.py +201 -0
- guts_base/prob/__init__.py +13 -0
- guts_base/prob/binom.py +18 -0
- guts_base/prob/conditional_binom.py +118 -0
- guts_base/prob/conditional_binom_mv.py +233 -0
- guts_base/prob/predictions.py +164 -0
- guts_base/sim/__init__.py +28 -0
- guts_base/sim/base.py +1286 -0
- guts_base/sim/config.py +170 -0
- guts_base/sim/constructors.py +31 -0
- guts_base/sim/ecx.py +585 -0
- guts_base/sim/mempy.py +290 -0
- guts_base/sim/report.py +405 -0
- guts_base/sim/transformer.py +548 -0
- guts_base/sim/units.py +313 -0
- guts_base/sim/utils.py +10 -0
- guts_base-2.0.0b0.dist-info/METADATA +853 -0
- guts_base-2.0.0b0.dist-info/RECORD +32 -0
- guts_base-2.0.0b0.dist-info/WHEEL +5 -0
- guts_base-2.0.0b0.dist-info/entry_points.txt +3 -0
- guts_base-2.0.0b0.dist-info/licenses/LICENSE +674 -0
- guts_base-2.0.0b0.dist-info/top_level.txt +1 -0
guts_base/mod.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import sympy as sp
|
|
7
|
+
import numpy as np
|
|
8
|
+
from scipy.interpolate import interp1d
|
|
9
|
+
from guts_base.prob import likelihood
|
|
10
|
+
|
|
11
|
+
from pymob.solvers.base import mappar
|
|
12
|
+
from pymob.solvers.symbolic import (
|
|
13
|
+
PiecewiseSymbolicODESolver, FunctionPythonCode, get_return_arguments, dX_dt2X
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
_params_info_defaults = {
|
|
17
|
+
"initial": 1.0,
|
|
18
|
+
"name": None,
|
|
19
|
+
"min": None,
|
|
20
|
+
"max": None,
|
|
21
|
+
"prior": None,
|
|
22
|
+
"dims": None,
|
|
23
|
+
"vary": True,
|
|
24
|
+
"module": None,
|
|
25
|
+
"unit": None,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
class RED_SD:
|
|
29
|
+
"""Simplest guts model, mainly for testing"""
|
|
30
|
+
extra_dim = "substance"
|
|
31
|
+
_likelihood_func_jax = likelihood
|
|
32
|
+
_it_model = False
|
|
33
|
+
_params_info_defaults = _params_info_defaults
|
|
34
|
+
|
|
35
|
+
params_info = {
|
|
36
|
+
"hb": dict(name="hb", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}", module="background-mortality"),
|
|
37
|
+
"kd": dict(name="kd", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}", module="tktd"),
|
|
38
|
+
"m": dict(name="m", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="{X}", module="tktd"),
|
|
39
|
+
"b": dict(name="b", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}/{X}", module="tktd"),
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
state_variables = {
|
|
43
|
+
"exposure": dict(dimensions=["id", "time", "substance"], observed=False),
|
|
44
|
+
"D": dict(dimensions=["id", "time"], observed=False, y0=[0.0]),
|
|
45
|
+
"H": dict(dimensions=["id", "time"], observed=False, y0=[0.0]),
|
|
46
|
+
"survival": dict(dimensions=["id", "time"], observed=True)
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def _rhs_jax(t, y, x_in, kd, b, m, hb):
|
|
51
|
+
D, H = y
|
|
52
|
+
dD_dt = kd * (x_in.evaluate(t) - D)
|
|
53
|
+
dH_dt = b * jnp.maximum(0.0, D - m) + hb
|
|
54
|
+
return dD_dt, dH_dt
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _solver_post_processing(results, t, interpolation):
|
|
58
|
+
results["survival"] = jnp.exp(-results["H"])
|
|
59
|
+
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
60
|
+
return results
|
|
61
|
+
|
|
62
|
+
class RED_IT:
|
|
63
|
+
"""Simplest guts model, mainly for testing"""
|
|
64
|
+
extra_dim = "substance"
|
|
65
|
+
_likelihood_func_jax = likelihood
|
|
66
|
+
_it_model = True
|
|
67
|
+
_params_info_defaults = _params_info_defaults
|
|
68
|
+
|
|
69
|
+
params_info = {
|
|
70
|
+
"hb": dict(name="hb", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, module="background-mortality"),
|
|
71
|
+
"kd": dict(name="kd", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, module="tktd"),
|
|
72
|
+
"m": dict(name="m", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, module="tktd"),
|
|
73
|
+
"beta": dict(name="beta", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, module="tktd"),
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
state_variables = {
|
|
77
|
+
"exposure": dict(dimensions=["id", "time", "substance"], observed=False),
|
|
78
|
+
"D": dict(dimensions=["id", "time"], observed=False, y0=[0.0]),
|
|
79
|
+
"H": dict(dimensions=["id", "time"], observed=False),
|
|
80
|
+
"survival": dict(dimensions=["id", "time"], observed=True)
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def _rhs_jax(t, y, x_in, kd):
|
|
85
|
+
D, = y
|
|
86
|
+
C = x_in.evaluate(t)
|
|
87
|
+
|
|
88
|
+
dD_dt = kd * (C - D)
|
|
89
|
+
|
|
90
|
+
return (dD_dt, )
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def _solver_post_processing(results, t, interpolation, m, beta, hb, eps):
|
|
94
|
+
"""
|
|
95
|
+
TODO: Try alternative formulation. This is computationally simpler and numerically
|
|
96
|
+
more stable:
|
|
97
|
+
log S = log 1.0 + log (1.0 - F) + log exp -hb * t = 0 + log (1.0 - F) - hb * t
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
d_max = jnp.squeeze(jnp.array([jnp.max(results["D"][:i+1])+eps for i in range(len(t))]))
|
|
101
|
+
F = jnp.where(d_max > 0, 1.0 / (1.0 + (d_max / m) ** -beta), 0)
|
|
102
|
+
S = 1.0 * (jnp.array([1.0], dtype=float) - F) * jnp.exp(-hb * t)
|
|
103
|
+
results["H"] = - jnp.log(S)
|
|
104
|
+
results["survival"] = S
|
|
105
|
+
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
106
|
+
|
|
107
|
+
return results
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class RED_SD_DA:
|
|
111
|
+
"""Simplest guts model, mainly for testing"""
|
|
112
|
+
extra_dim = "substance"
|
|
113
|
+
_likelihood_func_jax = likelihood
|
|
114
|
+
_it_model = False
|
|
115
|
+
_params_info_defaults = _params_info_defaults
|
|
116
|
+
params_info = {}
|
|
117
|
+
|
|
118
|
+
def __init__(self, num_expos = 1):
|
|
119
|
+
for i in range(num_expos):
|
|
120
|
+
self.params_info[f'kd{i+1}'] = {
|
|
121
|
+
'name':f'kd{i+1}', 'min':1.0e-3, 'max':1.0e3, 'initial':1.0,
|
|
122
|
+
'vary':True, "dims": "substance", "module": "tktd",
|
|
123
|
+
"unit": "1/{T}"
|
|
124
|
+
}
|
|
125
|
+
if i == 0:
|
|
126
|
+
self.params_info[f'w{i+1}'] = {
|
|
127
|
+
'name':f'w{i+1}', 'initial':1.0, 'vary':False,
|
|
128
|
+
"dims": "substance", "module": "tktd",
|
|
129
|
+
"unit": "{X}/{X_i}",
|
|
130
|
+
}
|
|
131
|
+
else:
|
|
132
|
+
self.params_info[f'w{i+1}'] = {
|
|
133
|
+
'name':f'w{i+1}', 'min':1.0e-3, 'max':1.0e3, 'initial':1.0,
|
|
134
|
+
'vary':True, "dims": "substance", "module": "tktd",
|
|
135
|
+
"unit": "{X}/{X_i}"
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
self.params_info.update({
|
|
139
|
+
"hb": dict(name="hb", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}", module="background-mortality"),
|
|
140
|
+
"m": dict(name="m", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="{X}", module="tktd"),
|
|
141
|
+
"b": dict(name="b", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}/{X}", module="tktd"),
|
|
142
|
+
})
|
|
143
|
+
|
|
144
|
+
self.state_variables = {
|
|
145
|
+
"exposure": dict(dimensions=["id", "time", "substance"], unit="{Xi}", observed=False),
|
|
146
|
+
"D": {"dimensions": ["id", "time", "substance"], "observed": False, "y0": [0.0] * num_expos},
|
|
147
|
+
"H": dict(dimensions=["id", "time"], observed=False, y0=[0.0]),
|
|
148
|
+
"survival": dict(dimensions=["id", "time"], observed=True)
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def _rhs_jax(t, y, x_in, kd, w, b, m, hb):
|
|
153
|
+
D, H = y
|
|
154
|
+
dD_dt = kd * (x_in.evaluate(t) - D)
|
|
155
|
+
dH_dt = b * jnp.maximum(0.0, jnp.sum(w * D) - m) + hb
|
|
156
|
+
return dD_dt, dH_dt
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def _solver_post_processing(results, t, interpolation):
|
|
160
|
+
results["survival"] = jnp.exp(-results["H"])
|
|
161
|
+
results["exposure"] = jax.vmap(interpolation.evaluate)(t)
|
|
162
|
+
return results
|
|
163
|
+
|
|
164
|
+
class RED_SD_explicit_units(RED_SD):
|
|
165
|
+
def __init__(self):
|
|
166
|
+
self.params_info["hb"]["unit"] = "1/day"
|
|
167
|
+
self.params_info["b"]["unit"] = "1/mg/day"
|
|
168
|
+
self.params_info["kd"]["unit"] = "1/day"
|
|
169
|
+
self.params_info["m"]["unit"] = "mg"
|
|
170
|
+
|
|
171
|
+
red_sd = RED_SD._rhs_jax
|
|
172
|
+
red_sd_post_processing = RED_SD._solver_post_processing
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
red_sd_da = RED_SD_DA._rhs_jax
|
|
176
|
+
red_sd_da_post_processing = RED_SD_DA._solver_post_processing
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def guts_constant_exposure(t, y, C_0, k_d, z, b, h_b):
|
|
180
|
+
# for constant exposure
|
|
181
|
+
D, H, S = y
|
|
182
|
+
dD_dt = k_d * (C_0 - D)
|
|
183
|
+
|
|
184
|
+
switchDS = 0.5 + (1 / jnp.pi) * jnp.arctan(1e16 * (D - z))
|
|
185
|
+
dH_dt = (b * switchDS * (D - z) + h_b)
|
|
186
|
+
|
|
187
|
+
dS_dt = -dH_dt * S
|
|
188
|
+
|
|
189
|
+
return dD_dt, dH_dt, dS_dt
|
|
190
|
+
|
|
191
|
+
def guts_variable_exposure(t, y, x_in, k_d, z, b, h_b):
|
|
192
|
+
# for constant exposure
|
|
193
|
+
D, H, S = y
|
|
194
|
+
C = x_in.evaluate(t)
|
|
195
|
+
dD_dt = k_d * (C - D)
|
|
196
|
+
|
|
197
|
+
switchDS = 0.5 + (1 / jnp.pi) * jnp.arctan(1e16 * (D - z))
|
|
198
|
+
dH_dt = (b * switchDS * (D - z) + h_b)
|
|
199
|
+
|
|
200
|
+
dS_dt = -dH_dt * S
|
|
201
|
+
|
|
202
|
+
return dD_dt, dH_dt, dS_dt
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class Interpolation:
|
|
206
|
+
def __init__(self, xs, ys, method="previous") -> None:
|
|
207
|
+
self.f = interp1d(
|
|
208
|
+
x=xs,
|
|
209
|
+
y=ys,
|
|
210
|
+
axis=0,
|
|
211
|
+
kind=method
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def evaluate(self, t) -> np.ndarray:
|
|
215
|
+
return self.f(t)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class PiecewiseSymbolicSolver(PiecewiseSymbolicODESolver):
|
|
219
|
+
interpolation_type = "previous"
|
|
220
|
+
|
|
221
|
+
def t_jump(self, func_name, compiled_functions={}):
|
|
222
|
+
t, Y, Y_0, theta = self.define_symbols()
|
|
223
|
+
|
|
224
|
+
D_t = compiled_functions["F"]["algebraic_solutions"]["D"]
|
|
225
|
+
z = theta["z"]
|
|
226
|
+
eq = sp.Eq(D_t.rhs, z).expand()
|
|
227
|
+
|
|
228
|
+
t_0 = sp.solve(eq, t)
|
|
229
|
+
|
|
230
|
+
assert len(t_0) == 1
|
|
231
|
+
func = sp.simplify(t_0[0])
|
|
232
|
+
|
|
233
|
+
python_code = FunctionPythonCode(
|
|
234
|
+
func_name=func_name,
|
|
235
|
+
lhs_0=("Y_0", tuple(Y_0.keys())),
|
|
236
|
+
theta=("theta", tuple(theta.keys())),
|
|
237
|
+
lhs=("Y",),
|
|
238
|
+
rhs=(func,),
|
|
239
|
+
expand_arguments=False,
|
|
240
|
+
modules=("numpy","scipy"),
|
|
241
|
+
docstring=""
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
tex = self.to_latex(solutions=[sp.Eq(sp.Symbol(func_name), func)])
|
|
245
|
+
code_file = Path(self.output_path, f"{func_name}.tex")
|
|
246
|
+
with open(code_file, "w") as f:
|
|
247
|
+
f.writelines(tex)
|
|
248
|
+
|
|
249
|
+
return func, python_code
|
|
250
|
+
|
|
251
|
+
def define_symbols(self):
|
|
252
|
+
"""Define the necessary symbols solely based on the function"""
|
|
253
|
+
thetanames = mappar(
|
|
254
|
+
self.model, {},
|
|
255
|
+
exclude=["t", "dt", "y", "x_in", "Y", "X"],
|
|
256
|
+
to="names"
|
|
257
|
+
)
|
|
258
|
+
ynames = [dX_dt2X(a) for a in get_return_arguments(self.model)]
|
|
259
|
+
|
|
260
|
+
# define symbols for t, Y, Y_0 and theta
|
|
261
|
+
t = sp.Symbol("t", positive=True, real=True)
|
|
262
|
+
Y = {y: sp.Function(y, positive=True, real=True) for y in ynames}
|
|
263
|
+
Y_0 = {
|
|
264
|
+
f"{y}_0": sp.Symbol(f"{y}_0", positive=True, real=True)
|
|
265
|
+
for y in ynames
|
|
266
|
+
}
|
|
267
|
+
theta = {p: sp.Symbol(p, positive=True, real=True) for p in thetanames}
|
|
268
|
+
|
|
269
|
+
symbols = (t, Y, Y_0, theta)
|
|
270
|
+
|
|
271
|
+
return symbols
|
|
272
|
+
|
|
273
|
+
@property
|
|
274
|
+
def compiler_recipe(self):
|
|
275
|
+
return {
|
|
276
|
+
"F":self.compile_model,
|
|
277
|
+
"t_jump":self.t_jump,
|
|
278
|
+
"F_piecewise":partial(self.jump_solution, funcnames="F t_jump")
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
def solve(self, parameters, y0, x_in):
|
|
282
|
+
odeargs = mappar(
|
|
283
|
+
self.model,
|
|
284
|
+
parameters,
|
|
285
|
+
exclude=["t", "dt", "y", "x_in"],
|
|
286
|
+
to="dict"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
F_piecewise = self.compiled_functions["F_piecewise"]["compiled_function"]
|
|
290
|
+
|
|
291
|
+
# get arguments
|
|
292
|
+
time = np.array(self.x)
|
|
293
|
+
Y_0_values = [v for v in y0.values()]
|
|
294
|
+
|
|
295
|
+
# handle interpolation
|
|
296
|
+
if "exposure" in self.coordinates_input_vars["x_in"]:
|
|
297
|
+
exposure_interpolation = Interpolation(
|
|
298
|
+
xs=self.coordinates_input_vars["x_in"]["exposure"][self.x_dim],
|
|
299
|
+
ys=x_in["exposure"],
|
|
300
|
+
method=self.interpolation_type
|
|
301
|
+
)
|
|
302
|
+
else:
|
|
303
|
+
exposure_interpolation = Interpolation(
|
|
304
|
+
xs=self.x,
|
|
305
|
+
ys=np.full_like(self.x, odeargs["C_0"]), #type: ignore
|
|
306
|
+
method=self.interpolation_type
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# run main loop
|
|
310
|
+
sol = [Y_0_values]
|
|
311
|
+
for i in range(1, len(time)):
|
|
312
|
+
# parse arguments
|
|
313
|
+
C_0 = exposure_interpolation.evaluate(time[i-1])
|
|
314
|
+
odeargs.update({"C_0": C_0}) # type: ignore
|
|
315
|
+
dt = time[i] - time[i-1]
|
|
316
|
+
|
|
317
|
+
# call piecewise function
|
|
318
|
+
y_t = F_piecewise(
|
|
319
|
+
t=dt,
|
|
320
|
+
Y_0=sol[i-1],
|
|
321
|
+
θ=tuple(odeargs.values()), # type: ignore
|
|
322
|
+
ε=1e-14
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
sol.append(y_t)
|
|
326
|
+
|
|
327
|
+
Y_t = np.array(sol)
|
|
328
|
+
|
|
329
|
+
results = {k:y for k, y in zip(y0.keys(), Y_t.T)}
|
|
330
|
+
results["exposure"] = exposure_interpolation.evaluate(np.array(self.x))
|
|
331
|
+
|
|
332
|
+
return results
|
guts_base/plot.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from cycler import cycler
|
|
2
|
+
import numpy as np
|
|
3
|
+
import arviz as az
|
|
4
|
+
from matplotlib import pyplot as plt
|
|
5
|
+
|
|
6
|
+
from pymob.sim.plot import SimulationPlot
|
|
7
|
+
from pymob import SimulationBase
|
|
8
|
+
|
|
9
|
+
def plot_prior_predictions(
|
|
10
|
+
sim: SimulationBase,
|
|
11
|
+
data_vars=["survival"],
|
|
12
|
+
title_func=lambda sp, c: f"{sp.observations.id.values[0]}"
|
|
13
|
+
):
|
|
14
|
+
idata = sim.inferer.prior_predictions() # type: ignore
|
|
15
|
+
def plot_survival_data_probs(idata):
|
|
16
|
+
return idata["survival"] * sim.observations.sel(id=[id_], time=0).survival
|
|
17
|
+
|
|
18
|
+
fig, axes = plt.subplots(
|
|
19
|
+
ncols=sim.observations.dims["id"], nrows=len(data_vars),
|
|
20
|
+
sharex=True, sharey="row", figsize=(30,10), squeeze=False)
|
|
21
|
+
for i, id_ in enumerate(sim.coordinates["id"]):
|
|
22
|
+
simplot = SimulationPlot(
|
|
23
|
+
observations=sim.observations.sel(id=[id_]),
|
|
24
|
+
idata=idata.sel(id=[id_]), # type: ignore
|
|
25
|
+
rows=data_vars,
|
|
26
|
+
coordinates=sim.coordinates,
|
|
27
|
+
config=sim.config,
|
|
28
|
+
obs_idata_map={
|
|
29
|
+
"survival": plot_survival_data_probs
|
|
30
|
+
},
|
|
31
|
+
idata_groups=["prior_model_fits"], # type: ignore
|
|
32
|
+
)
|
|
33
|
+
# replace simplot axis
|
|
34
|
+
for j, k in enumerate(simplot.rows):
|
|
35
|
+
simplot.axes_map[k]["all"] = axes[j][i]
|
|
36
|
+
|
|
37
|
+
simplot.plot_data_variables()
|
|
38
|
+
simplot.set_titles(title_func)
|
|
39
|
+
for j, k in enumerate(simplot.rows):
|
|
40
|
+
if i != 0:
|
|
41
|
+
simplot.axes_map[k]["all"].set_ylabel("")
|
|
42
|
+
if j != 3:
|
|
43
|
+
simplot.axes_map[k]["all"].set_xlabel("")
|
|
44
|
+
|
|
45
|
+
simplot.close()
|
|
46
|
+
|
|
47
|
+
fig.tight_layout()
|
|
48
|
+
fig.savefig(f"{sim.output_path}/combined_prior_predictions.png")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def plot_posterior_predictions(
|
|
52
|
+
sim: SimulationBase,
|
|
53
|
+
data_vars=["survival"],
|
|
54
|
+
title_func=lambda sp, c: f"{sp.observations.id.values[0]}",
|
|
55
|
+
groups=["posterior_model_fits", "posterior_predictive"],
|
|
56
|
+
):
|
|
57
|
+
fig, axes = plt.subplots(
|
|
58
|
+
ncols=sim.observations.dims["id"], nrows=len(data_vars),
|
|
59
|
+
sharex=True, sharey="row", figsize=(30,10), squeeze=False
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def plot_survival_data_probs_and_preds(dataset):
|
|
63
|
+
if dataset.attrs["group"] == "posterior_model_fits":
|
|
64
|
+
return dataset["survival"] * sim.observations.sel(id=[id_], time=0).survival
|
|
65
|
+
if dataset.attrs["group"] == "posterior_predictive":
|
|
66
|
+
return dataset["survival"]
|
|
67
|
+
|
|
68
|
+
for i, id_ in enumerate(sim.coordinates["id"]):
|
|
69
|
+
simplot = SimulationPlot(
|
|
70
|
+
observations=sim.observations.sel(id=[id_]),
|
|
71
|
+
idata=sim.inferer.idata.sel(id=[id_]), # type: ignore
|
|
72
|
+
rows=data_vars,
|
|
73
|
+
coordinates=sim.coordinates,
|
|
74
|
+
config=sim.config,
|
|
75
|
+
obs_idata_map={
|
|
76
|
+
"survival": plot_survival_data_probs_and_preds,
|
|
77
|
+
},
|
|
78
|
+
idata_groups=groups, # type: ignore
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# replace simplot axis
|
|
82
|
+
for j, k in enumerate(simplot.rows):
|
|
83
|
+
simplot.axes_map[k]["all"] = axes[j][i]
|
|
84
|
+
|
|
85
|
+
simplot.plot_data_variables()
|
|
86
|
+
simplot.set_titles(title_func)
|
|
87
|
+
for j, k in enumerate(simplot.rows):
|
|
88
|
+
if i != 0:
|
|
89
|
+
simplot.axes_map[k]["all"].set_ylabel("")
|
|
90
|
+
if j != 3:
|
|
91
|
+
simplot.axes_map[k]["all"].set_xlabel("")
|
|
92
|
+
|
|
93
|
+
simplot.close()
|
|
94
|
+
|
|
95
|
+
fig.tight_layout()
|
|
96
|
+
fig.savefig(f"{sim.output_path}/combined_posterior_predictions.png")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def plot_survival(sim: SimulationBase, results):
|
|
100
|
+
fig, ax = plt.subplots(1,1)
|
|
101
|
+
obs = sim.observations.survival / sim.observations.subject_count
|
|
102
|
+
ax.plot(sim.observations.time, obs.T, marker="o", color="black")
|
|
103
|
+
ax.plot(results.time, results.survival.T, color="black")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def plot_survival_multipanel(sim: SimulationBase, results, ncols=6, title=lambda _id: _id, filename="survival_multipanel"):
|
|
107
|
+
|
|
108
|
+
n_panels = results.sizes["id"]
|
|
109
|
+
|
|
110
|
+
nrows = int(np.ceil(n_panels / ncols))
|
|
111
|
+
|
|
112
|
+
fig, axes = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=(ncols*2+2, nrows*1.5+2))
|
|
113
|
+
axes = axes.flatten()
|
|
114
|
+
mean = results.mean(("chain", "draw"))
|
|
115
|
+
hdi = az.hdi(results, 0.95)
|
|
116
|
+
survival = sim.observations.survival / sim.observations.survival.isel(time=0)
|
|
117
|
+
|
|
118
|
+
plot_kwargs = {"color": "black"}
|
|
119
|
+
# param_cycler = plt.rcParams['axes.prop_cycle']
|
|
120
|
+
for _id, ax in zip(sim.observations.id.values, axes):
|
|
121
|
+
ax.set_ylim(-0.05,1.05)
|
|
122
|
+
|
|
123
|
+
ax.set_xlabel(f"Time [{sim.config.guts_base.unit_time}]")
|
|
124
|
+
ax.set_ylabel("Survival")
|
|
125
|
+
ax.plot(mean.time, mean.sel(id=_id).survival.T, **plot_kwargs)
|
|
126
|
+
ax.fill_between(hdi.time, *hdi.sel(id=_id).survival.T, alpha=.5, **plot_kwargs) # type: ignore
|
|
127
|
+
ax.plot(survival.time, survival.sel(id=_id).T, ls="", marker="o", alpha=.5, **plot_kwargs)
|
|
128
|
+
ax.set_title(title(_id))
|
|
129
|
+
|
|
130
|
+
out = f"{sim.output_path}/{filename}.png"
|
|
131
|
+
fig.tight_layout()
|
|
132
|
+
fig.savefig(out)
|
|
133
|
+
|
|
134
|
+
return out
|
|
135
|
+
|
|
136
|
+
def plot_exposure_multipanel(sim: SimulationBase, results, ncols=6, title=lambda _id: _id, filename="exposure_multipanel"):
|
|
137
|
+
|
|
138
|
+
n_panels = results.sizes["id"]
|
|
139
|
+
|
|
140
|
+
nrows = int(np.ceil(n_panels / ncols))
|
|
141
|
+
|
|
142
|
+
fig, axes = plt.subplots(nrows, ncols, sharex=True, figsize=(ncols*2+2, nrows*1.5+2))
|
|
143
|
+
axes = axes.flatten()
|
|
144
|
+
mean = results
|
|
145
|
+
|
|
146
|
+
plot_kwargs = {"color": "black"}
|
|
147
|
+
custom_cycler = (
|
|
148
|
+
cycler(ls=["-", "--", ":", "-."])
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
labels = {}
|
|
152
|
+
for _id, ax in zip(sim.observations.id.values, axes):
|
|
153
|
+
ax.set_prop_cycle(custom_cycler)
|
|
154
|
+
|
|
155
|
+
ax.set_xlabel(f"Time [{sim.config.guts_base.unit_time}]")
|
|
156
|
+
ax.set_ylabel("Exposure")
|
|
157
|
+
for expo in sim.coordinates[sim._exposure_dimension]: # type: ignore
|
|
158
|
+
line, = ax.plot(
|
|
159
|
+
mean.time, mean.sel({"id":_id, sim._exposure_dimension: expo}).exposure, # type: ignore
|
|
160
|
+
**plot_kwargs, label=f"Exposure: {expo}"
|
|
161
|
+
)
|
|
162
|
+
labels.update({f"Exposure: {expo}": line})
|
|
163
|
+
ax.set_title(title(_id))
|
|
164
|
+
|
|
165
|
+
fig.legend(labels.values(), labels.keys(), loc='lower center', fontsize=10, frameon=False)
|
|
166
|
+
out = f"{sim.output_path}/{filename}.png"
|
|
167
|
+
fig.tight_layout(rect=[0,0.05,1.0,1.0], ) # type: ignore
|
|
168
|
+
fig.savefig(out)
|
|
169
|
+
|
|
170
|
+
return out
|
|
171
|
+
|
|
172
|
+
def multipanel_title(sim, _id):
|
|
173
|
+
oid = sim.observations.sel(id=_id)
|
|
174
|
+
exposure_path = oid.exposure_path.values
|
|
175
|
+
rac = np.round(oid.concentration_closer_x_rac.max().values * 100, 3)
|
|
176
|
+
return "{ep}\n{c} %RAC".format(ep=str(exposure_path), c=str(rac))
|
|
177
|
+
|
|
178
|
+
def plot_intermediate_results(sim: SimulationBase, id=0):
|
|
179
|
+
e = sim.dispatch()
|
|
180
|
+
e()
|
|
181
|
+
e.results
|
|
182
|
+
|
|
183
|
+
results = e.results.isel(id=[id])
|
|
184
|
+
|
|
185
|
+
plot_results(results=results, batch_dim=sim.config.simulation.batch_dimension)
|
|
186
|
+
|
|
187
|
+
def plot_results(results, batch_dim="id", axes=None, **plot_kwargs):
|
|
188
|
+
datavars = list(results.data_vars.keys())
|
|
189
|
+
if axes is None:
|
|
190
|
+
fig, axes = plt.subplots(len(datavars), 1)
|
|
191
|
+
|
|
192
|
+
for ax, dv in zip(axes, datavars):
|
|
193
|
+
res = results[dv]
|
|
194
|
+
if len(res.shape) > 2:
|
|
195
|
+
res_ = res.transpose(..., "time", batch_dim)
|
|
196
|
+
for r in res_:
|
|
197
|
+
ax.plot(r.time, r, **plot_kwargs)
|
|
198
|
+
else:
|
|
199
|
+
ax.plot(results.time, res.transpose("time", batch_dim), **plot_kwargs)
|
|
200
|
+
|
|
201
|
+
ax.set_ylabel(dv)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from . import conditional_binom
|
|
2
|
+
from . import conditional_binom_mv
|
|
3
|
+
from . import predictions
|
|
4
|
+
from . import binom
|
|
5
|
+
|
|
6
|
+
from .binom import likelihood
|
|
7
|
+
|
|
8
|
+
from .predictions import (
|
|
9
|
+
survival_predictions,
|
|
10
|
+
posterior_predictions,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .conditional_binom_mv import conditional_survival
|
guts_base/prob/binom.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import numpyro
|
|
2
|
+
|
|
3
|
+
def likelihood(theta, simulation_results, indices, observations, masks, make_predictions):
|
|
4
|
+
"""Uses lookup and error model from the local function context"""
|
|
5
|
+
if make_predictions:
|
|
6
|
+
obs = None
|
|
7
|
+
else:
|
|
8
|
+
obs = observations["survival"]
|
|
9
|
+
|
|
10
|
+
_ = numpyro.sample(
|
|
11
|
+
name="survival" + "_obs",
|
|
12
|
+
fn=numpyro.distributions.Binomial(
|
|
13
|
+
probs=simulation_results["survival"],
|
|
14
|
+
total_count=observations["survivors_at_start"],
|
|
15
|
+
).mask(masks["survival"]),
|
|
16
|
+
obs=obs
|
|
17
|
+
)
|
|
18
|
+
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from guts_base.data.survival import generate_survival_repeated_observations
|
|
2
|
+
from scipy.stats import rv_discrete
|
|
3
|
+
from scipy.stats._discrete_distns import _isintegral
|
|
4
|
+
import numpy as np
|
|
5
|
+
import scipy.stats._boost as _boost
|
|
6
|
+
from scipy import stats
|
|
7
|
+
from matplotlib import pyplot as plt
|
|
8
|
+
|
|
9
|
+
def conditional_prob_neglogexp(p, p_init=1.0, eps=1e-12):
|
|
10
|
+
p_ = np.concatenate([p_init, p[:]])
|
|
11
|
+
# p needs to be clipped zero division does not occurr, if the last p values are zero
|
|
12
|
+
p_clipped = np.clip(p_, eps, 1.0)
|
|
13
|
+
# convert to logscale
|
|
14
|
+
neg_log_p = -np.log(p_clipped)
|
|
15
|
+
# exponent substraction is numerically more stable than division
|
|
16
|
+
return np.exp(neg_log_p[:-1] - neg_log_p[1:])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def conditional_prob(p, p_init=1.0):
|
|
20
|
+
p_ = np.concatenate([p_init, p[:]])
|
|
21
|
+
# divide later though previous probability
|
|
22
|
+
return p_[1:] / p_[:-1]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def conditional_prob_from_neglogp(p, p_init=1.0):
|
|
26
|
+
p_ = np.concatenate([[p_init], p[:]])
|
|
27
|
+
return np.exp(p_[:-1] - p_[1:])
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class conditional_survival_gen(rv_discrete):
|
|
32
|
+
"""
|
|
33
|
+
A scipy distribution for a conditional survival probability distribution
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
k: Array
|
|
38
|
+
Number of repeated positive observations of a quantity that
|
|
39
|
+
can only decrease. k must be monotonically decreasing (e.g. survivors)
|
|
40
|
+
|
|
41
|
+
p: Array
|
|
42
|
+
survival function of the repeated observation.
|
|
43
|
+
p must be monotonically decreasing
|
|
44
|
+
|
|
45
|
+
n_init: int
|
|
46
|
+
The starting number of positive observations (e.g. initial number of organisms
|
|
47
|
+
in a survival trial)
|
|
48
|
+
|
|
49
|
+
p_init: float
|
|
50
|
+
The starting survival probability in an experiment
|
|
51
|
+
|
|
52
|
+
Example
|
|
53
|
+
-------
|
|
54
|
+
Define a survival function (using a beta cdf and use it to make multinomial draws)
|
|
55
|
+
to simulate survivals from repeated observations
|
|
56
|
+
>>> n = 100
|
|
57
|
+
>>> B = stats.beta(5, 5)
|
|
58
|
+
>>> p = 1 - B.cdf(np.linspace(0, 1))
|
|
59
|
+
>>> s = stats.multinomial(n, p=np.diff(p)*-1).rvs()[0]
|
|
60
|
+
>>> s = n - s.cumsum()
|
|
61
|
+
|
|
62
|
+
construct a frozen distribution
|
|
63
|
+
>>> from guts_base.prob import conditional_survival
|
|
64
|
+
>>> S = conditional_survival(p=p[1:], n_init=[n], p_init=[1.0], eps=[1e-12])
|
|
65
|
+
|
|
66
|
+
Compute the pmf
|
|
67
|
+
>>> S.pmf(s)
|
|
68
|
+
|
|
69
|
+
Compute the logpmf
|
|
70
|
+
>>> S.logpmf(s)
|
|
71
|
+
|
|
72
|
+
Draw random samples
|
|
73
|
+
>>> samples = S.rvs(size=(1000, 49))
|
|
74
|
+
|
|
75
|
+
Plot the observational variation of a given survival function under repeated
|
|
76
|
+
Observations
|
|
77
|
+
>>> plt.plot(samples.T, color="black", alpha=.02)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
"""
|
|
81
|
+
def __init__(self, **kwargs):
|
|
82
|
+
# Initialize your custom parameters here
|
|
83
|
+
super().__init__(**kwargs)
|
|
84
|
+
# Set up any custom state needed for sampling
|
|
85
|
+
|
|
86
|
+
def _argcheck(self, p, n_init, p_init, eps):
|
|
87
|
+
return (n_init >= 0) & _isintegral(n_init) & (p >= 0) & (p <= 1)# & (np.diff(np.concatenate([p_init, p])) <= 0)
|
|
88
|
+
|
|
89
|
+
def _pmf(self, x, p, n_init, p_init, eps):
|
|
90
|
+
# nan filling is not necessary, because nans are thrown out this shifts the
|
|
91
|
+
# p vector to where it belongs
|
|
92
|
+
n_ = np.concatenate([n_init[[0]], x[:-1]])
|
|
93
|
+
p_conditional = conditional_prob_neglogexp(p, p_init=p_init[[0]], eps=eps[[0]])
|
|
94
|
+
return _boost._binom_pdf(x, n_, p_conditional)
|
|
95
|
+
|
|
96
|
+
def _rvs(self, p, n_init=10, p_init=[1.0], eps=[1e-12], size=1, random_state=None):
|
|
97
|
+
p_conditional = conditional_prob_neglogexp(p, p_init=p_init[[0]], eps=eps[[0]])
|
|
98
|
+
|
|
99
|
+
# axis-0 is the batch dimension
|
|
100
|
+
# axis-1 is the time dimension (probability)
|
|
101
|
+
L = np.zeros(shape=(*size,))
|
|
102
|
+
L = np.array(L, ndmin=2)
|
|
103
|
+
|
|
104
|
+
for i in range(L.shape[1]):
|
|
105
|
+
# calculate the binomial response of the conditional survival
|
|
106
|
+
# i.e. the probability to die within an interval conditional on
|
|
107
|
+
# having survived until the beginning of that interval
|
|
108
|
+
L[:, i] = random_state.binomial(
|
|
109
|
+
p=1 - p_conditional[i],
|
|
110
|
+
n=n_init[i]-L.sum(axis=1).astype(int)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return n_init-L.cumsum(axis=1)
|
|
114
|
+
|
|
115
|
+
conditional_survival = conditional_survival_gen(name="conditional_survival", )
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
|