guts-base 0.8.5__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

Files changed (39) hide show
  1. {guts_base-0.8.5 → guts_base-1.0.0}/PKG-INFO +3 -4
  2. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/__init__.py +2 -1
  3. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/data/__init__.py +1 -1
  4. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/data/generator.py +6 -5
  5. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/data/survival.py +6 -0
  6. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/mod.py +27 -80
  7. guts_base-1.0.0/guts_base/prob.py +160 -0
  8. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/sim/__init__.py +10 -1
  9. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/sim/base.py +350 -78
  10. guts_base-1.0.0/guts_base/sim/constructors.py +31 -0
  11. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/sim/ecx.py +221 -63
  12. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/sim/mempy.py +85 -70
  13. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/sim/report.py +9 -9
  14. guts_base-1.0.0/guts_base/sim/utils.py +10 -0
  15. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base.egg-info/PKG-INFO +3 -4
  16. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base.egg-info/SOURCES.txt +3 -1
  17. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base.egg-info/requires.txt +1 -2
  18. {guts_base-0.8.5 → guts_base-1.0.0}/pyproject.toml +4 -5
  19. guts_base-1.0.0/tests/test_ecx.py +38 -0
  20. {guts_base-0.8.5 → guts_base-1.0.0}/tests/test_from_pymob.py +1 -8
  21. guts_base-1.0.0/tests/test_simulations.py +61 -0
  22. guts_base-0.8.5/guts_base/prob.py +0 -412
  23. guts_base-0.8.5/guts_base/sim.py +0 -0
  24. guts_base-0.8.5/tests/test_simulations.py +0 -109
  25. {guts_base-0.8.5 → guts_base-1.0.0}/LICENSE +0 -0
  26. {guts_base-0.8.5 → guts_base-1.0.0}/README.md +0 -0
  27. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/data/expydb.py +0 -0
  28. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/data/openguts.py +0 -0
  29. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/data/preprocessing.py +0 -0
  30. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/data/time_of_death.py +0 -0
  31. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/data/utils.py +0 -0
  32. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base/plot.py +0 -0
  33. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base.egg-info/dependency_links.txt +0 -0
  34. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base.egg-info/entry_points.txt +0 -0
  35. {guts_base-0.8.5 → guts_base-1.0.0}/guts_base.egg-info/top_level.txt +0 -0
  36. {guts_base-0.8.5 → guts_base-1.0.0}/setup.cfg +0 -0
  37. {guts_base-0.8.5 → guts_base-1.0.0}/tests/test_data_import.py +0 -0
  38. {guts_base-0.8.5 → guts_base-1.0.0}/tests/test_scripted_simulations.py +0 -0
  39. {guts_base-0.8.5 → guts_base-1.0.0}/tests/test_symbolic_solve.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: guts_base
3
- Version: 0.8.5
3
+ Version: 1.0.0
4
4
  Summary: Basic GUTS model implementation in pymob
5
5
  Author-email: Florian Schunck <fluncki@protonmail.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -686,14 +686,13 @@ Classifier: Natural Language :: English
686
686
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
687
687
  Classifier: Operating System :: OS Independent
688
688
  Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
689
- Requires-Python: >=3.10
689
+ Requires-Python: <3.12,>=3.10
690
690
  Description-Content-Type: text/markdown
691
691
  License-File: LICENSE
692
692
  Requires-Dist: openpyxl>=3.1.3
693
693
  Requires-Dist: Bottleneck>=1.5.0
694
694
  Requires-Dist: expydb>=0.6.0
695
- Requires-Dist: mempyguts>=1.5.0
696
- Requires-Dist: pymob[numpyro]<1.0.0,>=0.4.1
695
+ Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.5.10
697
696
  Provides-Extra: dev
698
697
  Requires-Dist: pytest>=7.3; extra == "dev"
699
698
  Requires-Dist: bumpver; extra == "dev"
@@ -4,11 +4,12 @@ from . import data
4
4
  from . import prob
5
5
  from . import plot
6
6
 
7
- __version__ = "0.8.5"
7
+ __version__ = "1.0.0"
8
8
 
9
9
  from .sim import (
10
10
  GutsBase,
11
11
  PymobSimulator,
12
12
  ECxEstimator,
13
13
  LPxEstimator,
14
+ GutsBaseError,
14
15
  )
@@ -21,7 +21,7 @@ from .survival import (
21
21
  generate_survival_repeated_observations
22
22
  )
23
23
 
24
- from .generator import create_artificial_data, design_exposure_scenario
24
+ from .generator import create_artificial_data, design_exposure_scenario, ExposureDataDict
25
25
 
26
26
  from .expydb import (
27
27
  to_dataset,
@@ -11,7 +11,8 @@ def create_artificial_data(
11
11
  t_max,
12
12
  dt,
13
13
  exposure_paths=["oral", "topical", "contact"],
14
- intensity=[0.1, 0.5, 0.05]
14
+ intensity=[0.1, 0.5, 0.05],
15
+ seed=1,
15
16
  ):
16
17
  rng = np.random.default_rng(1)
17
18
  time = np.arange(0, t_max, step=dt) # daily time resolution
@@ -67,7 +68,8 @@ def design_exposure_scenario(
67
68
  """
68
69
  TODO: tmax, dt and eps are probably not necessary
69
70
  """
70
- time = np.arange(0, t_max, step=dt) # daily time resolution
71
+ # add dt so that tmax is definitely inclded
72
+ time = np.arange(0, t_max+dt, step=dt) # daily time resolution
71
73
  time = np.unique(np.concatenate([time] + [
72
74
  np.array([time[-1] if vals["end"] is None else vals["end"]])
73
75
  for key, vals in exposures.items()
@@ -79,13 +81,12 @@ def design_exposure_scenario(
79
81
  treat = design_exposure_timeseries(time, expo, eps)
80
82
  treatments.update({key: treat})
81
83
 
82
- data = np.column_stack(list(treatments.values())).squeeze()
84
+ data = np.column_stack(list(treatments.values()))
83
85
  data = np.expand_dims(data, axis=0)
84
86
 
85
87
  coords = {"id": [0], "time": time}
86
88
 
87
- if len(exposures) > 1:
88
- coords.update({exposure_dimension: list(treatments.keys())})
89
+ coords.update({exposure_dimension: list(treatments.keys())})
89
90
 
90
91
  exposures_dataset = xr.Dataset(
91
92
  data_vars={"exposure": (tuple(coords.keys()), data)},
@@ -40,6 +40,12 @@ def prepare_survival_data_for_conditional_binomial(observations: xr.Dataset) ->
40
40
  nsurv.isel(time=list(range(0, len(nsurv.time)-1))).values
41
41
  ]).astype(int))})
42
42
 
43
+ observations = observations.assign_coords({
44
+ "survivors_at_start": (("id", "time"), np.broadcast_to(
45
+ nsurv.isel(time=0).values.reshape(-1,1),
46
+ shape=nsurv.shape
47
+ ).astype(int))})
48
+
43
49
  return observations
44
50
 
45
51
 
@@ -1,10 +1,3 @@
1
- """
2
- GUTS models
3
-
4
- TODO: Import guts models from mempy and update to work well with jax
5
- TODO: Based on this implement the bufferguts model in mempy and import in the
6
- bufferguts case-study
7
- """
8
1
  from functools import partial
9
2
  from pathlib import Path
10
3
 
@@ -19,73 +12,37 @@ from pymob.solvers.symbolic import (
19
12
  PiecewiseSymbolicODESolver, FunctionPythonCode, get_return_arguments, dX_dt2X
20
13
  )
21
14
 
22
- from mempy.model import (
23
- RED_IT,
24
- RED_SD,
25
- RED_IT_DA,
26
- RED_SD_DA,
27
- RED_IT_IA,
28
- RED_SD_IA,
29
- BufferGUTS_IT,
30
- BufferGUTS_IT_CA,
31
- BufferGUTS_IT_DA
32
- )
15
+ class RED_SD:
16
+ """Simplest guts model, mainly for testing"""
17
+ @staticmethod
18
+ def _rhs_jax(t, y, x_in, kd, b, m, hb):
19
+ D, H = y
20
+ dD_dt = kd * (x_in.evaluate(t) - D)
21
+ dH_dt = b * jnp.maximum(0.0, D - m) + hb
22
+ return dD_dt, dH_dt
23
+
24
+ @staticmethod
25
+ def _solver_post_processing(results, t, interpolation):
26
+ results["survival"] = jnp.exp(-results["H"])
27
+ results["exposure"] = jax.vmap(interpolation.evaluate)(t)
28
+ return results
33
29
 
34
30
  red_sd = RED_SD._rhs_jax
35
- red_it = RED_IT._rhs_jax
31
+ red_sd_post_processing = RED_SD._solver_post_processing
36
32
 
33
+ class RED_SD_DA(RED_SD):
34
+ @staticmethod
35
+ def _rhs_jax(t, y, x_in, kd, w, b, m, hb):
36
+ D, H = y
37
+ dD_dt = kd * (x_in.evaluate(t) - D)
38
+ dH_dt = b * jnp.maximum(0.0, jnp.sum(w * D) - m) + hb
39
+ return dD_dt, dH_dt
37
40
 
38
- def p_survival(results, t, interpolation, z, k_k, h_b):
39
- """Computes the stochastic death survival probability after computing
40
-
41
- """
42
- # calculate survival
43
- results["exposure"] = jax.vmap(interpolation.evaluate)(t)
44
- p_surv = survival_jax(t, results["D"], z, k_k, h_b)
45
- results["survival"] = p_surv
46
- results["lethality"] = 1 - p_surv
47
- return results
48
-
49
- def it_post_processing(results, t, interpolation, alpha, beta, h_b, eps):
50
- results["exposure"] = jax.vmap(interpolation.evaluate)(t)
51
- p_surv = survival_IT_jax(t, results["D"], alpha, beta, h_b, eps)
52
- results["survival"] = p_surv
53
- results["H"] = - jnp.log(p_surv)
54
- return results
55
-
56
- def post_exposure(results, t, interpolation):
57
- results["survival"] = jnp.exp(-results["H"])
58
- results["exposure"] = jax.vmap(interpolation.evaluate)(t)
59
- return results
60
-
61
-
62
- def no_post_processing(results):
63
- return results
64
-
65
-
66
- @jax.jit
67
- def survival_jax(t, damage, z, kk, h_b):
68
- """
69
- survival probability derived from hazard
70
- first calculate cumulative Hazard by integrating hazard cumulatively over t
71
- then calculate the resulting survival probability
72
- It was checked that `survival_jax` behaves exactly the same as `survival`
73
- """
74
- hazard = kk * jnp.where(damage - z < 0, 0, damage - z) + h_b
75
- H = jnp.array([jax.scipy.integrate.trapezoid(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
76
- # H = jnp.array([jnp.trapz(hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))])
77
- S = jnp.exp(-H)
78
-
79
- return S
80
-
81
- @jax.jit
82
- def survival_IT_jax(t, damage, alpha, beta, h_b, eps):
83
- d_max = jnp.squeeze(jnp.array([jnp.max(damage[:i+1])+eps for i in range(len(t))]))
84
- F = jnp.where(d_max > 0, 1.0 / (1.0 + (d_max / alpha) ** -beta), 0)
85
- S = 1.0 * (jnp.array([1.0], dtype=float) - F) * jnp.exp(-h_b * t)
86
- return S
87
-
88
- def guts_jax(t, y, C_0, k_d, z, b, h_b):
41
+ red_sd_da = RED_SD_DA._rhs_jax
42
+ red_sd_da_post_processing = RED_SD_DA._solver_post_processing
43
+
44
+
45
+ def guts_constant_exposure(t, y, C_0, k_d, z, b, h_b):
89
46
  # for constant exposure
90
47
  D, H, S = y
91
48
  dD_dt = k_d * (C_0 - D)
@@ -97,16 +54,6 @@ def guts_jax(t, y, C_0, k_d, z, b, h_b):
97
54
 
98
55
  return dD_dt, dH_dt, dS_dt
99
56
 
100
-
101
- def RED_IT(t, y, x_in, kd):
102
- D, = y
103
- C = x_in.evaluate(t)
104
-
105
- dD_dt = kd * (C - D)
106
-
107
- return (dD_dt, )
108
-
109
-
110
57
  def guts_variable_exposure(t, y, x_in, k_d, z, b, h_b):
111
58
  # for constant exposure
112
59
  D, H, S = y
@@ -0,0 +1,160 @@
1
+ from functools import partial
2
+ import numpy as np
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from numpyro.infer import Predictive
6
+ from typing import Literal
7
+
8
+
9
+ def survival_predictions(
10
+ probs, n_trials,
11
+ eps: float = 0.0,
12
+ seed=1,
13
+ mode: Literal["survival", "lethality", "deaths"] = "survival"
14
+ ):
15
+ """Generate predictions for survival based on a multinomial survival distribution
16
+
17
+ Parameters
18
+ ----------
19
+ probs : ArrayLike
20
+ 2D Array denoting the multinomial probabilities of deaths for each time interval
21
+ per experiment
22
+ dims=(experiment, time)
23
+ n_trials : ArrayLike
24
+ 1D Array denoting the number of organisms at the beginning in each experiment
25
+ dims = (experiment,)
26
+ seed : int, optional
27
+ Seed for the random number generator, by default 1
28
+ mode : str, optional
29
+ How should the random draws be returned?
30
+ - survival: Decreasing from n_trials to 0
31
+ - lethality: Increasing from 0 to n_trials
32
+ - deaths: Between 0 and n_trials in each interval. Summing to n_trials
33
+ """
34
+
35
+ def survival_to_death_probs(pr_survival):
36
+ # truncate here, because numeric errors below the solver tolerance can
37
+ # lead to negative values in the difference. This needs to be cured here
38
+ pr_survival_ = np.trunc(pr_survival / eps) * eps
39
+ pr_death = pr_survival_[:-1] - pr_survival_[1:]
40
+
41
+ pr_death = np.concatenate([
42
+ # concatenate a zero at the beginning in order to "simulate" no
43
+ # deaths at T = 0
44
+ jnp.zeros((1,)),
45
+ # Delta S
46
+ pr_death,
47
+ # The remaining mortility as T -> infinity
48
+ jnp.ones((1,))-pr_death.sum()
49
+ ])
50
+
51
+ # make sure the vector is not zero or 1 (this is always problematic for
52
+ # probabilities) and make sure the vector sums to 1
53
+ pr_death = np.clip(pr_death, eps, 1-eps)
54
+ pr_death = pr_death / pr_death.sum()
55
+ return pr_death
56
+
57
+ rng = np.random.default_rng(seed)
58
+ deaths = jnp.array(list(map(
59
+ lambda n, pr_survival: rng.multinomial(
60
+ n=n, pvals=survival_to_death_probs(pr_survival)
61
+ ),
62
+ n_trials,
63
+ probs
64
+ )))
65
+
66
+ # remove the last observations to trim off the simulated unobserved mortality
67
+ deaths = deaths[:, :-1]
68
+
69
+ if mode == "deaths":
70
+ return deaths
71
+ elif mode == "lethality":
72
+ return deaths.cumsum(axis=1)
73
+ elif mode == "survival":
74
+ return np.expand_dims(n_trials, axis=1) - deaths.cumsum(axis=1)
75
+ else:
76
+ raise NotImplementedError(
77
+ f"Mode {mode} is not implemented."+
78
+ "Use one of 'survival', 'lethality', or 'deaths'."
79
+ )
80
+
81
+ def posterior_predictions(sim, idata, seed=None):
82
+ """Make posterior predictions for survival data"""
83
+ if seed is None:
84
+ seed = sim.config.simulation.seed
85
+
86
+ n = idata.posterior.dims["draw"]
87
+ c = idata.posterior.dims["chain"]
88
+
89
+ key = jax.random.PRNGKey(seed)
90
+
91
+ obs, masks = sim.inferer.observation_parser()
92
+
93
+ model_kwargs = sim.inferer.preprocessing(obs=obs, masks=masks)
94
+ if sim.config.inference_numpyro.user_defined_error_model is None:
95
+ model_kwargs["obs"]["survival"] = None
96
+
97
+ # prepare model
98
+ model = partial(
99
+ sim.inferer.inference_model,
100
+ solver=sim.inferer.evaluator,
101
+ **model_kwargs
102
+ )
103
+
104
+ posterior_samples = {
105
+ k: np.array(v["data"]) for k, v
106
+ in idata.unconstrained_posterior.to_dict()["data_vars"].items()
107
+ }
108
+
109
+ predictive = Predictive(
110
+ model, posterior_samples=posterior_samples,
111
+ num_samples=n, batch_ndims=2
112
+ )
113
+
114
+ samples = predictive(key)
115
+
116
+ if sim.config.inference_numpyro.user_defined_error_model is not None:
117
+ chains = []
118
+ for i in range(c):
119
+ # TODO: Switch to vmap and jit, but it did not work, so if you do it TEST IT!!!!!!!
120
+ predictions = list(map(
121
+ partial(
122
+ survival_predictions,
123
+ n_trials=obs["survival"][:, 0].astype(int),
124
+ eps=obs["eps"],
125
+ seed=seed,
126
+ mode="survival",
127
+ ),
128
+ samples["survival"][i]
129
+ ))
130
+ chains.append(predictions)
131
+
132
+ posterior_predictive = {"survival_obs": np.array(chains)}
133
+ else:
134
+ posterior_predictive = {"survival_obs": samples.pop("survival_obs")}
135
+
136
+
137
+ new_idata = sim.inferer.to_arviz_idata(
138
+ posterior=samples,
139
+ posterior_predictive=posterior_predictive,
140
+ n_draws=n,
141
+ n_chains=c
142
+ )
143
+
144
+ # update chain names in case they were subselected (clustering)
145
+ # new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
146
+ new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
147
+
148
+ # assert the new posterior matches the old posterior
149
+ tol = sim.config.jaxsolver.atol * 100
150
+ abs_diff_posterior = np.abs(idata.posterior - new_idata.posterior)
151
+ np.testing.assert_array_less(abs_diff_posterior.mean().to_array(), tol)
152
+
153
+ fit_tol = tol * sim.coordinates["time"].max()
154
+ abs_diff_fits = np.abs(new_idata.posterior_model_fits - idata.posterior_model_fits)
155
+ np.testing.assert_array_less(abs_diff_fits.mean().to_array(), fit_tol)
156
+
157
+ idata.posterior_model_fits = new_idata.posterior_model_fits
158
+ idata.posterior_predictive = new_idata.posterior_predictive
159
+
160
+ return idata
@@ -1,6 +1,7 @@
1
1
  from . import base
2
2
  from . import ecx
3
3
  from . import report
4
+ from . import utils
4
5
 
5
6
  from .base import (
6
7
  GutsBase,
@@ -11,4 +12,12 @@ from .base import (
11
12
  from .ecx import ECxEstimator, LPxEstimator
12
13
  from .report import GutsReport
13
14
 
14
- from .mempy import PymobSimulator
15
+ from .mempy import PymobSimulator
16
+ from .utils import (
17
+ GutsBaseError
18
+ )
19
+
20
+ from .constructors import (
21
+ construct_sim_from_config,
22
+ load_idata
23
+ )