bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +4 -3
- bayinx/constraints/identity.py +26 -0
- bayinx/constraints/interval.py +62 -0
- bayinx/constraints/lower.py +31 -24
- bayinx/constraints/upper.py +57 -0
- bayinx/core/__init__.py +0 -7
- bayinx/core/constraint.py +32 -0
- bayinx/core/context.py +42 -0
- bayinx/core/distribution.py +34 -0
- bayinx/core/flow.py +99 -0
- bayinx/core/model.py +228 -0
- bayinx/core/node.py +201 -0
- bayinx/core/types.py +17 -0
- bayinx/core/utils.py +109 -0
- bayinx/core/variational.py +170 -0
- bayinx/dists/__init__.py +5 -3
- bayinx/dists/bernoulli.py +180 -11
- bayinx/dists/binomial.py +215 -0
- bayinx/dists/exponential.py +211 -0
- bayinx/dists/normal.py +131 -59
- bayinx/dists/poisson.py +203 -0
- bayinx/flows/__init__.py +5 -0
- bayinx/flows/diagaffine.py +120 -0
- bayinx/flows/fullaffine.py +123 -0
- bayinx/flows/lowrankaffine.py +165 -0
- bayinx/flows/planar.py +155 -0
- bayinx/flows/radial.py +1 -0
- bayinx/flows/sylvester.py +225 -0
- bayinx/nodes/__init__.py +3 -0
- bayinx/nodes/continuous.py +64 -0
- bayinx/nodes/observed.py +36 -0
- bayinx/nodes/stochastic.py +25 -0
- bayinx/ops.py +104 -0
- bayinx/posterior.py +220 -0
- bayinx/vi/__init__.py +0 -0
- bayinx/{mhx/vi → vi}/meanfield.py +33 -29
- bayinx/vi/normalizing_flow.py +246 -0
- bayinx/vi/standard.py +95 -0
- bayinx-0.5.3.dist-info/METADATA +93 -0
- bayinx-0.5.3.dist-info/RECORD +44 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
- bayinx/core/_constraint.py +0 -28
- bayinx/core/_flow.py +0 -80
- bayinx/core/_model.py +0 -98
- bayinx/core/_parameter.py +0 -44
- bayinx/core/_variational.py +0 -181
- bayinx/dists/censored/__init__.py +0 -3
- bayinx/dists/censored/gamma2/__init__.py +0 -3
- bayinx/dists/censored/gamma2/r.py +0 -68
- bayinx/dists/censored/posnormal/__init__.py +0 -3
- bayinx/dists/censored/posnormal/r.py +0 -116
- bayinx/dists/gamma2.py +0 -49
- bayinx/dists/posnormal.py +0 -260
- bayinx/dists/uniform.py +0 -75
- bayinx/mhx/__init__.py +0 -1
- bayinx/mhx/vi/__init__.py +0 -5
- bayinx/mhx/vi/flows/__init__.py +0 -3
- bayinx/mhx/vi/flows/fullaffine.py +0 -75
- bayinx/mhx/vi/flows/planar.py +0 -74
- bayinx/mhx/vi/flows/radial.py +0 -94
- bayinx/mhx/vi/flows/sylvester.py +0 -19
- bayinx/mhx/vi/normalizing_flow.py +0 -149
- bayinx/mhx/vi/standard.py +0 -63
- bayinx-0.3.10.dist-info/METADATA +0 -39
- bayinx-0.3.10.dist-info/RECORD +0 -35
- /bayinx/{py.typed → flows/otflow.py} +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/core/_parameter.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
from typing import Generic, Self, TypeVar
|
|
2
|
-
|
|
3
|
-
import equinox as eqx
|
|
4
|
-
import jax.tree as jt
|
|
5
|
-
from jaxtyping import PyTree
|
|
6
|
-
|
|
7
|
-
T = TypeVar("T", bound=PyTree)
|
|
8
|
-
class Parameter(eqx.Module, Generic[T]):
|
|
9
|
-
"""
|
|
10
|
-
A container for a parameter of a `Model`.
|
|
11
|
-
|
|
12
|
-
Subclasses can be constructed for custom filter specifications(`filter_spec`).
|
|
13
|
-
|
|
14
|
-
# Attributes
|
|
15
|
-
- `vals`: The parameter's value(s).
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
vals: T
|
|
19
|
-
|
|
20
|
-
def __init__(self, values: T):
|
|
21
|
-
# Insert parameter values
|
|
22
|
-
self.vals = values
|
|
23
|
-
|
|
24
|
-
def __call__(self) -> T:
|
|
25
|
-
return self.vals
|
|
26
|
-
|
|
27
|
-
# Default filter specification
|
|
28
|
-
@property
|
|
29
|
-
@eqx.filter_jit
|
|
30
|
-
def filter_spec(self) -> Self:
|
|
31
|
-
"""
|
|
32
|
-
Generates a filter specification to filter out static parameters.
|
|
33
|
-
"""
|
|
34
|
-
# Generate empty specification
|
|
35
|
-
filter_spec = jt.map(lambda _: False, self)
|
|
36
|
-
|
|
37
|
-
# Specify Array leaves
|
|
38
|
-
filter_spec = eqx.tree_at(
|
|
39
|
-
lambda params: params.vals,
|
|
40
|
-
filter_spec,
|
|
41
|
-
replace=jt.map(eqx.is_array_like, self.vals),
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
return filter_spec
|
bayinx/core/_variational.py
DELETED
|
@@ -1,181 +0,0 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
|
-
from functools import partial
|
|
3
|
-
from typing import Any, Callable, Generic, Self, Tuple, TypeVar
|
|
4
|
-
|
|
5
|
-
import equinox as eqx
|
|
6
|
-
import jax
|
|
7
|
-
import jax.lax as lax
|
|
8
|
-
import jax.numpy as jnp
|
|
9
|
-
import jax.random as jr
|
|
10
|
-
import optax as opx
|
|
11
|
-
from jaxtyping import Array, Key, PyTree, Scalar
|
|
12
|
-
from optax import GradientTransformation, OptState, Schedule
|
|
13
|
-
|
|
14
|
-
from ._model import Model
|
|
15
|
-
|
|
16
|
-
M = TypeVar('M', bound=Model)
|
|
17
|
-
class Variational(eqx.Module, Generic[M]):
|
|
18
|
-
"""
|
|
19
|
-
An abstract base class used to define variational methods.
|
|
20
|
-
|
|
21
|
-
# Attributes
|
|
22
|
-
- `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
|
|
23
|
-
- `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
_unflatten: Callable[[Array], M]
|
|
27
|
-
_constraints: M
|
|
28
|
-
|
|
29
|
-
@abstractmethod
|
|
30
|
-
def filter_spec(self):
|
|
31
|
-
"""
|
|
32
|
-
Filter specification for dynamic and static components of the `Variational`.
|
|
33
|
-
"""
|
|
34
|
-
pass
|
|
35
|
-
|
|
36
|
-
@abstractmethod
|
|
37
|
-
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
|
38
|
-
"""
|
|
39
|
-
Sample from the variational distribution.
|
|
40
|
-
"""
|
|
41
|
-
pass
|
|
42
|
-
|
|
43
|
-
@abstractmethod
|
|
44
|
-
def eval(self, draws: Array) -> Array:
|
|
45
|
-
"""
|
|
46
|
-
Evaluate the variational distribution at `draws`.
|
|
47
|
-
"""
|
|
48
|
-
pass
|
|
49
|
-
|
|
50
|
-
@abstractmethod
|
|
51
|
-
def elbo(self, n: int, key: Key, data: Any = None) -> Array:
|
|
52
|
-
"""
|
|
53
|
-
Evaluate the ELBO.
|
|
54
|
-
"""
|
|
55
|
-
pass
|
|
56
|
-
|
|
57
|
-
@abstractmethod
|
|
58
|
-
def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
|
|
59
|
-
"""
|
|
60
|
-
Evaluate the gradient of the ELBO.
|
|
61
|
-
"""
|
|
62
|
-
pass
|
|
63
|
-
|
|
64
|
-
@eqx.filter_jit
|
|
65
|
-
@partial(jax.vmap, in_axes=(None, 0, None))
|
|
66
|
-
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
|
67
|
-
"""
|
|
68
|
-
Reconstruct models from variational draws and evaluate their posterior density.
|
|
69
|
-
|
|
70
|
-
# Parameters
|
|
71
|
-
- `draws`: A set of variational draws.
|
|
72
|
-
- `data`: Data used to evaluate the posterior(if needed).
|
|
73
|
-
"""
|
|
74
|
-
# Unflatten variational draw
|
|
75
|
-
model: M = self._unflatten(draws)
|
|
76
|
-
|
|
77
|
-
# Combine with constraints
|
|
78
|
-
model: M = eqx.combine(model, self._constraints)
|
|
79
|
-
|
|
80
|
-
# Evaluate posterior density
|
|
81
|
-
return model.eval(data)
|
|
82
|
-
|
|
83
|
-
@eqx.filter_jit
|
|
84
|
-
def fit(
|
|
85
|
-
self,
|
|
86
|
-
max_iters: int,
|
|
87
|
-
data: Any = None,
|
|
88
|
-
learning_rate: float = 1,
|
|
89
|
-
weight_decay: float = 1e-4,
|
|
90
|
-
tolerance: float = 1e-4,
|
|
91
|
-
var_draws: int = 1,
|
|
92
|
-
key: Key = jr.PRNGKey(0),
|
|
93
|
-
) -> Self:
|
|
94
|
-
"""
|
|
95
|
-
Optimize the variational distribution.
|
|
96
|
-
|
|
97
|
-
# Parameters
|
|
98
|
-
- `max_iters`: Maximum number of iterations for the optimization loop.
|
|
99
|
-
- `data`: Data to evaluate the posterior density with(if available).
|
|
100
|
-
- `learning_rate`: Initial learning rate for optimization.
|
|
101
|
-
- `tolerance`: Relative tolerance of ELBO decrease for stopping early.
|
|
102
|
-
- `var_draws`: Number of variational draws to draw each iteration.
|
|
103
|
-
- `key`: A PRNG key.
|
|
104
|
-
"""
|
|
105
|
-
# Partition variational
|
|
106
|
-
dyn, static = eqx.partition(self, self.filter_spec)
|
|
107
|
-
|
|
108
|
-
# Construct scheduler
|
|
109
|
-
schedule: Schedule = opx.warmup_cosine_decay_schedule(
|
|
110
|
-
init_value=1e-16, peak_value=learning_rate, warmup_steps=int(max_iters/10), decay_steps=max_iters-int(max_iters/10)
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
# Initialize optimizer
|
|
114
|
-
optim: GradientTransformation = opx.chain(
|
|
115
|
-
opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
|
|
116
|
-
)
|
|
117
|
-
opt_state: OptState = optim.init(dyn)
|
|
118
|
-
|
|
119
|
-
# Optimization loop helper functions
|
|
120
|
-
@eqx.filter_jit
|
|
121
|
-
def condition(state: Tuple[Self, OptState, Scalar, Key]):
|
|
122
|
-
# Unpack iteration state
|
|
123
|
-
dyn, opt_state, i, key = state
|
|
124
|
-
|
|
125
|
-
return i < max_iters
|
|
126
|
-
|
|
127
|
-
@eqx.filter_jit
|
|
128
|
-
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
|
129
|
-
# Unpack iteration state
|
|
130
|
-
dyn, opt_state, i, key = state
|
|
131
|
-
|
|
132
|
-
# Update iteration
|
|
133
|
-
i = i + 1
|
|
134
|
-
|
|
135
|
-
# Update PRNG key
|
|
136
|
-
key, _ = jr.split(key)
|
|
137
|
-
|
|
138
|
-
# Reconstruct variational
|
|
139
|
-
vari = eqx.combine(dyn, static)
|
|
140
|
-
|
|
141
|
-
# Compute gradient of the ELBO
|
|
142
|
-
updates: PyTree = vari.elbo_grad(var_draws, key, data)
|
|
143
|
-
|
|
144
|
-
# Compute updates
|
|
145
|
-
updates, opt_state = optim.update(
|
|
146
|
-
updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
# Update variational distribution
|
|
150
|
-
dyn = eqx.apply_updates(dyn, updates)
|
|
151
|
-
|
|
152
|
-
return dyn, opt_state, i, key
|
|
153
|
-
|
|
154
|
-
# Run optimization loop
|
|
155
|
-
dyn = lax.while_loop(
|
|
156
|
-
cond_fun=condition,
|
|
157
|
-
body_fun=body,
|
|
158
|
-
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
|
159
|
-
)[0]
|
|
160
|
-
|
|
161
|
-
# Return optimized variational
|
|
162
|
-
return eqx.combine(dyn, static)
|
|
163
|
-
|
|
164
|
-
@eqx.filter_jit
|
|
165
|
-
def posterior_predictive(
|
|
166
|
-
self, func: Callable[[M, Any], Array], n: int, data: Any = None, key: Key = jr.PRNGKey(0)
|
|
167
|
-
) -> Array:
|
|
168
|
-
# Sample draws from the variational approximation
|
|
169
|
-
draws: Array = self.sample(n, key)
|
|
170
|
-
|
|
171
|
-
# Evaluate posterior predictive
|
|
172
|
-
@jax.jit
|
|
173
|
-
@jax.vmap
|
|
174
|
-
def evaluate(draw: Array, data: Any = None):
|
|
175
|
-
# Reconstruct model
|
|
176
|
-
model: M = self._unflatten(draw)
|
|
177
|
-
|
|
178
|
-
# Evaluate
|
|
179
|
-
return func(model, data)
|
|
180
|
-
|
|
181
|
-
return evaluate(draws, data)
|
|
@@ -1,68 +0,0 @@
|
|
|
1
|
-
import jax.lax as lax
|
|
2
|
-
import jax.numpy as jnp
|
|
3
|
-
from jax.scipy.special import gammaincc
|
|
4
|
-
from jaxtyping import Array, ArrayLike, Float
|
|
5
|
-
|
|
6
|
-
from bayinx.dists import gamma2
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def prob(
|
|
10
|
-
x: Float[ArrayLike, "..."],
|
|
11
|
-
mu: Float[ArrayLike, "..."],
|
|
12
|
-
nu: Float[ArrayLike, "..."],
|
|
13
|
-
censor: Float[ArrayLike, "..."],
|
|
14
|
-
) -> Float[Array, "..."]:
|
|
15
|
-
"""
|
|
16
|
-
The mixed probability mass/density function (PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
|
|
17
|
-
|
|
18
|
-
# Parameters
|
|
19
|
-
- `x`: Value(s) at which to evaluate the PMF/PDF.
|
|
20
|
-
- `mu`: The positive mean.
|
|
21
|
-
- `nu`: The positive inverse dispersion.
|
|
22
|
-
- `censor`: The positive censor value.
|
|
23
|
-
|
|
24
|
-
# Returns
|
|
25
|
-
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
|
|
26
|
-
"""
|
|
27
|
-
evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
|
|
28
|
-
|
|
29
|
-
# Construct boolean masks
|
|
30
|
-
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
|
31
|
-
censored: Array = jnp.array(x == censor) # pyright: ignore
|
|
32
|
-
|
|
33
|
-
# Evaluate probability mass/density function
|
|
34
|
-
evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
|
|
35
|
-
evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
|
|
36
|
-
|
|
37
|
-
return evals
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def logprob(
|
|
41
|
-
x: Float[ArrayLike, "..."],
|
|
42
|
-
mu: Float[ArrayLike, "..."],
|
|
43
|
-
nu: Float[ArrayLike, "..."],
|
|
44
|
-
censor: Float[ArrayLike, "..."],
|
|
45
|
-
) -> Float[Array, "..."]:
|
|
46
|
-
"""
|
|
47
|
-
The log-transformed mixed probability mass/density function (log PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
|
|
48
|
-
|
|
49
|
-
# Parameters
|
|
50
|
-
- `x`: Value(s) at which to evaluate the log PMF/PDF.
|
|
51
|
-
- `mu`: The positive mean/location.
|
|
52
|
-
- `nu`: The positive inverse dispersion.
|
|
53
|
-
- `censor`: The positive censor value.
|
|
54
|
-
|
|
55
|
-
# Returns
|
|
56
|
-
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
|
|
57
|
-
"""
|
|
58
|
-
evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
|
|
59
|
-
|
|
60
|
-
# Construct boolean masks
|
|
61
|
-
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
|
62
|
-
censored: Array = jnp.array(x == censor) # pyright: ignore
|
|
63
|
-
|
|
64
|
-
# Evaluate log probability mass/density function
|
|
65
|
-
evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
|
|
66
|
-
evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
|
|
67
|
-
|
|
68
|
-
return evals
|
|
@@ -1,116 +0,0 @@
|
|
|
1
|
-
import jax.numpy as jnp
|
|
2
|
-
import jax.random as jr
|
|
3
|
-
from jaxtyping import Array, ArrayLike, Float, Key
|
|
4
|
-
|
|
5
|
-
from bayinx.dists import posnormal
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def prob(
|
|
9
|
-
x: Float[ArrayLike, "..."],
|
|
10
|
-
mu: Float[ArrayLike, "..."],
|
|
11
|
-
sigma: Float[ArrayLike, "..."],
|
|
12
|
-
censor: Float[ArrayLike, "..."],
|
|
13
|
-
) -> Float[Array, "..."]:
|
|
14
|
-
"""
|
|
15
|
-
The mixed probability mass/density function (PMF/PDF) for a right-censored positive Normal distribution.
|
|
16
|
-
|
|
17
|
-
# Parameters
|
|
18
|
-
- `x`: Value(s) at which to evaluate the PMF/PDF.
|
|
19
|
-
- `mu`: The mean.
|
|
20
|
-
- `sigma`: The positive standard deviation.
|
|
21
|
-
- `censor`: The positive censor value.
|
|
22
|
-
|
|
23
|
-
# Returns
|
|
24
|
-
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
|
|
25
|
-
"""
|
|
26
|
-
# Cast to Array
|
|
27
|
-
x, mu, sigma, censor = (
|
|
28
|
-
jnp.asarray(x),
|
|
29
|
-
jnp.asarray(mu),
|
|
30
|
-
jnp.asarray(sigma),
|
|
31
|
-
jnp.asarray(censor),
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
# Construct boolean masks
|
|
35
|
-
uncensored: Array = jnp.logical_and(0.0 < x, x < censor)
|
|
36
|
-
censored: Array = x == censor
|
|
37
|
-
|
|
38
|
-
# Evaluate probability mass/density function
|
|
39
|
-
evals = jnp.where(uncensored, posnormal.prob(x, mu, sigma), 0.0)
|
|
40
|
-
evals = jnp.where(censored, posnormal.ccdf(x, mu, sigma), evals)
|
|
41
|
-
|
|
42
|
-
return evals
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def logprob(
|
|
46
|
-
x: Float[ArrayLike, "..."],
|
|
47
|
-
mu: Float[ArrayLike, "..."],
|
|
48
|
-
sigma: Float[ArrayLike, "..."],
|
|
49
|
-
censor: Float[ArrayLike, "..."],
|
|
50
|
-
) -> Float[Array, "..."]:
|
|
51
|
-
"""
|
|
52
|
-
The log-transformed mixed probability mass/density function (log PMF/PDF) for a right-censored positive Normal distribution.
|
|
53
|
-
|
|
54
|
-
# Parameters
|
|
55
|
-
- `x`: Where to evaluate the log PMF/PDF.
|
|
56
|
-
- `mu`: The mean.
|
|
57
|
-
- `sigma`: The standard deviation.
|
|
58
|
-
- `censor`: The censor.
|
|
59
|
-
|
|
60
|
-
# Returns
|
|
61
|
-
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
|
|
62
|
-
"""
|
|
63
|
-
# Cast to Array
|
|
64
|
-
x, mu, sigma, censor = (
|
|
65
|
-
jnp.asarray(x),
|
|
66
|
-
jnp.asarray(mu),
|
|
67
|
-
jnp.asarray(sigma),
|
|
68
|
-
jnp.asarray(censor),
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
# Construct boolean masks for censoring
|
|
72
|
-
uncensored: Array = jnp.logical_and(jnp.asarray(0.0) < x, x < censor)
|
|
73
|
-
censored: Array = x == censor
|
|
74
|
-
|
|
75
|
-
# Evaluate log probability mass/density function
|
|
76
|
-
evals = jnp.where(uncensored, posnormal.logprob(x, mu, sigma), -jnp.inf)
|
|
77
|
-
evals = jnp.where(censored, posnormal.logccdf(x, mu, sigma), evals)
|
|
78
|
-
|
|
79
|
-
return evals
|
|
80
|
-
|
|
81
|
-
def sample(
|
|
82
|
-
n: int,
|
|
83
|
-
mu: Float[ArrayLike, "..."],
|
|
84
|
-
sigma: Float[ArrayLike, "..."],
|
|
85
|
-
censor: Float[ArrayLike, "..."] = jnp.inf,
|
|
86
|
-
key: Key = jr.PRNGKey(0)
|
|
87
|
-
) -> Float[Array, "..."]:
|
|
88
|
-
"""
|
|
89
|
-
Sample from a right-censored positive Normal distribution.
|
|
90
|
-
|
|
91
|
-
# Parameters
|
|
92
|
-
- `n`: Number of draws to sample per-parameter.
|
|
93
|
-
- `mu`: The mean.
|
|
94
|
-
- `sigma`: The standard deviation.
|
|
95
|
-
- `censor`: The censor.
|
|
96
|
-
|
|
97
|
-
# Returns
|
|
98
|
-
Draws from a right-censored positive Normal distribution. The output will have the shape of (n,) + the broadcasted shapes of `mu`, `sigma`, and `censor`.
|
|
99
|
-
"""
|
|
100
|
-
# Cast to Array
|
|
101
|
-
mu, sigma, censor = (
|
|
102
|
-
jnp.asarray(mu),
|
|
103
|
-
jnp.asarray(sigma),
|
|
104
|
-
jnp.asarray(censor),
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
# Derive shape
|
|
108
|
-
shape = (n,) + jnp.broadcast_shapes(mu.shape, sigma.shape, censor.shape)
|
|
109
|
-
|
|
110
|
-
# Draw from positive normal
|
|
111
|
-
draws = jr.truncated_normal(key, 0.0, jnp.inf, shape) * sigma + mu
|
|
112
|
-
|
|
113
|
-
# Censor values
|
|
114
|
-
draws = jnp.where(censor <= draws, censor, draws)
|
|
115
|
-
|
|
116
|
-
return draws
|
bayinx/dists/gamma2.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
import jax.lax as lax
|
|
2
|
-
import jax.numpy as jnp
|
|
3
|
-
from jax.scipy.special import gammaln
|
|
4
|
-
from jaxtyping import Array, ArrayLike, Float
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def prob(
|
|
8
|
-
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
|
|
9
|
-
) -> Float[Array, "..."]:
|
|
10
|
-
"""
|
|
11
|
-
The probability density function (PDF) for a (mean-precision parameterized) Gamma distribution.
|
|
12
|
-
|
|
13
|
-
# Parameters
|
|
14
|
-
- `x`: Value(s) at which to evaluate the PDF.
|
|
15
|
-
- `mu`: The positive mean.
|
|
16
|
-
- `nu`: The positive inverse dispersion.
|
|
17
|
-
|
|
18
|
-
# Returns
|
|
19
|
-
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
|
20
|
-
"""
|
|
21
|
-
# Cast to Array
|
|
22
|
-
x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
|
|
23
|
-
|
|
24
|
-
return lax.exp(logprob(x, mu, nu))
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def logprob(
|
|
28
|
-
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
|
|
29
|
-
) -> Float[Array, "..."]:
|
|
30
|
-
"""
|
|
31
|
-
The log-transformed probability density function (log PDF) for a (mean-precision parameterized) Gamma distribution.
|
|
32
|
-
|
|
33
|
-
# Parameters
|
|
34
|
-
- `x`: Value(s) at which to evaluate the log PDF.
|
|
35
|
-
- `mu`: The positive mean/location.
|
|
36
|
-
- `nu`: The positive inverse dispersion.
|
|
37
|
-
|
|
38
|
-
# Returns
|
|
39
|
-
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
|
40
|
-
"""
|
|
41
|
-
# Cast to Array
|
|
42
|
-
x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
|
|
43
|
-
|
|
44
|
-
return (
|
|
45
|
-
-gammaln(nu)
|
|
46
|
-
+ nu * (lax.log(nu) - lax.log(mu))
|
|
47
|
-
+ (nu - 1.0) * lax.log(x)
|
|
48
|
-
- (x * nu / mu)
|
|
49
|
-
) # pyright: ignore
|