guts-base 2.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,233 @@
1
+ from scipy.stats._multivariate import multinomial_frozen, multi_rv_generic
2
+ from scipy.stats._discrete_distns import binom
3
+ import numpy as np
4
+
5
+ # TODO: Update example
6
+ # TODO: Use forward filling again to deal with nans (these were automatically solved by using rv_discrete)
7
+ # TODO: When all is complete delete conditional_bionm.py
8
+
9
+ def conditional_prob_neglogexp(p, p_init=1.0, eps=1e-12):
10
+ p_ = np.column_stack([p_init, p])
11
+ # p needs to be clipped zero division does not occurr, if the last p values are zero
12
+ p_clipped = np.clip(p_, eps, 1.0)
13
+ # convert to logscale
14
+ neg_log_p = -np.log(p_clipped)
15
+ # exponent substraction is numerically more stable than division
16
+ return np.exp(neg_log_p[:, :-1] - neg_log_p[:, 1:])
17
+
18
+
19
+ def conditional_prob(p, p_init=1.0):
20
+ p_ = np.concatenate([p_init, p[:]])
21
+ # divide later though previous probability
22
+ return p_[1:] / p_[:-1]
23
+
24
+
25
+ def conditional_prob_from_neglogp(p, p_init=1.0):
26
+ p_ = np.concatenate([[p_init], p[:]])
27
+ return np.exp(p_[:-1] - p_[1:])
28
+
29
+
30
+ class conditional_binomial(multi_rv_generic):
31
+ """
32
+ A scipy distribution for a conditional survival probability distribution
33
+
34
+ Parameters
35
+ ----------
36
+ k: Array
37
+ Number of repeated positive observations of a quantity that
38
+ can only decrease. k must be monotonically decreasing (e.g. survivors)
39
+
40
+ p: Array
41
+ survival function of the repeated observation.
42
+ p must be monotonically decreasing
43
+
44
+ n: int
45
+ The starting number of positive observations (e.g. initial number of organisms
46
+ in a survival trial)
47
+
48
+ Example
49
+ -------
50
+ Define a survival function (using a beta cdf and use it to make multinomial draws)
51
+ to simulate survivals from repeated observations
52
+ >>> n = 100
53
+ >>> B = stats.beta(5, 5)
54
+ >>> p = 1 - B.cdf(np.linspace(0, 1))
55
+ >>> s = stats.multinomial(n, p=np.diff(p)*-1).rvs()[0]
56
+ >>> s = n - s.cumsum()
57
+
58
+ construct a frozen distribution
59
+ >>> from guts_base.prob import conditional_survival
60
+ >>> S = conditional_survival(n=n, p=p[1:])
61
+
62
+ Compute the pmf
63
+ >>> S.pmf(s)
64
+
65
+ Compute the logpmf
66
+ >>> S.logpmf(s)
67
+
68
+ Draw random samples
69
+ >>> samples = S.rvs(size=(1000, 49))
70
+
71
+ Plot the observational variation of a given survival function under repeated
72
+ Observations
73
+ >>> plt.plot(samples.T, color="black", alpha=.02)
74
+
75
+
76
+ """
77
+ p_init = 1.0
78
+ eps = 1e-12
79
+ shapes = ("n", "p")
80
+ def __call__(self, n, p, seed=None):
81
+ """Create a frozen multinomial distribution.
82
+
83
+ See `multinomial_frozen` for more information.
84
+ """
85
+ return conditional_binomial_frozen(n, p, seed)
86
+
87
+ def _preprocess_params(self, x, n, p):
88
+ x = np.array(x, ndmin=2)
89
+ p = np.array(p, ndmin=2)
90
+ p_init = np.broadcast_to(self.p_init, p[:, [0]].shape)
91
+ n_init = np.broadcast_to(n, x[:,[0]].shape)
92
+ # nan filling is not necessary, because nans are thrown out this shifts the
93
+ # p vector to where it belongs
94
+ n_ = np.column_stack([n_init[:,[0]], x[:, :-1]])
95
+ p_conditional = conditional_prob_neglogexp(p, p_init=p_init, eps=self.eps)
96
+ return x, n_, p_conditional
97
+
98
+ def logpmf(self, x, n, p):
99
+ x, n, p = self._preprocess_params(x, n, p)
100
+ return binom._logpmf(x, n, p)
101
+
102
+ def pmf(self, x, n, p):
103
+ x, n, p = self._preprocess_params(x, n, p)
104
+ return binom._pmf(x, n, p)
105
+
106
+ def rvs(self, n, p, size=None, random_state=None):
107
+ r"""
108
+ Generate random samples from the conditional binomial distribution.
109
+
110
+ The random generation process is described by the following equations:
111
+
112
+ The lethality matrix is initialized with
113
+ \[
114
+ L \leftarrow \mathbf{0} \in \mathbb{N}_0^{K \times T}
115
+ \]
116
+
117
+ where $K$ is the number of treatments and $T$ is the number of observations.
118
+ For each observation $t$ in \( t = 1, \ldots, T \):
119
+ \[
120
+ L_{k,t} \sim \text{Binomial}\left(n=N_{k} - \sum_{t'=0}^{t-1} L_{k,t'},~p=1 - \frac{S_{k,t}}{S_{k,t-1}}\right)
121
+ \]
122
+ Where:
123
+ - \( L_{k,t} \) is the number of organisms dying between $t$ and $t-1$ in each treatment $k$.
124
+ - \( N_{k} \) is the initial number of alive organism in each treatment \( k \) at the start of the observations $t=0$.
125
+ - \( S_{k,t} \) is survival function (the probability of being alive at a given time) computed for each observations $t$ in $T$.
126
+
127
+ Finally:
128
+ \[
129
+ S_{k,t}^{obs} = N_{k} - \sum_{t'=0}^{t} L_{k,t'}
130
+ \]
131
+
132
+
133
+ Parameters
134
+ ----------
135
+ n: int or Array of int
136
+ The initial number of positive observations (e.g. initial number of organisms
137
+ in a survival trial).
138
+ p: Array
139
+ Survival function of the repeated observation.
140
+ p must be monotonically decreasing.
141
+ size: tuple, optional
142
+ Shape of the random variates to generate.
143
+ random_state: RandomState or int, optional
144
+ If seed is not None, it will be used by the RandomState to generate
145
+ random time steps.
146
+
147
+ Returns
148
+ -------
149
+ ndarray
150
+ Random samples from the conditional binomial distribution representing the
151
+ number of entities surviving at each time step.
152
+
153
+
154
+ """
155
+ random_state = self._get_random_state(random_state)
156
+
157
+ p = np.array(p, ndmin=2)
158
+ p_init = np.broadcast_to(self.p_init, p[:, [0]].shape)
159
+ n_init = np.broadcast_to(n, p.shape)
160
+
161
+ p_conditional = conditional_prob_neglogexp(p, p_init=p_init, eps=self.eps)
162
+
163
+ if size is None:
164
+ size = p.shape
165
+
166
+ # axis-0 is the batch dimension
167
+ # axis-1 is the time dimension (probability)
168
+ L = np.zeros(shape=size)
169
+
170
+ for i in range(L.shape[1]):
171
+ # calculate the binomial response of the conditional survival
172
+ # i.e. the probability to die within an interval conditional on
173
+ # having survived until the beginning of that interval
174
+ L[..., i] = random_state.binomial(
175
+ p=1 - p_conditional[:, i],
176
+ n=n_init[:, i]-L.sum(axis=-1).astype(int),
177
+ size=size[slice(len(size)-1)]
178
+ )
179
+
180
+ return n_init-L.cumsum(axis=-1)
181
+
182
+ conditional_survival = conditional_binomial()
183
+
184
+
185
+ class conditional_binomial_frozen(multinomial_frozen):
186
+
187
+ def __init__(self, n, p, seed=None):
188
+ self._dist = conditional_binomial(seed)
189
+ self.n, self.p = n, p
190
+ # self.n, self.p, self.npcond = self._dist._process_parameters(n, p)
191
+
192
+ # # monkey patch self._dist
193
+ # def _process_parameters(n, p):
194
+ # return self.n, self.p, self.npcond
195
+
196
+ # self._dist._process_parameters = _process_parameters
197
+
198
+ def logpmf(self, x):
199
+ return self._dist.logpmf(x, n=self.n, p=self.p)
200
+
201
+ def pmf(self, x):
202
+ return self._dist.pmf(x, n=self.n, p=self.p)
203
+
204
+ def rvs(self, n=None, size=None, random_state=None):
205
+ if n is None:
206
+ n = self.n
207
+ return self._dist.rvs(n=self.n, p=self.p, size=size, random_state=random_state)
208
+
209
+
210
+ if __name__ == "__main__":
211
+
212
+ S = conditional_survival(n=10, p=[[0.8,0.4,0.2],[0.8,0.4,0.2]])
213
+
214
+
215
+ prob = S.pmf([[10,5,2], [7,2,0]])
216
+ prob = S.pmf([10,5,2])
217
+ sample = S.rvs()
218
+ sample = S.rvs(size=(10, 2, 3))
219
+
220
+
221
+ S = conditional_survival(n=10, p=[0.8,0.4,0.2])
222
+ prob = S.pmf([[10,5,2], [7,2,0]])
223
+ prob = S.pmf([10,5,2])
224
+ S.rvs(size=(10,3))
225
+ S.rvs()
226
+
227
+ sample
228
+
229
+ S = conditional_survival(n=[[10],[2000]], p=[[0.8,0.4,0.2],[0.8,0.4,0.2]])
230
+ prob = S.pmf([[10,5,2], [7,2,0]])
231
+ S.rvs()
232
+
233
+
@@ -0,0 +1,164 @@
1
+ from functools import partial
2
+ import numpy as np
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpyro
6
+ from numpyro.infer import Predictive
7
+ from typing import Literal
8
+
9
+
10
+ def survival_predictions(
11
+ probs, n_trials,
12
+ eps: float = 0.0,
13
+ seed=1,
14
+ mode: Literal["survival", "lethality", "deaths"] = "survival"
15
+ ):
16
+ """Generate predictions for survival based on a multinomial survival distribution
17
+
18
+ Parameters
19
+ ----------
20
+ probs : ArrayLike
21
+ 2D Array denoting the multinomial probabilities of deaths for each time interval
22
+ per experiment
23
+ dims=(experiment, time)
24
+ n_trials : ArrayLike
25
+ 1D Array denoting the number of organisms at the beginning in each experiment
26
+ dims = (experiment,)
27
+ seed : int, optional
28
+ Seed for the random number generator, by default 1
29
+ mode : str, optional
30
+ How should the random draws be returned?
31
+ - survival: Decreasing from n_trials to 0
32
+ - lethality: Increasing from 0 to n_trials
33
+ - deaths: Between 0 and n_trials in each interval. Summing to n_trials
34
+ """
35
+
36
+ def survival_to_death_probs(pr_survival):
37
+ # truncate here, because numeric errors below the solver tolerance can
38
+ # lead to negative values in the difference. This needs to be cured here
39
+ pr_survival_ = np.trunc(pr_survival / eps) * eps
40
+ pr_death = pr_survival_[:-1] - pr_survival_[1:]
41
+
42
+ pr_death = np.concatenate([
43
+ # concatenate a zero at the beginning in order to "simulate" no
44
+ # deaths at T = 0
45
+ jnp.zeros((1,)),
46
+ # Delta S
47
+ pr_death,
48
+ # The remaining mortility as T -> infinity
49
+ jnp.ones((1,))-pr_death.sum()
50
+ ])
51
+
52
+ # make sure the vector is not zero or 1 (this is always problematic for
53
+ # probabilities) and make sure the vector sums to 1
54
+ pr_death = np.clip(pr_death, eps, 1-eps)
55
+ pr_death = pr_death / pr_death.sum()
56
+ return pr_death
57
+
58
+ rng = np.random.default_rng(seed)
59
+ deaths = jnp.array(list(map(
60
+ lambda n, pr_survival: rng.multinomial(
61
+ n=n, pvals=survival_to_death_probs(pr_survival)
62
+ ),
63
+ n_trials,
64
+ probs
65
+ )))
66
+
67
+ # remove the last observations to trim off the simulated unobserved mortality
68
+ deaths = deaths[:, :-1]
69
+
70
+ if mode == "deaths":
71
+ return deaths
72
+ elif mode == "lethality":
73
+ return deaths.cumsum(axis=1)
74
+ elif mode == "survival":
75
+ return np.expand_dims(n_trials, axis=1) - deaths.cumsum(axis=1)
76
+ else:
77
+ raise NotImplementedError(
78
+ f"Mode {mode} is not implemented."+
79
+ "Use one of 'survival', 'lethality', or 'deaths'."
80
+ )
81
+
82
+ def posterior_predictions(sim, idata, seed=None):
83
+ """Make posterior predictions for survival data"""
84
+ if seed is None:
85
+ seed = sim.config.simulation.seed
86
+
87
+ n = idata.posterior.dims["draw"]
88
+ c = idata.posterior.dims["chain"]
89
+
90
+ key = jax.random.PRNGKey(seed)
91
+
92
+ obs, masks = sim.inferer.observation_parser()
93
+
94
+ model_kwargs = sim.inferer.preprocessing(obs=obs, masks=masks)
95
+ if sim.config.inference_numpyro.user_defined_error_model is None:
96
+ model_kwargs["obs"]["survival"] = None
97
+
98
+ # prepare model
99
+ model = partial(
100
+ sim.inferer.inference_model,
101
+ solver=sim.inferer.evaluator,
102
+ **model_kwargs
103
+ )
104
+
105
+ posterior_samples = {
106
+ k: np.array(v["data"]) for k, v
107
+ in idata.unconstrained_posterior.to_dict()["data_vars"].items()
108
+ }
109
+
110
+ predictive = Predictive(
111
+ model, posterior_samples=posterior_samples,
112
+ num_samples=n, batch_ndims=2
113
+ )
114
+
115
+ samples = predictive(key)
116
+
117
+ if sim.config.inference_numpyro.user_defined_error_model is not None:
118
+ chains = []
119
+ for i in range(c):
120
+ # TODO: Switch to vmap and jit, but it did not work, so if you do it TEST IT!!!!!!!
121
+ predictions = list(map(
122
+ partial(
123
+ survival_predictions,
124
+ n_trials=obs["survival"][:, 0].astype(int),
125
+ eps=obs["eps"],
126
+ seed=seed,
127
+ mode="survival",
128
+ ),
129
+ samples["survival"][i]
130
+ ))
131
+ chains.append(predictions)
132
+
133
+ posterior_predictive = {"survival_obs": np.array(chains)}
134
+ else:
135
+ posterior_predictive = {"survival_obs": samples.pop("survival_obs")}
136
+
137
+
138
+ new_idata = sim.inferer.to_arviz_idata(
139
+ posterior=samples,
140
+ posterior_predictive=posterior_predictive,
141
+ n_draws=n,
142
+ n_chains=c
143
+ )
144
+
145
+ # update chain names in case they were subselected (clustering)
146
+ # new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
147
+ new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
148
+
149
+ # assert the new posterior matches the old posterior
150
+ tol = sim.config.jaxsolver.atol * 100
151
+ abs_diff_posterior = np.abs(idata.posterior - new_idata.posterior)
152
+ np.testing.assert_array_less(abs_diff_posterior.mean().to_array(), tol)
153
+
154
+ fit_tol = tol * sim.coordinates["time"].max()
155
+ abs_diff_fits = np.abs(new_idata.posterior_model_fits - idata.posterior_model_fits)
156
+ np.testing.assert_array_less(abs_diff_fits.mean().to_array(), fit_tol)
157
+
158
+ idata.posterior_model_fits = new_idata.posterior_model_fits
159
+ idata.posterior_predictive = new_idata.posterior_predictive
160
+
161
+ return idata
162
+
163
+
164
+
@@ -0,0 +1,28 @@
1
+ from . import base
2
+ from . import ecx
3
+ from . import report
4
+ from . import utils
5
+ from . import transformer
6
+ from . import config
7
+ from . import units
8
+
9
+ from .base import (
10
+ GutsBase,
11
+ GutsSimulationConstantExposure,
12
+ GutsSimulationVariableExposure
13
+ )
14
+
15
+ from .ecx import ECxEstimator, LPxEstimator
16
+ from .report import GutsReport, ParameterConverter
17
+
18
+ from .mempy import PymobSimulator
19
+ from .utils import (
20
+ GutsBaseError
21
+ )
22
+
23
+ from .constructors import (
24
+ construct_sim_from_config,
25
+ load_idata
26
+ )
27
+
28
+ from .config import GutsBaseConfig