guts-base 2.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
guts_base/mod.py ADDED
@@ -0,0 +1,332 @@
1
+ from functools import partial
2
+ from pathlib import Path
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import sympy as sp
7
+ import numpy as np
8
+ from scipy.interpolate import interp1d
9
+ from guts_base.prob import likelihood
10
+
11
+ from pymob.solvers.base import mappar
12
+ from pymob.solvers.symbolic import (
13
+ PiecewiseSymbolicODESolver, FunctionPythonCode, get_return_arguments, dX_dt2X
14
+ )
15
+
16
+ _params_info_defaults = {
17
+ "initial": 1.0,
18
+ "name": None,
19
+ "min": None,
20
+ "max": None,
21
+ "prior": None,
22
+ "dims": None,
23
+ "vary": True,
24
+ "module": None,
25
+ "unit": None,
26
+ }
27
+
28
+ class RED_SD:
29
+ """Simplest guts model, mainly for testing"""
30
+ extra_dim = "substance"
31
+ _likelihood_func_jax = likelihood
32
+ _it_model = False
33
+ _params_info_defaults = _params_info_defaults
34
+
35
+ params_info = {
36
+ "hb": dict(name="hb", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}", module="background-mortality"),
37
+ "kd": dict(name="kd", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}", module="tktd"),
38
+ "m": dict(name="m", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="{X}", module="tktd"),
39
+ "b": dict(name="b", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}/{X}", module="tktd"),
40
+ }
41
+
42
+ state_variables = {
43
+ "exposure": dict(dimensions=["id", "time", "substance"], observed=False),
44
+ "D": dict(dimensions=["id", "time"], observed=False, y0=[0.0]),
45
+ "H": dict(dimensions=["id", "time"], observed=False, y0=[0.0]),
46
+ "survival": dict(dimensions=["id", "time"], observed=True)
47
+ }
48
+
49
+ @staticmethod
50
+ def _rhs_jax(t, y, x_in, kd, b, m, hb):
51
+ D, H = y
52
+ dD_dt = kd * (x_in.evaluate(t) - D)
53
+ dH_dt = b * jnp.maximum(0.0, D - m) + hb
54
+ return dD_dt, dH_dt
55
+
56
+ @staticmethod
57
+ def _solver_post_processing(results, t, interpolation):
58
+ results["survival"] = jnp.exp(-results["H"])
59
+ results["exposure"] = jax.vmap(interpolation.evaluate)(t)
60
+ return results
61
+
62
+ class RED_IT:
63
+ """Simplest guts model, mainly for testing"""
64
+ extra_dim = "substance"
65
+ _likelihood_func_jax = likelihood
66
+ _it_model = True
67
+ _params_info_defaults = _params_info_defaults
68
+
69
+ params_info = {
70
+ "hb": dict(name="hb", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, module="background-mortality"),
71
+ "kd": dict(name="kd", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, module="tktd"),
72
+ "m": dict(name="m", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, module="tktd"),
73
+ "beta": dict(name="beta", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, module="tktd"),
74
+ }
75
+
76
+ state_variables = {
77
+ "exposure": dict(dimensions=["id", "time", "substance"], observed=False),
78
+ "D": dict(dimensions=["id", "time"], observed=False, y0=[0.0]),
79
+ "H": dict(dimensions=["id", "time"], observed=False),
80
+ "survival": dict(dimensions=["id", "time"], observed=True)
81
+ }
82
+
83
+ @staticmethod
84
+ def _rhs_jax(t, y, x_in, kd):
85
+ D, = y
86
+ C = x_in.evaluate(t)
87
+
88
+ dD_dt = kd * (C - D)
89
+
90
+ return (dD_dt, )
91
+
92
+ @staticmethod
93
+ def _solver_post_processing(results, t, interpolation, m, beta, hb, eps):
94
+ """
95
+ TODO: Try alternative formulation. This is computationally simpler and numerically
96
+ more stable:
97
+ log S = log 1.0 + log (1.0 - F) + log exp -hb * t = 0 + log (1.0 - F) - hb * t
98
+ """
99
+
100
+ d_max = jnp.squeeze(jnp.array([jnp.max(results["D"][:i+1])+eps for i in range(len(t))]))
101
+ F = jnp.where(d_max > 0, 1.0 / (1.0 + (d_max / m) ** -beta), 0)
102
+ S = 1.0 * (jnp.array([1.0], dtype=float) - F) * jnp.exp(-hb * t)
103
+ results["H"] = - jnp.log(S)
104
+ results["survival"] = S
105
+ results["exposure"] = jax.vmap(interpolation.evaluate)(t)
106
+
107
+ return results
108
+
109
+
110
+ class RED_SD_DA:
111
+ """Simplest guts model, mainly for testing"""
112
+ extra_dim = "substance"
113
+ _likelihood_func_jax = likelihood
114
+ _it_model = False
115
+ _params_info_defaults = _params_info_defaults
116
+ params_info = {}
117
+
118
+ def __init__(self, num_expos = 1):
119
+ for i in range(num_expos):
120
+ self.params_info[f'kd{i+1}'] = {
121
+ 'name':f'kd{i+1}', 'min':1.0e-3, 'max':1.0e3, 'initial':1.0,
122
+ 'vary':True, "dims": "substance", "module": "tktd",
123
+ "unit": "1/{T}"
124
+ }
125
+ if i == 0:
126
+ self.params_info[f'w{i+1}'] = {
127
+ 'name':f'w{i+1}', 'initial':1.0, 'vary':False,
128
+ "dims": "substance", "module": "tktd",
129
+ "unit": "{X}/{X_i}",
130
+ }
131
+ else:
132
+ self.params_info[f'w{i+1}'] = {
133
+ 'name':f'w{i+1}', 'min':1.0e-3, 'max':1.0e3, 'initial':1.0,
134
+ 'vary':True, "dims": "substance", "module": "tktd",
135
+ "unit": "{X}/{X_i}"
136
+ }
137
+
138
+ self.params_info.update({
139
+ "hb": dict(name="hb", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}", module="background-mortality"),
140
+ "m": dict(name="m", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="{X}", module="tktd"),
141
+ "b": dict(name="b", min=None, max=None, initial=1.0, vary=True, prior="lognorm(scale=1,s=5)", dims=None, unit="1/{T}/{X}", module="tktd"),
142
+ })
143
+
144
+ self.state_variables = {
145
+ "exposure": dict(dimensions=["id", "time", "substance"], unit="{Xi}", observed=False),
146
+ "D": {"dimensions": ["id", "time", "substance"], "observed": False, "y0": [0.0] * num_expos},
147
+ "H": dict(dimensions=["id", "time"], observed=False, y0=[0.0]),
148
+ "survival": dict(dimensions=["id", "time"], observed=True)
149
+ }
150
+
151
+ @staticmethod
152
+ def _rhs_jax(t, y, x_in, kd, w, b, m, hb):
153
+ D, H = y
154
+ dD_dt = kd * (x_in.evaluate(t) - D)
155
+ dH_dt = b * jnp.maximum(0.0, jnp.sum(w * D) - m) + hb
156
+ return dD_dt, dH_dt
157
+
158
+ @staticmethod
159
+ def _solver_post_processing(results, t, interpolation):
160
+ results["survival"] = jnp.exp(-results["H"])
161
+ results["exposure"] = jax.vmap(interpolation.evaluate)(t)
162
+ return results
163
+
164
+ class RED_SD_explicit_units(RED_SD):
165
+ def __init__(self):
166
+ self.params_info["hb"]["unit"] = "1/day"
167
+ self.params_info["b"]["unit"] = "1/mg/day"
168
+ self.params_info["kd"]["unit"] = "1/day"
169
+ self.params_info["m"]["unit"] = "mg"
170
+
171
+ red_sd = RED_SD._rhs_jax
172
+ red_sd_post_processing = RED_SD._solver_post_processing
173
+
174
+
175
+ red_sd_da = RED_SD_DA._rhs_jax
176
+ red_sd_da_post_processing = RED_SD_DA._solver_post_processing
177
+
178
+
179
+ def guts_constant_exposure(t, y, C_0, k_d, z, b, h_b):
180
+ # for constant exposure
181
+ D, H, S = y
182
+ dD_dt = k_d * (C_0 - D)
183
+
184
+ switchDS = 0.5 + (1 / jnp.pi) * jnp.arctan(1e16 * (D - z))
185
+ dH_dt = (b * switchDS * (D - z) + h_b)
186
+
187
+ dS_dt = -dH_dt * S
188
+
189
+ return dD_dt, dH_dt, dS_dt
190
+
191
+ def guts_variable_exposure(t, y, x_in, k_d, z, b, h_b):
192
+ # for constant exposure
193
+ D, H, S = y
194
+ C = x_in.evaluate(t)
195
+ dD_dt = k_d * (C - D)
196
+
197
+ switchDS = 0.5 + (1 / jnp.pi) * jnp.arctan(1e16 * (D - z))
198
+ dH_dt = (b * switchDS * (D - z) + h_b)
199
+
200
+ dS_dt = -dH_dt * S
201
+
202
+ return dD_dt, dH_dt, dS_dt
203
+
204
+
205
+ class Interpolation:
206
+ def __init__(self, xs, ys, method="previous") -> None:
207
+ self.f = interp1d(
208
+ x=xs,
209
+ y=ys,
210
+ axis=0,
211
+ kind=method
212
+ )
213
+
214
+ def evaluate(self, t) -> np.ndarray:
215
+ return self.f(t)
216
+
217
+
218
+ class PiecewiseSymbolicSolver(PiecewiseSymbolicODESolver):
219
+ interpolation_type = "previous"
220
+
221
+ def t_jump(self, func_name, compiled_functions={}):
222
+ t, Y, Y_0, theta = self.define_symbols()
223
+
224
+ D_t = compiled_functions["F"]["algebraic_solutions"]["D"]
225
+ z = theta["z"]
226
+ eq = sp.Eq(D_t.rhs, z).expand()
227
+
228
+ t_0 = sp.solve(eq, t)
229
+
230
+ assert len(t_0) == 1
231
+ func = sp.simplify(t_0[0])
232
+
233
+ python_code = FunctionPythonCode(
234
+ func_name=func_name,
235
+ lhs_0=("Y_0", tuple(Y_0.keys())),
236
+ theta=("theta", tuple(theta.keys())),
237
+ lhs=("Y",),
238
+ rhs=(func,),
239
+ expand_arguments=False,
240
+ modules=("numpy","scipy"),
241
+ docstring=""
242
+ )
243
+
244
+ tex = self.to_latex(solutions=[sp.Eq(sp.Symbol(func_name), func)])
245
+ code_file = Path(self.output_path, f"{func_name}.tex")
246
+ with open(code_file, "w") as f:
247
+ f.writelines(tex)
248
+
249
+ return func, python_code
250
+
251
+ def define_symbols(self):
252
+ """Define the necessary symbols solely based on the function"""
253
+ thetanames = mappar(
254
+ self.model, {},
255
+ exclude=["t", "dt", "y", "x_in", "Y", "X"],
256
+ to="names"
257
+ )
258
+ ynames = [dX_dt2X(a) for a in get_return_arguments(self.model)]
259
+
260
+ # define symbols for t, Y, Y_0 and theta
261
+ t = sp.Symbol("t", positive=True, real=True)
262
+ Y = {y: sp.Function(y, positive=True, real=True) for y in ynames}
263
+ Y_0 = {
264
+ f"{y}_0": sp.Symbol(f"{y}_0", positive=True, real=True)
265
+ for y in ynames
266
+ }
267
+ theta = {p: sp.Symbol(p, positive=True, real=True) for p in thetanames}
268
+
269
+ symbols = (t, Y, Y_0, theta)
270
+
271
+ return symbols
272
+
273
+ @property
274
+ def compiler_recipe(self):
275
+ return {
276
+ "F":self.compile_model,
277
+ "t_jump":self.t_jump,
278
+ "F_piecewise":partial(self.jump_solution, funcnames="F t_jump")
279
+ }
280
+
281
+ def solve(self, parameters, y0, x_in):
282
+ odeargs = mappar(
283
+ self.model,
284
+ parameters,
285
+ exclude=["t", "dt", "y", "x_in"],
286
+ to="dict"
287
+ )
288
+
289
+ F_piecewise = self.compiled_functions["F_piecewise"]["compiled_function"]
290
+
291
+ # get arguments
292
+ time = np.array(self.x)
293
+ Y_0_values = [v for v in y0.values()]
294
+
295
+ # handle interpolation
296
+ if "exposure" in self.coordinates_input_vars["x_in"]:
297
+ exposure_interpolation = Interpolation(
298
+ xs=self.coordinates_input_vars["x_in"]["exposure"][self.x_dim],
299
+ ys=x_in["exposure"],
300
+ method=self.interpolation_type
301
+ )
302
+ else:
303
+ exposure_interpolation = Interpolation(
304
+ xs=self.x,
305
+ ys=np.full_like(self.x, odeargs["C_0"]), #type: ignore
306
+ method=self.interpolation_type
307
+ )
308
+
309
+ # run main loop
310
+ sol = [Y_0_values]
311
+ for i in range(1, len(time)):
312
+ # parse arguments
313
+ C_0 = exposure_interpolation.evaluate(time[i-1])
314
+ odeargs.update({"C_0": C_0}) # type: ignore
315
+ dt = time[i] - time[i-1]
316
+
317
+ # call piecewise function
318
+ y_t = F_piecewise(
319
+ t=dt,
320
+ Y_0=sol[i-1],
321
+ θ=tuple(odeargs.values()), # type: ignore
322
+ ε=1e-14
323
+ )
324
+
325
+ sol.append(y_t)
326
+
327
+ Y_t = np.array(sol)
328
+
329
+ results = {k:y for k, y in zip(y0.keys(), Y_t.T)}
330
+ results["exposure"] = exposure_interpolation.evaluate(np.array(self.x))
331
+
332
+ return results
guts_base/plot.py ADDED
@@ -0,0 +1,201 @@
1
+ from cycler import cycler
2
+ import numpy as np
3
+ import arviz as az
4
+ from matplotlib import pyplot as plt
5
+
6
+ from pymob.sim.plot import SimulationPlot
7
+ from pymob import SimulationBase
8
+
9
+ def plot_prior_predictions(
10
+ sim: SimulationBase,
11
+ data_vars=["survival"],
12
+ title_func=lambda sp, c: f"{sp.observations.id.values[0]}"
13
+ ):
14
+ idata = sim.inferer.prior_predictions() # type: ignore
15
+ def plot_survival_data_probs(idata):
16
+ return idata["survival"] * sim.observations.sel(id=[id_], time=0).survival
17
+
18
+ fig, axes = plt.subplots(
19
+ ncols=sim.observations.dims["id"], nrows=len(data_vars),
20
+ sharex=True, sharey="row", figsize=(30,10), squeeze=False)
21
+ for i, id_ in enumerate(sim.coordinates["id"]):
22
+ simplot = SimulationPlot(
23
+ observations=sim.observations.sel(id=[id_]),
24
+ idata=idata.sel(id=[id_]), # type: ignore
25
+ rows=data_vars,
26
+ coordinates=sim.coordinates,
27
+ config=sim.config,
28
+ obs_idata_map={
29
+ "survival": plot_survival_data_probs
30
+ },
31
+ idata_groups=["prior_model_fits"], # type: ignore
32
+ )
33
+ # replace simplot axis
34
+ for j, k in enumerate(simplot.rows):
35
+ simplot.axes_map[k]["all"] = axes[j][i]
36
+
37
+ simplot.plot_data_variables()
38
+ simplot.set_titles(title_func)
39
+ for j, k in enumerate(simplot.rows):
40
+ if i != 0:
41
+ simplot.axes_map[k]["all"].set_ylabel("")
42
+ if j != 3:
43
+ simplot.axes_map[k]["all"].set_xlabel("")
44
+
45
+ simplot.close()
46
+
47
+ fig.tight_layout()
48
+ fig.savefig(f"{sim.output_path}/combined_prior_predictions.png")
49
+
50
+
51
+ def plot_posterior_predictions(
52
+ sim: SimulationBase,
53
+ data_vars=["survival"],
54
+ title_func=lambda sp, c: f"{sp.observations.id.values[0]}",
55
+ groups=["posterior_model_fits", "posterior_predictive"],
56
+ ):
57
+ fig, axes = plt.subplots(
58
+ ncols=sim.observations.dims["id"], nrows=len(data_vars),
59
+ sharex=True, sharey="row", figsize=(30,10), squeeze=False
60
+ )
61
+
62
+ def plot_survival_data_probs_and_preds(dataset):
63
+ if dataset.attrs["group"] == "posterior_model_fits":
64
+ return dataset["survival"] * sim.observations.sel(id=[id_], time=0).survival
65
+ if dataset.attrs["group"] == "posterior_predictive":
66
+ return dataset["survival"]
67
+
68
+ for i, id_ in enumerate(sim.coordinates["id"]):
69
+ simplot = SimulationPlot(
70
+ observations=sim.observations.sel(id=[id_]),
71
+ idata=sim.inferer.idata.sel(id=[id_]), # type: ignore
72
+ rows=data_vars,
73
+ coordinates=sim.coordinates,
74
+ config=sim.config,
75
+ obs_idata_map={
76
+ "survival": plot_survival_data_probs_and_preds,
77
+ },
78
+ idata_groups=groups, # type: ignore
79
+ )
80
+
81
+ # replace simplot axis
82
+ for j, k in enumerate(simplot.rows):
83
+ simplot.axes_map[k]["all"] = axes[j][i]
84
+
85
+ simplot.plot_data_variables()
86
+ simplot.set_titles(title_func)
87
+ for j, k in enumerate(simplot.rows):
88
+ if i != 0:
89
+ simplot.axes_map[k]["all"].set_ylabel("")
90
+ if j != 3:
91
+ simplot.axes_map[k]["all"].set_xlabel("")
92
+
93
+ simplot.close()
94
+
95
+ fig.tight_layout()
96
+ fig.savefig(f"{sim.output_path}/combined_posterior_predictions.png")
97
+
98
+
99
+ def plot_survival(sim: SimulationBase, results):
100
+ fig, ax = plt.subplots(1,1)
101
+ obs = sim.observations.survival / sim.observations.subject_count
102
+ ax.plot(sim.observations.time, obs.T, marker="o", color="black")
103
+ ax.plot(results.time, results.survival.T, color="black")
104
+
105
+
106
+ def plot_survival_multipanel(sim: SimulationBase, results, ncols=6, title=lambda _id: _id, filename="survival_multipanel"):
107
+
108
+ n_panels = results.sizes["id"]
109
+
110
+ nrows = int(np.ceil(n_panels / ncols))
111
+
112
+ fig, axes = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=(ncols*2+2, nrows*1.5+2))
113
+ axes = axes.flatten()
114
+ mean = results.mean(("chain", "draw"))
115
+ hdi = az.hdi(results, 0.95)
116
+ survival = sim.observations.survival / sim.observations.survival.isel(time=0)
117
+
118
+ plot_kwargs = {"color": "black"}
119
+ # param_cycler = plt.rcParams['axes.prop_cycle']
120
+ for _id, ax in zip(sim.observations.id.values, axes):
121
+ ax.set_ylim(-0.05,1.05)
122
+
123
+ ax.set_xlabel(f"Time [{sim.config.guts_base.unit_time}]")
124
+ ax.set_ylabel("Survival")
125
+ ax.plot(mean.time, mean.sel(id=_id).survival.T, **plot_kwargs)
126
+ ax.fill_between(hdi.time, *hdi.sel(id=_id).survival.T, alpha=.5, **plot_kwargs) # type: ignore
127
+ ax.plot(survival.time, survival.sel(id=_id).T, ls="", marker="o", alpha=.5, **plot_kwargs)
128
+ ax.set_title(title(_id))
129
+
130
+ out = f"{sim.output_path}/{filename}.png"
131
+ fig.tight_layout()
132
+ fig.savefig(out)
133
+
134
+ return out
135
+
136
+ def plot_exposure_multipanel(sim: SimulationBase, results, ncols=6, title=lambda _id: _id, filename="exposure_multipanel"):
137
+
138
+ n_panels = results.sizes["id"]
139
+
140
+ nrows = int(np.ceil(n_panels / ncols))
141
+
142
+ fig, axes = plt.subplots(nrows, ncols, sharex=True, figsize=(ncols*2+2, nrows*1.5+2))
143
+ axes = axes.flatten()
144
+ mean = results
145
+
146
+ plot_kwargs = {"color": "black"}
147
+ custom_cycler = (
148
+ cycler(ls=["-", "--", ":", "-."])
149
+ )
150
+
151
+ labels = {}
152
+ for _id, ax in zip(sim.observations.id.values, axes):
153
+ ax.set_prop_cycle(custom_cycler)
154
+
155
+ ax.set_xlabel(f"Time [{sim.config.guts_base.unit_time}]")
156
+ ax.set_ylabel("Exposure")
157
+ for expo in sim.coordinates[sim._exposure_dimension]: # type: ignore
158
+ line, = ax.plot(
159
+ mean.time, mean.sel({"id":_id, sim._exposure_dimension: expo}).exposure, # type: ignore
160
+ **plot_kwargs, label=f"Exposure: {expo}"
161
+ )
162
+ labels.update({f"Exposure: {expo}": line})
163
+ ax.set_title(title(_id))
164
+
165
+ fig.legend(labels.values(), labels.keys(), loc='lower center', fontsize=10, frameon=False)
166
+ out = f"{sim.output_path}/{filename}.png"
167
+ fig.tight_layout(rect=[0,0.05,1.0,1.0], ) # type: ignore
168
+ fig.savefig(out)
169
+
170
+ return out
171
+
172
+ def multipanel_title(sim, _id):
173
+ oid = sim.observations.sel(id=_id)
174
+ exposure_path = oid.exposure_path.values
175
+ rac = np.round(oid.concentration_closer_x_rac.max().values * 100, 3)
176
+ return "{ep}\n{c} %RAC".format(ep=str(exposure_path), c=str(rac))
177
+
178
+ def plot_intermediate_results(sim: SimulationBase, id=0):
179
+ e = sim.dispatch()
180
+ e()
181
+ e.results
182
+
183
+ results = e.results.isel(id=[id])
184
+
185
+ plot_results(results=results, batch_dim=sim.config.simulation.batch_dimension)
186
+
187
+ def plot_results(results, batch_dim="id", axes=None, **plot_kwargs):
188
+ datavars = list(results.data_vars.keys())
189
+ if axes is None:
190
+ fig, axes = plt.subplots(len(datavars), 1)
191
+
192
+ for ax, dv in zip(axes, datavars):
193
+ res = results[dv]
194
+ if len(res.shape) > 2:
195
+ res_ = res.transpose(..., "time", batch_dim)
196
+ for r in res_:
197
+ ax.plot(r.time, r, **plot_kwargs)
198
+ else:
199
+ ax.plot(results.time, res.transpose("time", batch_dim), **plot_kwargs)
200
+
201
+ ax.set_ylabel(dv)
@@ -0,0 +1,13 @@
1
+ from . import conditional_binom
2
+ from . import conditional_binom_mv
3
+ from . import predictions
4
+ from . import binom
5
+
6
+ from .binom import likelihood
7
+
8
+ from .predictions import (
9
+ survival_predictions,
10
+ posterior_predictions,
11
+ )
12
+
13
+ from .conditional_binom_mv import conditional_survival
@@ -0,0 +1,18 @@
1
+ import numpyro
2
+
3
+ def likelihood(theta, simulation_results, indices, observations, masks, make_predictions):
4
+ """Uses lookup and error model from the local function context"""
5
+ if make_predictions:
6
+ obs = None
7
+ else:
8
+ obs = observations["survival"]
9
+
10
+ _ = numpyro.sample(
11
+ name="survival" + "_obs",
12
+ fn=numpyro.distributions.Binomial(
13
+ probs=simulation_results["survival"],
14
+ total_count=observations["survivors_at_start"],
15
+ ).mask(masks["survival"]),
16
+ obs=obs
17
+ )
18
+
@@ -0,0 +1,118 @@
1
+ from guts_base.data.survival import generate_survival_repeated_observations
2
+ from scipy.stats import rv_discrete
3
+ from scipy.stats._discrete_distns import _isintegral
4
+ import numpy as np
5
+ import scipy.stats._boost as _boost
6
+ from scipy import stats
7
+ from matplotlib import pyplot as plt
8
+
9
+ def conditional_prob_neglogexp(p, p_init=1.0, eps=1e-12):
10
+ p_ = np.concatenate([p_init, p[:]])
11
+ # p needs to be clipped zero division does not occurr, if the last p values are zero
12
+ p_clipped = np.clip(p_, eps, 1.0)
13
+ # convert to logscale
14
+ neg_log_p = -np.log(p_clipped)
15
+ # exponent substraction is numerically more stable than division
16
+ return np.exp(neg_log_p[:-1] - neg_log_p[1:])
17
+
18
+
19
+ def conditional_prob(p, p_init=1.0):
20
+ p_ = np.concatenate([p_init, p[:]])
21
+ # divide later though previous probability
22
+ return p_[1:] / p_[:-1]
23
+
24
+
25
+ def conditional_prob_from_neglogp(p, p_init=1.0):
26
+ p_ = np.concatenate([[p_init], p[:]])
27
+ return np.exp(p_[:-1] - p_[1:])
28
+
29
+
30
+
31
+ class conditional_survival_gen(rv_discrete):
32
+ """
33
+ A scipy distribution for a conditional survival probability distribution
34
+
35
+ Parameters
36
+ ----------
37
+ k: Array
38
+ Number of repeated positive observations of a quantity that
39
+ can only decrease. k must be monotonically decreasing (e.g. survivors)
40
+
41
+ p: Array
42
+ survival function of the repeated observation.
43
+ p must be monotonically decreasing
44
+
45
+ n_init: int
46
+ The starting number of positive observations (e.g. initial number of organisms
47
+ in a survival trial)
48
+
49
+ p_init: float
50
+ The starting survival probability in an experiment
51
+
52
+ Example
53
+ -------
54
+ Define a survival function (using a beta cdf and use it to make multinomial draws)
55
+ to simulate survivals from repeated observations
56
+ >>> n = 100
57
+ >>> B = stats.beta(5, 5)
58
+ >>> p = 1 - B.cdf(np.linspace(0, 1))
59
+ >>> s = stats.multinomial(n, p=np.diff(p)*-1).rvs()[0]
60
+ >>> s = n - s.cumsum()
61
+
62
+ construct a frozen distribution
63
+ >>> from guts_base.prob import conditional_survival
64
+ >>> S = conditional_survival(p=p[1:], n_init=[n], p_init=[1.0], eps=[1e-12])
65
+
66
+ Compute the pmf
67
+ >>> S.pmf(s)
68
+
69
+ Compute the logpmf
70
+ >>> S.logpmf(s)
71
+
72
+ Draw random samples
73
+ >>> samples = S.rvs(size=(1000, 49))
74
+
75
+ Plot the observational variation of a given survival function under repeated
76
+ Observations
77
+ >>> plt.plot(samples.T, color="black", alpha=.02)
78
+
79
+
80
+ """
81
+ def __init__(self, **kwargs):
82
+ # Initialize your custom parameters here
83
+ super().__init__(**kwargs)
84
+ # Set up any custom state needed for sampling
85
+
86
+ def _argcheck(self, p, n_init, p_init, eps):
87
+ return (n_init >= 0) & _isintegral(n_init) & (p >= 0) & (p <= 1)# & (np.diff(np.concatenate([p_init, p])) <= 0)
88
+
89
+ def _pmf(self, x, p, n_init, p_init, eps):
90
+ # nan filling is not necessary, because nans are thrown out this shifts the
91
+ # p vector to where it belongs
92
+ n_ = np.concatenate([n_init[[0]], x[:-1]])
93
+ p_conditional = conditional_prob_neglogexp(p, p_init=p_init[[0]], eps=eps[[0]])
94
+ return _boost._binom_pdf(x, n_, p_conditional)
95
+
96
+ def _rvs(self, p, n_init=10, p_init=[1.0], eps=[1e-12], size=1, random_state=None):
97
+ p_conditional = conditional_prob_neglogexp(p, p_init=p_init[[0]], eps=eps[[0]])
98
+
99
+ # axis-0 is the batch dimension
100
+ # axis-1 is the time dimension (probability)
101
+ L = np.zeros(shape=(*size,))
102
+ L = np.array(L, ndmin=2)
103
+
104
+ for i in range(L.shape[1]):
105
+ # calculate the binomial response of the conditional survival
106
+ # i.e. the probability to die within an interval conditional on
107
+ # having survived until the beginning of that interval
108
+ L[:, i] = random_state.binomial(
109
+ p=1 - p_conditional[i],
110
+ n=n_init[i]-L.sum(axis=1).astype(int)
111
+ )
112
+
113
+ return n_init-L.cumsum(axis=1)
114
+
115
+ conditional_survival = conditional_survival_gen(name="conditional_survival", )
116
+
117
+
118
+