guts-base 0.8.6__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

Files changed (40) hide show
  1. {guts_base-0.8.6 → guts_base-1.0.0}/PKG-INFO +2 -3
  2. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/__init__.py +2 -1
  3. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/data/__init__.py +1 -1
  4. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/data/generator.py +2 -1
  5. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/data/survival.py +6 -0
  6. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/mod.py +24 -83
  7. guts_base-1.0.0/guts_base/prob.py +160 -0
  8. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/sim/__init__.py +10 -1
  9. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/sim/base.py +285 -75
  10. guts_base-1.0.0/guts_base/sim/constructors.py +31 -0
  11. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/sim/ecx.py +168 -58
  12. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/sim/mempy.py +85 -70
  13. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/sim/report.py +0 -1
  14. guts_base-1.0.0/guts_base/sim/utils.py +10 -0
  15. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base.egg-info/PKG-INFO +2 -3
  16. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base.egg-info/SOURCES.txt +2 -2
  17. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base.egg-info/requires.txt +1 -2
  18. {guts_base-0.8.6 → guts_base-1.0.0}/pyproject.toml +3 -4
  19. guts_base-1.0.0/tests/test_ecx.py +38 -0
  20. {guts_base-0.8.6 → guts_base-1.0.0}/tests/test_from_pymob.py +1 -8
  21. {guts_base-0.8.6 → guts_base-1.0.0}/tests/test_simulations.py +3 -29
  22. guts_base-0.8.6/guts_base/prob.py +0 -412
  23. guts_base-0.8.6/guts_base/sim.py +0 -0
  24. guts_base-0.8.6/tests/test_ecx.py +0 -62
  25. guts_base-0.8.6/tests/test_simulations_from_mempy.py +0 -119
  26. {guts_base-0.8.6 → guts_base-1.0.0}/LICENSE +0 -0
  27. {guts_base-0.8.6 → guts_base-1.0.0}/README.md +0 -0
  28. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/data/expydb.py +0 -0
  29. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/data/openguts.py +0 -0
  30. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/data/preprocessing.py +0 -0
  31. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/data/time_of_death.py +0 -0
  32. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/data/utils.py +0 -0
  33. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base/plot.py +0 -0
  34. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base.egg-info/dependency_links.txt +0 -0
  35. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base.egg-info/entry_points.txt +0 -0
  36. {guts_base-0.8.6 → guts_base-1.0.0}/guts_base.egg-info/top_level.txt +0 -0
  37. {guts_base-0.8.6 → guts_base-1.0.0}/setup.cfg +0 -0
  38. {guts_base-0.8.6 → guts_base-1.0.0}/tests/test_data_import.py +0 -0
  39. {guts_base-0.8.6 → guts_base-1.0.0}/tests/test_scripted_simulations.py +0 -0
  40. {guts_base-0.8.6 → guts_base-1.0.0}/tests/test_symbolic_solve.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: guts_base
3
- Version: 0.8.6
3
+ Version: 1.0.0
4
4
  Summary: Basic GUTS model implementation in pymob
5
5
  Author-email: Florian Schunck <fluncki@protonmail.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -692,8 +692,7 @@ License-File: LICENSE
692
692
  Requires-Dist: openpyxl>=3.1.3
693
693
  Requires-Dist: Bottleneck>=1.5.0
694
694
  Requires-Dist: expydb>=0.6.0
695
- Requires-Dist: mempyguts>=1.5.0
696
- Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.4.1
695
+ Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.5.10
697
696
  Provides-Extra: dev
698
697
  Requires-Dist: pytest>=7.3; extra == "dev"
699
698
  Requires-Dist: bumpver; extra == "dev"
@@ -4,11 +4,12 @@ from . import data
4
4
  from . import prob
5
5
  from . import plot
6
6
 
7
- __version__ = "0.8.6"
7
+ __version__ = "1.0.0"
8
8
 
9
9
  from .sim import (
10
10
  GutsBase,
11
11
  PymobSimulator,
12
12
  ECxEstimator,
13
13
  LPxEstimator,
14
+ GutsBaseError,
14
15
  )
@@ -21,7 +21,7 @@ from .survival import (
21
21
  generate_survival_repeated_observations
22
22
  )
23
23
 
24
- from .generator import create_artificial_data, design_exposure_scenario
24
+ from .generator import create_artificial_data, design_exposure_scenario, ExposureDataDict
25
25
 
26
26
  from .expydb import (
27
27
  to_dataset,
@@ -11,7 +11,8 @@ def create_artificial_data(
11
11
  t_max,
12
12
  dt,
13
13
  exposure_paths=["oral", "topical", "contact"],
14
- intensity=[0.1, 0.5, 0.05]
14
+ intensity=[0.1, 0.5, 0.05],
15
+ seed=1,
15
16
  ):
16
17
  rng = np.random.default_rng(1)
17
18
  time = np.arange(0, t_max, step=dt) # daily time resolution
@@ -40,6 +40,12 @@ def prepare_survival_data_for_conditional_binomial(observations: xr.Dataset) ->
40
40
  nsurv.isel(time=list(range(0, len(nsurv.time)-1))).values
41
41
  ]).astype(int))})
42
42
 
43
+ observations = observations.assign_coords({
44
+ "survivors_at_start": (("id", "time"), np.broadcast_to(
45
+ nsurv.isel(time=0).values.reshape(-1,1),
46
+ shape=nsurv.shape
47
+ ).astype(int))})
48
+
43
49
  return observations
44
50
 
45
51
 
@@ -1,10 +1,3 @@
1
- """
2
- GUTS models
3
-
4
- TODO: Import guts models from mempy and update to work well with jax
5
- TODO: Based on this implement the bufferguts model in mempy and import in the
6
- bufferguts case-study
7
- """
8
1
  from functools import partial
9
2
  from pathlib import Path
10
3
 
@@ -19,79 +12,37 @@ from pymob.solvers.symbolic import (
19
12
  PiecewiseSymbolicODESolver, FunctionPythonCode, get_return_arguments, dX_dt2X
20
13
  )
21
14
 
22
- from mempy.model import (
23
- RED_IT,
24
- RED_SD,
25
- RED_IT_DA,
26
- RED_SD_DA,
27
- RED_IT_IA,
28
- RED_SD_IA,
29
- BufferGUTS_IT,
30
- BufferGUTS_IT_CA,
31
- BufferGUTS_IT_DA
32
- )
15
+ class RED_SD:
16
+ """Simplest guts model, mainly for testing"""
17
+ @staticmethod
18
+ def _rhs_jax(t, y, x_in, kd, b, m, hb):
19
+ D, H = y
20
+ dD_dt = kd * (x_in.evaluate(t) - D)
21
+ dH_dt = b * jnp.maximum(0.0, D - m) + hb
22
+ return dD_dt, dH_dt
23
+
24
+ @staticmethod
25
+ def _solver_post_processing(results, t, interpolation):
26
+ results["survival"] = jnp.exp(-results["H"])
27
+ results["exposure"] = jax.vmap(interpolation.evaluate)(t)
28
+ return results
33
29
 
34
30
  red_sd = RED_SD._rhs_jax
35
31
  red_sd_post_processing = RED_SD._solver_post_processing
36
32
 
37
- red_it = RED_IT._rhs_jax
38
- red_it_post_processing = RED_IT._solver_post_processing
33
+ class RED_SD_DA(RED_SD):
34
+ @staticmethod
35
+ def _rhs_jax(t, y, x_in, kd, w, b, m, hb):
36
+ D, H = y
37
+ dD_dt = kd * (x_in.evaluate(t) - D)
38
+ dH_dt = b * jnp.maximum(0.0, jnp.sum(w * D) - m) + hb
39
+ return dD_dt, dH_dt
39
40
 
40
- red_sd_ia = RED_SD_IA._rhs_jax
41
- red_sd_ia_post_processing = RED_SD_IA._solver_post_processing
41
+ red_sd_da = RED_SD_DA._rhs_jax
42
+ red_sd_da_post_processing = RED_SD_DA._solver_post_processing
42
43
 
43
44
 
44
- def p_survival(results, t, interpolation, z, k_k, h_b):
45
- """Computes the stochastic death survival probability after computing
46
-
47
- """
48
- # calculate survival
49
- results["exposure"] = jax.vmap(interpolation.evaluate)(t)
50
- p_surv = survival_jax(t, results["D"], z, k_k, h_b)
51
- results["survival"] = p_surv
52
- results["lethality"] = 1 - p_surv
53
- return results
54
-
55
- def it_post_processing(results, t, interpolation, alpha, beta, h_b, eps):
56
- results["exposure"] = jax.vmap(interpolation.evaluate)(t)
57
- p_surv = survival_IT_jax(t, results["D"], alpha, beta, h_b, eps)
58
- results["survival"] = p_surv
59
- results["H"] = - jnp.log(p_surv)
60
- return results
61
-
62
- def post_exposure(results, t, interpolation):
63
- results["survival"] = jnp.exp(-results["H"])
64
- results["exposure"] = jax.vmap(interpolation.evaluate)(t)
65
- return results
66
-
67
-
68
- def no_post_processing(results):
69
- return results
70
-
71
-
72
- @jax.jit
73
- def survival_jax(t, damage, z, kk, h_b):
74
- """
75
- survival probability derived from hazard
76
- first calculate cumulative Hazard by integrating hazard cumulatively over t
77
- then calculate the resulting survival probability
78
- It was checked that `survival_jax` behaves exactly the same as `survival`
79
- """
80
- hazard = kk * jnp.where(damage - z < 0, 0, damage - z) + h_b
81
- H = jnp.array([jax.scipy.integrate.trapezoid(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
82
- # H = jnp.array([jnp.trapz(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
83
- S = jnp.exp(-H)
84
-
85
- return S
86
-
87
- @jax.jit
88
- def survival_IT_jax(t, damage, alpha, beta, h_b, eps):
89
- d_max = jnp.squeeze(jnp.array([jnp.max(damage[:i+1])+eps for i in range(len(t))]))
90
- F = jnp.where(d_max > 0, 1.0 / (1.0 + (d_max / alpha) ** -beta), 0)
91
- S = 1.0 * (jnp.array([1.0], dtype=float) - F) * jnp.exp(-h_b * t)
92
- return S
93
-
94
- def guts_jax(t, y, C_0, k_d, z, b, h_b):
45
+ def guts_constant_exposure(t, y, C_0, k_d, z, b, h_b):
95
46
  # for constant exposure
96
47
  D, H, S = y
97
48
  dD_dt = k_d * (C_0 - D)
@@ -103,16 +54,6 @@ def guts_jax(t, y, C_0, k_d, z, b, h_b):
103
54
 
104
55
  return dD_dt, dH_dt, dS_dt
105
56
 
106
-
107
- def RED_IT(t, y, x_in, kd):
108
- D, = y
109
- C = x_in.evaluate(t)
110
-
111
- dD_dt = kd * (C - D)
112
-
113
- return (dD_dt, )
114
-
115
-
116
57
  def guts_variable_exposure(t, y, x_in, k_d, z, b, h_b):
117
58
  # for constant exposure
118
59
  D, H, S = y
@@ -0,0 +1,160 @@
1
+ from functools import partial
2
+ import numpy as np
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from numpyro.infer import Predictive
6
+ from typing import Literal
7
+
8
+
9
+ def survival_predictions(
10
+ probs, n_trials,
11
+ eps: float = 0.0,
12
+ seed=1,
13
+ mode: Literal["survival", "lethality", "deaths"] = "survival"
14
+ ):
15
+ """Generate predictions for survival based on a multinomial survival distribution
16
+
17
+ Parameters
18
+ ----------
19
+ probs : ArrayLike
20
+ 2D Array denoting the multinomial probabilities of deaths for each time interval
21
+ per experiment
22
+ dims=(experiment, time)
23
+ n_trials : ArrayLike
24
+ 1D Array denoting the number of organisms at the beginning in each experiment
25
+ dims = (experiment,)
26
+ seed : int, optional
27
+ Seed for the random number generator, by default 1
28
+ mode : str, optional
29
+ How should the random draws be returned?
30
+ - survival: Decreasing from n_trials to 0
31
+ - lethality: Increasing from 0 to n_trials
32
+ - deaths: Between 0 and n_trials in each interval. Summing to n_trials
33
+ """
34
+
35
+ def survival_to_death_probs(pr_survival):
36
+ # truncate here, because numeric errors below the solver tolerance can
37
+ # lead to negative values in the difference. This needs to be cured here
38
+ pr_survival_ = np.trunc(pr_survival / eps) * eps
39
+ pr_death = pr_survival_[:-1] - pr_survival_[1:]
40
+
41
+ pr_death = np.concatenate([
42
+ # concatenate a zero at the beginning in order to "simulate" no
43
+ # deaths at T = 0
44
+ jnp.zeros((1,)),
45
+ # Delta S
46
+ pr_death,
47
+ # The remaining mortility as T -> infinity
48
+ jnp.ones((1,))-pr_death.sum()
49
+ ])
50
+
51
+ # make sure the vector is not zero or 1 (this is always problematic for
52
+ # probabilities) and make sure the vector sums to 1
53
+ pr_death = np.clip(pr_death, eps, 1-eps)
54
+ pr_death = pr_death / pr_death.sum()
55
+ return pr_death
56
+
57
+ rng = np.random.default_rng(seed)
58
+ deaths = jnp.array(list(map(
59
+ lambda n, pr_survival: rng.multinomial(
60
+ n=n, pvals=survival_to_death_probs(pr_survival)
61
+ ),
62
+ n_trials,
63
+ probs
64
+ )))
65
+
66
+ # remove the last observations to trim off the simulated unobserved mortality
67
+ deaths = deaths[:, :-1]
68
+
69
+ if mode == "deaths":
70
+ return deaths
71
+ elif mode == "lethality":
72
+ return deaths.cumsum(axis=1)
73
+ elif mode == "survival":
74
+ return np.expand_dims(n_trials, axis=1) - deaths.cumsum(axis=1)
75
+ else:
76
+ raise NotImplementedError(
77
+ f"Mode {mode} is not implemented."+
78
+ "Use one of 'survival', 'lethality', or 'deaths'."
79
+ )
80
+
81
+ def posterior_predictions(sim, idata, seed=None):
82
+ """Make posterior predictions for survival data"""
83
+ if seed is None:
84
+ seed = sim.config.simulation.seed
85
+
86
+ n = idata.posterior.dims["draw"]
87
+ c = idata.posterior.dims["chain"]
88
+
89
+ key = jax.random.PRNGKey(seed)
90
+
91
+ obs, masks = sim.inferer.observation_parser()
92
+
93
+ model_kwargs = sim.inferer.preprocessing(obs=obs, masks=masks)
94
+ if sim.config.inference_numpyro.user_defined_error_model is None:
95
+ model_kwargs["obs"]["survival"] = None
96
+
97
+ # prepare model
98
+ model = partial(
99
+ sim.inferer.inference_model,
100
+ solver=sim.inferer.evaluator,
101
+ **model_kwargs
102
+ )
103
+
104
+ posterior_samples = {
105
+ k: np.array(v["data"]) for k, v
106
+ in idata.unconstrained_posterior.to_dict()["data_vars"].items()
107
+ }
108
+
109
+ predictive = Predictive(
110
+ model, posterior_samples=posterior_samples,
111
+ num_samples=n, batch_ndims=2
112
+ )
113
+
114
+ samples = predictive(key)
115
+
116
+ if sim.config.inference_numpyro.user_defined_error_model is not None:
117
+ chains = []
118
+ for i in range(c):
119
+ # TODO: Switch to vmap and jit, but it did not work, so if you do it TEST IT!!!!!!!
120
+ predictions = list(map(
121
+ partial(
122
+ survival_predictions,
123
+ n_trials=obs["survival"][:, 0].astype(int),
124
+ eps=obs["eps"],
125
+ seed=seed,
126
+ mode="survival",
127
+ ),
128
+ samples["survival"][i]
129
+ ))
130
+ chains.append(predictions)
131
+
132
+ posterior_predictive = {"survival_obs": np.array(chains)}
133
+ else:
134
+ posterior_predictive = {"survival_obs": samples.pop("survival_obs")}
135
+
136
+
137
+ new_idata = sim.inferer.to_arviz_idata(
138
+ posterior=samples,
139
+ posterior_predictive=posterior_predictive,
140
+ n_draws=n,
141
+ n_chains=c
142
+ )
143
+
144
+ # update chain names in case they were subselected (clustering)
145
+ # new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
146
+ new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
147
+
148
+ # assert the new posterior matches the old posterior
149
+ tol = sim.config.jaxsolver.atol * 100
150
+ abs_diff_posterior = np.abs(idata.posterior - new_idata.posterior)
151
+ np.testing.assert_array_less(abs_diff_posterior.mean().to_array(), tol)
152
+
153
+ fit_tol = tol * sim.coordinates["time"].max()
154
+ abs_diff_fits = np.abs(new_idata.posterior_model_fits - idata.posterior_model_fits)
155
+ np.testing.assert_array_less(abs_diff_fits.mean().to_array(), fit_tol)
156
+
157
+ idata.posterior_model_fits = new_idata.posterior_model_fits
158
+ idata.posterior_predictive = new_idata.posterior_predictive
159
+
160
+ return idata
@@ -1,6 +1,7 @@
1
1
  from . import base
2
2
  from . import ecx
3
3
  from . import report
4
+ from . import utils
4
5
 
5
6
  from .base import (
6
7
  GutsBase,
@@ -11,4 +12,12 @@ from .base import (
11
12
  from .ecx import ECxEstimator, LPxEstimator
12
13
  from .report import GutsReport
13
14
 
14
- from .mempy import PymobSimulator
15
+ from .mempy import PymobSimulator
16
+ from .utils import (
17
+ GutsBaseError
18
+ )
19
+
20
+ from .constructors import (
21
+ construct_sim_from_config,
22
+ load_idata
23
+ )