guts-base 2.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guts_base/__init__.py +15 -0
- guts_base/data/__init__.py +35 -0
- guts_base/data/expydb.py +248 -0
- guts_base/data/generator.py +191 -0
- guts_base/data/openguts.py +296 -0
- guts_base/data/preprocessing.py +55 -0
- guts_base/data/survival.py +148 -0
- guts_base/data/time_of_death.py +595 -0
- guts_base/data/utils.py +8 -0
- guts_base/mod.py +332 -0
- guts_base/plot.py +201 -0
- guts_base/prob/__init__.py +13 -0
- guts_base/prob/binom.py +18 -0
- guts_base/prob/conditional_binom.py +118 -0
- guts_base/prob/conditional_binom_mv.py +233 -0
- guts_base/prob/predictions.py +164 -0
- guts_base/sim/__init__.py +28 -0
- guts_base/sim/base.py +1286 -0
- guts_base/sim/config.py +170 -0
- guts_base/sim/constructors.py +31 -0
- guts_base/sim/ecx.py +585 -0
- guts_base/sim/mempy.py +290 -0
- guts_base/sim/report.py +405 -0
- guts_base/sim/transformer.py +548 -0
- guts_base/sim/units.py +313 -0
- guts_base/sim/utils.py +10 -0
- guts_base-2.0.0b0.dist-info/METADATA +853 -0
- guts_base-2.0.0b0.dist-info/RECORD +32 -0
- guts_base-2.0.0b0.dist-info/WHEEL +5 -0
- guts_base-2.0.0b0.dist-info/entry_points.txt +3 -0
- guts_base-2.0.0b0.dist-info/licenses/LICENSE +674 -0
- guts_base-2.0.0b0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
from scipy.stats._multivariate import multinomial_frozen, multi_rv_generic
|
|
2
|
+
from scipy.stats._discrete_distns import binom
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
# TODO: Update example
|
|
6
|
+
# TODO: Use forward filling again to deal with nans (these were automatically solved by using rv_discrete)
|
|
7
|
+
# TODO: When all is complete delete conditional_bionm.py
|
|
8
|
+
|
|
9
|
+
def conditional_prob_neglogexp(p, p_init=1.0, eps=1e-12):
|
|
10
|
+
p_ = np.column_stack([p_init, p])
|
|
11
|
+
# p needs to be clipped zero division does not occurr, if the last p values are zero
|
|
12
|
+
p_clipped = np.clip(p_, eps, 1.0)
|
|
13
|
+
# convert to logscale
|
|
14
|
+
neg_log_p = -np.log(p_clipped)
|
|
15
|
+
# exponent substraction is numerically more stable than division
|
|
16
|
+
return np.exp(neg_log_p[:, :-1] - neg_log_p[:, 1:])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def conditional_prob(p, p_init=1.0):
|
|
20
|
+
p_ = np.concatenate([p_init, p[:]])
|
|
21
|
+
# divide later though previous probability
|
|
22
|
+
return p_[1:] / p_[:-1]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def conditional_prob_from_neglogp(p, p_init=1.0):
|
|
26
|
+
p_ = np.concatenate([[p_init], p[:]])
|
|
27
|
+
return np.exp(p_[:-1] - p_[1:])
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class conditional_binomial(multi_rv_generic):
|
|
31
|
+
"""
|
|
32
|
+
A scipy distribution for a conditional survival probability distribution
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
k: Array
|
|
37
|
+
Number of repeated positive observations of a quantity that
|
|
38
|
+
can only decrease. k must be monotonically decreasing (e.g. survivors)
|
|
39
|
+
|
|
40
|
+
p: Array
|
|
41
|
+
survival function of the repeated observation.
|
|
42
|
+
p must be monotonically decreasing
|
|
43
|
+
|
|
44
|
+
n: int
|
|
45
|
+
The starting number of positive observations (e.g. initial number of organisms
|
|
46
|
+
in a survival trial)
|
|
47
|
+
|
|
48
|
+
Example
|
|
49
|
+
-------
|
|
50
|
+
Define a survival function (using a beta cdf and use it to make multinomial draws)
|
|
51
|
+
to simulate survivals from repeated observations
|
|
52
|
+
>>> n = 100
|
|
53
|
+
>>> B = stats.beta(5, 5)
|
|
54
|
+
>>> p = 1 - B.cdf(np.linspace(0, 1))
|
|
55
|
+
>>> s = stats.multinomial(n, p=np.diff(p)*-1).rvs()[0]
|
|
56
|
+
>>> s = n - s.cumsum()
|
|
57
|
+
|
|
58
|
+
construct a frozen distribution
|
|
59
|
+
>>> from guts_base.prob import conditional_survival
|
|
60
|
+
>>> S = conditional_survival(n=n, p=p[1:])
|
|
61
|
+
|
|
62
|
+
Compute the pmf
|
|
63
|
+
>>> S.pmf(s)
|
|
64
|
+
|
|
65
|
+
Compute the logpmf
|
|
66
|
+
>>> S.logpmf(s)
|
|
67
|
+
|
|
68
|
+
Draw random samples
|
|
69
|
+
>>> samples = S.rvs(size=(1000, 49))
|
|
70
|
+
|
|
71
|
+
Plot the observational variation of a given survival function under repeated
|
|
72
|
+
Observations
|
|
73
|
+
>>> plt.plot(samples.T, color="black", alpha=.02)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
p_init = 1.0
|
|
78
|
+
eps = 1e-12
|
|
79
|
+
shapes = ("n", "p")
|
|
80
|
+
def __call__(self, n, p, seed=None):
|
|
81
|
+
"""Create a frozen multinomial distribution.
|
|
82
|
+
|
|
83
|
+
See `multinomial_frozen` for more information.
|
|
84
|
+
"""
|
|
85
|
+
return conditional_binomial_frozen(n, p, seed)
|
|
86
|
+
|
|
87
|
+
def _preprocess_params(self, x, n, p):
|
|
88
|
+
x = np.array(x, ndmin=2)
|
|
89
|
+
p = np.array(p, ndmin=2)
|
|
90
|
+
p_init = np.broadcast_to(self.p_init, p[:, [0]].shape)
|
|
91
|
+
n_init = np.broadcast_to(n, x[:,[0]].shape)
|
|
92
|
+
# nan filling is not necessary, because nans are thrown out this shifts the
|
|
93
|
+
# p vector to where it belongs
|
|
94
|
+
n_ = np.column_stack([n_init[:,[0]], x[:, :-1]])
|
|
95
|
+
p_conditional = conditional_prob_neglogexp(p, p_init=p_init, eps=self.eps)
|
|
96
|
+
return x, n_, p_conditional
|
|
97
|
+
|
|
98
|
+
def logpmf(self, x, n, p):
|
|
99
|
+
x, n, p = self._preprocess_params(x, n, p)
|
|
100
|
+
return binom._logpmf(x, n, p)
|
|
101
|
+
|
|
102
|
+
def pmf(self, x, n, p):
|
|
103
|
+
x, n, p = self._preprocess_params(x, n, p)
|
|
104
|
+
return binom._pmf(x, n, p)
|
|
105
|
+
|
|
106
|
+
def rvs(self, n, p, size=None, random_state=None):
|
|
107
|
+
r"""
|
|
108
|
+
Generate random samples from the conditional binomial distribution.
|
|
109
|
+
|
|
110
|
+
The random generation process is described by the following equations:
|
|
111
|
+
|
|
112
|
+
The lethality matrix is initialized with
|
|
113
|
+
\[
|
|
114
|
+
L \leftarrow \mathbf{0} \in \mathbb{N}_0^{K \times T}
|
|
115
|
+
\]
|
|
116
|
+
|
|
117
|
+
where $K$ is the number of treatments and $T$ is the number of observations.
|
|
118
|
+
For each observation $t$ in \( t = 1, \ldots, T \):
|
|
119
|
+
\[
|
|
120
|
+
L_{k,t} \sim \text{Binomial}\left(n=N_{k} - \sum_{t'=0}^{t-1} L_{k,t'},~p=1 - \frac{S_{k,t}}{S_{k,t-1}}\right)
|
|
121
|
+
\]
|
|
122
|
+
Where:
|
|
123
|
+
- \( L_{k,t} \) is the number of organisms dying between $t$ and $t-1$ in each treatment $k$.
|
|
124
|
+
- \( N_{k} \) is the initial number of alive organism in each treatment \( k \) at the start of the observations $t=0$.
|
|
125
|
+
- \( S_{k,t} \) is survival function (the probability of being alive at a given time) computed for each observations $t$ in $T$.
|
|
126
|
+
|
|
127
|
+
Finally:
|
|
128
|
+
\[
|
|
129
|
+
S_{k,t}^{obs} = N_{k} - \sum_{t'=0}^{t} L_{k,t'}
|
|
130
|
+
\]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
n: int or Array of int
|
|
136
|
+
The initial number of positive observations (e.g. initial number of organisms
|
|
137
|
+
in a survival trial).
|
|
138
|
+
p: Array
|
|
139
|
+
Survival function of the repeated observation.
|
|
140
|
+
p must be monotonically decreasing.
|
|
141
|
+
size: tuple, optional
|
|
142
|
+
Shape of the random variates to generate.
|
|
143
|
+
random_state: RandomState or int, optional
|
|
144
|
+
If seed is not None, it will be used by the RandomState to generate
|
|
145
|
+
random time steps.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
ndarray
|
|
150
|
+
Random samples from the conditional binomial distribution representing the
|
|
151
|
+
number of entities surviving at each time step.
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
"""
|
|
155
|
+
random_state = self._get_random_state(random_state)
|
|
156
|
+
|
|
157
|
+
p = np.array(p, ndmin=2)
|
|
158
|
+
p_init = np.broadcast_to(self.p_init, p[:, [0]].shape)
|
|
159
|
+
n_init = np.broadcast_to(n, p.shape)
|
|
160
|
+
|
|
161
|
+
p_conditional = conditional_prob_neglogexp(p, p_init=p_init, eps=self.eps)
|
|
162
|
+
|
|
163
|
+
if size is None:
|
|
164
|
+
size = p.shape
|
|
165
|
+
|
|
166
|
+
# axis-0 is the batch dimension
|
|
167
|
+
# axis-1 is the time dimension (probability)
|
|
168
|
+
L = np.zeros(shape=size)
|
|
169
|
+
|
|
170
|
+
for i in range(L.shape[1]):
|
|
171
|
+
# calculate the binomial response of the conditional survival
|
|
172
|
+
# i.e. the probability to die within an interval conditional on
|
|
173
|
+
# having survived until the beginning of that interval
|
|
174
|
+
L[..., i] = random_state.binomial(
|
|
175
|
+
p=1 - p_conditional[:, i],
|
|
176
|
+
n=n_init[:, i]-L.sum(axis=-1).astype(int),
|
|
177
|
+
size=size[slice(len(size)-1)]
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return n_init-L.cumsum(axis=-1)
|
|
181
|
+
|
|
182
|
+
conditional_survival = conditional_binomial()
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class conditional_binomial_frozen(multinomial_frozen):
|
|
186
|
+
|
|
187
|
+
def __init__(self, n, p, seed=None):
|
|
188
|
+
self._dist = conditional_binomial(seed)
|
|
189
|
+
self.n, self.p = n, p
|
|
190
|
+
# self.n, self.p, self.npcond = self._dist._process_parameters(n, p)
|
|
191
|
+
|
|
192
|
+
# # monkey patch self._dist
|
|
193
|
+
# def _process_parameters(n, p):
|
|
194
|
+
# return self.n, self.p, self.npcond
|
|
195
|
+
|
|
196
|
+
# self._dist._process_parameters = _process_parameters
|
|
197
|
+
|
|
198
|
+
def logpmf(self, x):
|
|
199
|
+
return self._dist.logpmf(x, n=self.n, p=self.p)
|
|
200
|
+
|
|
201
|
+
def pmf(self, x):
|
|
202
|
+
return self._dist.pmf(x, n=self.n, p=self.p)
|
|
203
|
+
|
|
204
|
+
def rvs(self, n=None, size=None, random_state=None):
|
|
205
|
+
if n is None:
|
|
206
|
+
n = self.n
|
|
207
|
+
return self._dist.rvs(n=self.n, p=self.p, size=size, random_state=random_state)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
if __name__ == "__main__":
|
|
211
|
+
|
|
212
|
+
S = conditional_survival(n=10, p=[[0.8,0.4,0.2],[0.8,0.4,0.2]])
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
prob = S.pmf([[10,5,2], [7,2,0]])
|
|
216
|
+
prob = S.pmf([10,5,2])
|
|
217
|
+
sample = S.rvs()
|
|
218
|
+
sample = S.rvs(size=(10, 2, 3))
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
S = conditional_survival(n=10, p=[0.8,0.4,0.2])
|
|
222
|
+
prob = S.pmf([[10,5,2], [7,2,0]])
|
|
223
|
+
prob = S.pmf([10,5,2])
|
|
224
|
+
S.rvs(size=(10,3))
|
|
225
|
+
S.rvs()
|
|
226
|
+
|
|
227
|
+
sample
|
|
228
|
+
|
|
229
|
+
S = conditional_survival(n=[[10],[2000]], p=[[0.8,0.4,0.2],[0.8,0.4,0.2]])
|
|
230
|
+
prob = S.pmf([[10,5,2], [7,2,0]])
|
|
231
|
+
S.rvs()
|
|
232
|
+
|
|
233
|
+
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
import numpy as np
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import numpyro
|
|
6
|
+
from numpyro.infer import Predictive
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def survival_predictions(
|
|
11
|
+
probs, n_trials,
|
|
12
|
+
eps: float = 0.0,
|
|
13
|
+
seed=1,
|
|
14
|
+
mode: Literal["survival", "lethality", "deaths"] = "survival"
|
|
15
|
+
):
|
|
16
|
+
"""Generate predictions for survival based on a multinomial survival distribution
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
probs : ArrayLike
|
|
21
|
+
2D Array denoting the multinomial probabilities of deaths for each time interval
|
|
22
|
+
per experiment
|
|
23
|
+
dims=(experiment, time)
|
|
24
|
+
n_trials : ArrayLike
|
|
25
|
+
1D Array denoting the number of organisms at the beginning in each experiment
|
|
26
|
+
dims = (experiment,)
|
|
27
|
+
seed : int, optional
|
|
28
|
+
Seed for the random number generator, by default 1
|
|
29
|
+
mode : str, optional
|
|
30
|
+
How should the random draws be returned?
|
|
31
|
+
- survival: Decreasing from n_trials to 0
|
|
32
|
+
- lethality: Increasing from 0 to n_trials
|
|
33
|
+
- deaths: Between 0 and n_trials in each interval. Summing to n_trials
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def survival_to_death_probs(pr_survival):
|
|
37
|
+
# truncate here, because numeric errors below the solver tolerance can
|
|
38
|
+
# lead to negative values in the difference. This needs to be cured here
|
|
39
|
+
pr_survival_ = np.trunc(pr_survival / eps) * eps
|
|
40
|
+
pr_death = pr_survival_[:-1] - pr_survival_[1:]
|
|
41
|
+
|
|
42
|
+
pr_death = np.concatenate([
|
|
43
|
+
# concatenate a zero at the beginning in order to "simulate" no
|
|
44
|
+
# deaths at T = 0
|
|
45
|
+
jnp.zeros((1,)),
|
|
46
|
+
# Delta S
|
|
47
|
+
pr_death,
|
|
48
|
+
# The remaining mortility as T -> infinity
|
|
49
|
+
jnp.ones((1,))-pr_death.sum()
|
|
50
|
+
])
|
|
51
|
+
|
|
52
|
+
# make sure the vector is not zero or 1 (this is always problematic for
|
|
53
|
+
# probabilities) and make sure the vector sums to 1
|
|
54
|
+
pr_death = np.clip(pr_death, eps, 1-eps)
|
|
55
|
+
pr_death = pr_death / pr_death.sum()
|
|
56
|
+
return pr_death
|
|
57
|
+
|
|
58
|
+
rng = np.random.default_rng(seed)
|
|
59
|
+
deaths = jnp.array(list(map(
|
|
60
|
+
lambda n, pr_survival: rng.multinomial(
|
|
61
|
+
n=n, pvals=survival_to_death_probs(pr_survival)
|
|
62
|
+
),
|
|
63
|
+
n_trials,
|
|
64
|
+
probs
|
|
65
|
+
)))
|
|
66
|
+
|
|
67
|
+
# remove the last observations to trim off the simulated unobserved mortality
|
|
68
|
+
deaths = deaths[:, :-1]
|
|
69
|
+
|
|
70
|
+
if mode == "deaths":
|
|
71
|
+
return deaths
|
|
72
|
+
elif mode == "lethality":
|
|
73
|
+
return deaths.cumsum(axis=1)
|
|
74
|
+
elif mode == "survival":
|
|
75
|
+
return np.expand_dims(n_trials, axis=1) - deaths.cumsum(axis=1)
|
|
76
|
+
else:
|
|
77
|
+
raise NotImplementedError(
|
|
78
|
+
f"Mode {mode} is not implemented."+
|
|
79
|
+
"Use one of 'survival', 'lethality', or 'deaths'."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def posterior_predictions(sim, idata, seed=None):
|
|
83
|
+
"""Make posterior predictions for survival data"""
|
|
84
|
+
if seed is None:
|
|
85
|
+
seed = sim.config.simulation.seed
|
|
86
|
+
|
|
87
|
+
n = idata.posterior.dims["draw"]
|
|
88
|
+
c = idata.posterior.dims["chain"]
|
|
89
|
+
|
|
90
|
+
key = jax.random.PRNGKey(seed)
|
|
91
|
+
|
|
92
|
+
obs, masks = sim.inferer.observation_parser()
|
|
93
|
+
|
|
94
|
+
model_kwargs = sim.inferer.preprocessing(obs=obs, masks=masks)
|
|
95
|
+
if sim.config.inference_numpyro.user_defined_error_model is None:
|
|
96
|
+
model_kwargs["obs"]["survival"] = None
|
|
97
|
+
|
|
98
|
+
# prepare model
|
|
99
|
+
model = partial(
|
|
100
|
+
sim.inferer.inference_model,
|
|
101
|
+
solver=sim.inferer.evaluator,
|
|
102
|
+
**model_kwargs
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
posterior_samples = {
|
|
106
|
+
k: np.array(v["data"]) for k, v
|
|
107
|
+
in idata.unconstrained_posterior.to_dict()["data_vars"].items()
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
predictive = Predictive(
|
|
111
|
+
model, posterior_samples=posterior_samples,
|
|
112
|
+
num_samples=n, batch_ndims=2
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
samples = predictive(key)
|
|
116
|
+
|
|
117
|
+
if sim.config.inference_numpyro.user_defined_error_model is not None:
|
|
118
|
+
chains = []
|
|
119
|
+
for i in range(c):
|
|
120
|
+
# TODO: Switch to vmap and jit, but it did not work, so if you do it TEST IT!!!!!!!
|
|
121
|
+
predictions = list(map(
|
|
122
|
+
partial(
|
|
123
|
+
survival_predictions,
|
|
124
|
+
n_trials=obs["survival"][:, 0].astype(int),
|
|
125
|
+
eps=obs["eps"],
|
|
126
|
+
seed=seed,
|
|
127
|
+
mode="survival",
|
|
128
|
+
),
|
|
129
|
+
samples["survival"][i]
|
|
130
|
+
))
|
|
131
|
+
chains.append(predictions)
|
|
132
|
+
|
|
133
|
+
posterior_predictive = {"survival_obs": np.array(chains)}
|
|
134
|
+
else:
|
|
135
|
+
posterior_predictive = {"survival_obs": samples.pop("survival_obs")}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
new_idata = sim.inferer.to_arviz_idata(
|
|
139
|
+
posterior=samples,
|
|
140
|
+
posterior_predictive=posterior_predictive,
|
|
141
|
+
n_draws=n,
|
|
142
|
+
n_chains=c
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# update chain names in case they were subselected (clustering)
|
|
146
|
+
# new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
|
|
147
|
+
new_idata = new_idata.assign_coords({"chain": idata.posterior.chain.values})
|
|
148
|
+
|
|
149
|
+
# assert the new posterior matches the old posterior
|
|
150
|
+
tol = sim.config.jaxsolver.atol * 100
|
|
151
|
+
abs_diff_posterior = np.abs(idata.posterior - new_idata.posterior)
|
|
152
|
+
np.testing.assert_array_less(abs_diff_posterior.mean().to_array(), tol)
|
|
153
|
+
|
|
154
|
+
fit_tol = tol * sim.coordinates["time"].max()
|
|
155
|
+
abs_diff_fits = np.abs(new_idata.posterior_model_fits - idata.posterior_model_fits)
|
|
156
|
+
np.testing.assert_array_less(abs_diff_fits.mean().to_array(), fit_tol)
|
|
157
|
+
|
|
158
|
+
idata.posterior_model_fits = new_idata.posterior_model_fits
|
|
159
|
+
idata.posterior_predictive = new_idata.posterior_predictive
|
|
160
|
+
|
|
161
|
+
return idata
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from . import base
|
|
2
|
+
from . import ecx
|
|
3
|
+
from . import report
|
|
4
|
+
from . import utils
|
|
5
|
+
from . import transformer
|
|
6
|
+
from . import config
|
|
7
|
+
from . import units
|
|
8
|
+
|
|
9
|
+
from .base import (
|
|
10
|
+
GutsBase,
|
|
11
|
+
GutsSimulationConstantExposure,
|
|
12
|
+
GutsSimulationVariableExposure
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from .ecx import ECxEstimator, LPxEstimator
|
|
16
|
+
from .report import GutsReport, ParameterConverter
|
|
17
|
+
|
|
18
|
+
from .mempy import PymobSimulator
|
|
19
|
+
from .utils import (
|
|
20
|
+
GutsBaseError
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from .constructors import (
|
|
24
|
+
construct_sim_from_config,
|
|
25
|
+
load_idata
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
from .config import GutsBaseConfig
|