bayinx 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +3 -1
- bayinx/constraints/lower.py +3 -4
- bayinx/core/__init__.py +7 -4
- bayinx/core/{constraint.py → _constraint.py} +1 -1
- bayinx/core/{model.py → _model.py} +7 -11
- bayinx/core/{parameter.py → _parameter.py} +2 -2
- bayinx/core/{variational.py → _variational.py} +30 -11
- bayinx/dists/__init__.py +2 -2
- bayinx/dists/censored/__init__.py +2 -2
- bayinx/dists/censored/gamma2/__init__.py +1 -1
- bayinx/dists/censored/gamma2/r.py +10 -10
- bayinx/dists/censored/posnormal/__init__.py +1 -1
- bayinx/dists/censored/posnormal/r.py +19 -9
- bayinx/dists/gamma2.py +12 -2
- bayinx/dists/normal.py +50 -34
- bayinx/dists/posnormal.py +77 -75
- bayinx/mhx/vi/__init__.py +5 -3
- bayinx/mhx/vi/meanfield.py +5 -6
- bayinx/mhx/vi/normalizing_flow.py +4 -4
- bayinx/mhx/vi/standard.py +4 -8
- {bayinx-0.3.5.dist-info → bayinx-0.3.7.dist-info}/METADATA +2 -1
- bayinx-0.3.7.dist-info/RECORD +35 -0
- bayinx-0.3.7.dist-info/licenses/LICENSE +21 -0
- bayinx-0.3.5.dist-info/RECORD +0 -34
- /bayinx/core/{flow.py → _flow.py} +0 -0
- {bayinx-0.3.5.dist-info → bayinx-0.3.7.dist-info}/WHEEL +0 -0
bayinx/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1
|
-
from bayinx.core import Model
|
2
|
-
|
3
|
-
|
1
|
+
from bayinx.core import Model, Parameter, constrain
|
2
|
+
|
3
|
+
__all__ = ["Model", "Parameter", "constrain"]
|
bayinx/constraints/__init__.py
CHANGED
bayinx/constraints/lower.py
CHANGED
@@ -5,8 +5,7 @@ import jax.numpy as jnp
|
|
5
5
|
import jax.tree as jt
|
6
6
|
from jaxtyping import PyTree, Scalar, ScalarLike
|
7
7
|
|
8
|
-
from bayinx.core
|
9
|
-
from bayinx.core.parameter import Parameter
|
8
|
+
from bayinx.core import Constraint, Parameter
|
10
9
|
|
11
10
|
|
12
11
|
class Lower(Constraint):
|
@@ -39,8 +38,8 @@ class Lower(Constraint):
|
|
39
38
|
dyn_params, static_params = eqx.partition(x, filter_spec)
|
40
39
|
|
41
40
|
# Compute density adjustment
|
42
|
-
laj: PyTree = jt.map(jnp.sum, dyn_params)
|
43
|
-
laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
|
41
|
+
laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
|
42
|
+
laj: Scalar = jt.reduce(lambda a, b: a + b, laj)
|
44
43
|
|
45
44
|
# Compute transformation
|
46
45
|
dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
|
bayinx/core/__init__.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
from
|
4
|
-
from
|
1
|
+
from ._constraint import Constraint
|
2
|
+
from ._flow import Flow
|
3
|
+
from ._model import Model, constrain
|
4
|
+
from ._parameter import Parameter
|
5
|
+
from ._variational import Variational
|
6
|
+
|
7
|
+
__all__ = ["Constraint", "Flow", "Model", "constrain", "Parameter", "Variational"]
|
@@ -7,13 +7,13 @@ import jax.numpy as jnp
|
|
7
7
|
import jax.tree as jt
|
8
8
|
from jaxtyping import Scalar
|
9
9
|
|
10
|
-
from
|
11
|
-
from
|
10
|
+
from ._constraint import Constraint
|
11
|
+
from ._parameter import Parameter
|
12
12
|
|
13
13
|
|
14
14
|
def constrain(constraint: Constraint):
|
15
15
|
"""Defines constraint metadata."""
|
16
|
-
return field(metadata={
|
16
|
+
return field(metadata={"constraint": constraint})
|
17
17
|
|
18
18
|
|
19
19
|
class Model(eqx.Module):
|
@@ -49,12 +49,11 @@ class Model(eqx.Module):
|
|
49
49
|
filter_spec = eqx.tree_at(
|
50
50
|
lambda model: getattr(model, f.name),
|
51
51
|
filter_spec,
|
52
|
-
replace=attr.filter_spec
|
52
|
+
replace=attr.filter_spec,
|
53
53
|
)
|
54
54
|
|
55
55
|
return filter_spec
|
56
56
|
|
57
|
-
|
58
57
|
@eqx.filter_jit
|
59
58
|
def constrain_params(self) -> Tuple[Self, Scalar]:
|
60
59
|
"""
|
@@ -71,18 +70,16 @@ class Model(eqx.Module):
|
|
71
70
|
attr = getattr(self, f.name)
|
72
71
|
|
73
72
|
# Check if constrained parameter
|
74
|
-
if isinstance(attr, Parameter) and
|
73
|
+
if isinstance(attr, Parameter) and "constraint" in f.metadata:
|
75
74
|
param = attr
|
76
|
-
constraint = f.metadata[
|
75
|
+
constraint = f.metadata["constraint"]
|
77
76
|
|
78
77
|
# Apply constraint
|
79
78
|
param, laj = constraint.constrain(param)
|
80
79
|
|
81
80
|
# Update parameters for constrained model
|
82
81
|
constrained = eqx.tree_at(
|
83
|
-
lambda model: getattr(model, f.name),
|
84
|
-
constrained,
|
85
|
-
replace=param
|
82
|
+
lambda model: getattr(model, f.name), constrained, replace=param
|
86
83
|
)
|
87
84
|
|
88
85
|
# Adjust posterior density
|
@@ -90,7 +87,6 @@ class Model(eqx.Module):
|
|
90
87
|
|
91
88
|
return constrained, target
|
92
89
|
|
93
|
-
|
94
90
|
@eqx.filter_jit
|
95
91
|
def transform_params(self) -> Tuple[Self, Scalar]:
|
96
92
|
"""
|
@@ -4,7 +4,7 @@ import equinox as eqx
|
|
4
4
|
import jax.tree as jt
|
5
5
|
from jaxtyping import PyTree
|
6
6
|
|
7
|
-
T = TypeVar(
|
7
|
+
T = TypeVar("T", bound=PyTree)
|
8
8
|
class Parameter(eqx.Module, Generic[T]):
|
9
9
|
"""
|
10
10
|
A container for a parameter of a `Model`.
|
@@ -14,8 +14,8 @@ class Parameter(eqx.Module, Generic[T]):
|
|
14
14
|
# Attributes
|
15
15
|
- `vals`: The parameter's value(s).
|
16
16
|
"""
|
17
|
-
vals: T
|
18
17
|
|
18
|
+
vals: T
|
19
19
|
|
20
20
|
def __init__(self, values: T):
|
21
21
|
# Insert parameter values
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from abc import abstractmethod
|
2
2
|
from functools import partial
|
3
|
-
from typing import Any, Callable, Self, Tuple
|
3
|
+
from typing import Any, Callable, Generic, Self, Tuple, TypeVar
|
4
4
|
|
5
5
|
import equinox as eqx
|
6
6
|
import jax
|
@@ -11,10 +11,10 @@ import optax as opx
|
|
11
11
|
from jaxtyping import Array, Key, PyTree, Scalar
|
12
12
|
from optax import GradientTransformation, OptState, Schedule
|
13
13
|
|
14
|
-
from
|
14
|
+
from ._model import Model
|
15
15
|
|
16
|
-
|
17
|
-
class Variational(eqx.Module):
|
16
|
+
M = TypeVar('M', bound=Model)
|
17
|
+
class Variational(eqx.Module, Generic[M]):
|
18
18
|
"""
|
19
19
|
An abstract base class used to define variational methods.
|
20
20
|
|
@@ -23,8 +23,8 @@ class Variational(eqx.Module):
|
|
23
23
|
- `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
|
24
24
|
"""
|
25
25
|
|
26
|
-
_unflatten: Callable[[Array],
|
27
|
-
_constraints:
|
26
|
+
_unflatten: Callable[[Array], M]
|
27
|
+
_constraints: M
|
28
28
|
|
29
29
|
@abstractmethod
|
30
30
|
def filter_spec(self):
|
@@ -34,7 +34,7 @@ class Variational(eqx.Module):
|
|
34
34
|
pass
|
35
35
|
|
36
36
|
@abstractmethod
|
37
|
-
def sample(self, n: int, key: Key) -> Array:
|
37
|
+
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
38
38
|
"""
|
39
39
|
Sample from the variational distribution.
|
40
40
|
"""
|
@@ -72,10 +72,10 @@ class Variational(eqx.Module):
|
|
72
72
|
- `data`: Data used to evaluate the posterior(if needed).
|
73
73
|
"""
|
74
74
|
# Unflatten variational draw
|
75
|
-
model:
|
75
|
+
model: M = self._unflatten(draws)
|
76
76
|
|
77
77
|
# Combine with constraints
|
78
|
-
model:
|
78
|
+
model: M = eqx.combine(model, self._constraints)
|
79
79
|
|
80
80
|
# Evaluate posterior density
|
81
81
|
return model.eval(data)
|
@@ -106,8 +106,8 @@ class Variational(eqx.Module):
|
|
106
106
|
dyn, static = eqx.partition(self, self.filter_spec)
|
107
107
|
|
108
108
|
# Construct scheduler
|
109
|
-
schedule: Schedule = opx.
|
110
|
-
init_value=learning_rate, decay_steps=max_iters
|
109
|
+
schedule: Schedule = opx.warmup_cosine_decay_schedule(
|
110
|
+
init_value=1e-16, peak_value=learning_rate, warmup_steps=int(max_iters/10), decay_steps=max_iters-int(max_iters/10)
|
111
111
|
)
|
112
112
|
|
113
113
|
# Initialize optimizer
|
@@ -160,3 +160,22 @@ class Variational(eqx.Module):
|
|
160
160
|
|
161
161
|
# Return optimized variational
|
162
162
|
return eqx.combine(dyn, static)
|
163
|
+
|
164
|
+
@eqx.filter_jit
|
165
|
+
def posterior_predictive(
|
166
|
+
self, func: Callable[[M], Array], n: int, key: Key = jr.PRNGKey(0)
|
167
|
+
) -> Array:
|
168
|
+
# Sample draws from the variational approximation
|
169
|
+
draws: Array = self.sample(n, key)
|
170
|
+
|
171
|
+
# Evaluate posterior predictive
|
172
|
+
@jax.jit
|
173
|
+
@jax.vmap
|
174
|
+
def evaluate(draw: Array):
|
175
|
+
# Reconstruct model
|
176
|
+
model: M = self._unflatten(draw)
|
177
|
+
|
178
|
+
# Evaluate
|
179
|
+
return func(model)
|
180
|
+
|
181
|
+
return evaluate(draws)
|
bayinx/dists/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1
|
-
from bayinx.dists import normal, posnormal
|
1
|
+
from bayinx.dists import censored, gamma2, normal, posnormal
|
2
2
|
|
3
|
-
__all__ = ['
|
3
|
+
__all__ = ['censored', "gamma2", "normal", "posnormal"]
|
@@ -1,3 +1,3 @@
|
|
1
|
-
from . import
|
1
|
+
from . import posnormal
|
2
2
|
|
3
|
-
__all__ = [
|
3
|
+
__all__ = ["posnormal"]
|
@@ -10,7 +10,7 @@ def prob(
|
|
10
10
|
x: Float[ArrayLike, "..."],
|
11
11
|
mu: Float[ArrayLike, "..."],
|
12
12
|
nu: Float[ArrayLike, "..."],
|
13
|
-
censor: Float[ArrayLike, "..."]
|
13
|
+
censor: Float[ArrayLike, "..."],
|
14
14
|
) -> Float[Array, "..."]:
|
15
15
|
"""
|
16
16
|
The mixed probability mass/density function (PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
|
@@ -24,15 +24,15 @@ def prob(
|
|
24
24
|
# Returns
|
25
25
|
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
|
26
26
|
"""
|
27
|
-
evals: Array = jnp.zeros_like(x * 1.0)
|
27
|
+
evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
|
28
28
|
|
29
29
|
# Construct boolean masks
|
30
|
-
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor))
|
31
|
-
censored: Array = jnp.array(x == censor)
|
30
|
+
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
31
|
+
censored: Array = jnp.array(x == censor) # pyright: ignore
|
32
32
|
|
33
33
|
# Evaluate probability mass/density function
|
34
34
|
evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
|
35
|
-
evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals)
|
35
|
+
evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
|
36
36
|
|
37
37
|
return evals
|
38
38
|
|
@@ -41,7 +41,7 @@ def logprob(
|
|
41
41
|
x: Float[ArrayLike, "..."],
|
42
42
|
mu: Float[ArrayLike, "..."],
|
43
43
|
nu: Float[ArrayLike, "..."],
|
44
|
-
censor: Float[ArrayLike, "..."]
|
44
|
+
censor: Float[ArrayLike, "..."],
|
45
45
|
) -> Float[Array, "..."]:
|
46
46
|
"""
|
47
47
|
The log-transformed mixed probability mass/density function (log PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
|
@@ -55,14 +55,14 @@ def logprob(
|
|
55
55
|
# Returns
|
56
56
|
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
|
57
57
|
"""
|
58
|
-
evals: Array = jnp.full_like(x * 1.0, -jnp.inf)
|
58
|
+
evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
|
59
59
|
|
60
60
|
# Construct boolean masks
|
61
|
-
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor))
|
62
|
-
censored: Array = jnp.array(x == censor)
|
61
|
+
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
62
|
+
censored: Array = jnp.array(x == censor) # pyright: ignore
|
63
63
|
|
64
64
|
# Evaluate log probability mass/density function
|
65
65
|
evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
|
66
|
-
evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals)
|
66
|
+
evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
|
67
67
|
|
68
68
|
return evals
|
@@ -8,10 +8,10 @@ def prob(
|
|
8
8
|
x: Float[ArrayLike, "..."],
|
9
9
|
mu: Float[ArrayLike, "..."],
|
10
10
|
sigma: Float[ArrayLike, "..."],
|
11
|
-
censor: Float[ArrayLike, "..."]
|
11
|
+
censor: Float[ArrayLike, "..."],
|
12
12
|
) -> Float[Array, "..."]:
|
13
13
|
"""
|
14
|
-
The mixed probability mass/density function (PMF/PDF) for a censored positive Normal distribution.
|
14
|
+
The mixed probability mass/density function (PMF/PDF) for a right-censored positive Normal distribution.
|
15
15
|
|
16
16
|
# Parameters
|
17
17
|
- `x`: Value(s) at which to evaluate the PMF/PDF.
|
@@ -23,7 +23,12 @@ def prob(
|
|
23
23
|
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
|
24
24
|
"""
|
25
25
|
# Cast to Array
|
26
|
-
x, mu, sigma, censor =
|
26
|
+
x, mu, sigma, censor = (
|
27
|
+
jnp.asarray(x),
|
28
|
+
jnp.asarray(mu),
|
29
|
+
jnp.asarray(sigma),
|
30
|
+
jnp.asarray(censor),
|
31
|
+
)
|
27
32
|
|
28
33
|
# Construct boolean masks
|
29
34
|
uncensored: Array = jnp.logical_and(0.0 < x, x < censor)
|
@@ -31,7 +36,7 @@ def prob(
|
|
31
36
|
|
32
37
|
# Evaluate probability mass/density function
|
33
38
|
evals = jnp.where(uncensored, posnormal.prob(x, mu, sigma), 0.0)
|
34
|
-
evals = jnp.where(censored, posnormal.ccdf(x,mu,sigma), evals)
|
39
|
+
evals = jnp.where(censored, posnormal.ccdf(x, mu, sigma), evals)
|
35
40
|
|
36
41
|
return evals
|
37
42
|
|
@@ -40,10 +45,10 @@ def logprob(
|
|
40
45
|
x: Float[ArrayLike, "..."],
|
41
46
|
mu: Float[ArrayLike, "..."],
|
42
47
|
sigma: Float[ArrayLike, "..."],
|
43
|
-
censor: Float[ArrayLike, "..."]
|
48
|
+
censor: Float[ArrayLike, "..."],
|
44
49
|
) -> Float[Array, "..."]:
|
45
50
|
"""
|
46
|
-
The log-transformed mixed probability mass/density function (log PMF/PDF) for a censored positive Normal distribution.
|
51
|
+
The log-transformed mixed probability mass/density function (log PMF/PDF) for a right-censored positive Normal distribution.
|
47
52
|
|
48
53
|
# Parameters
|
49
54
|
- `x`: Where to evaluate the log PMF/PDF.
|
@@ -55,10 +60,15 @@ def logprob(
|
|
55
60
|
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
|
56
61
|
"""
|
57
62
|
# Cast to Array
|
58
|
-
x, mu, sigma, censor =
|
63
|
+
x, mu, sigma, censor = (
|
64
|
+
jnp.asarray(x),
|
65
|
+
jnp.asarray(mu),
|
66
|
+
jnp.asarray(sigma),
|
67
|
+
jnp.asarray(censor),
|
68
|
+
)
|
59
69
|
|
60
|
-
# Construct boolean masks
|
61
|
-
uncensored: Array = jnp.logical_and(jnp.
|
70
|
+
# Construct boolean masks for censoring
|
71
|
+
uncensored: Array = jnp.logical_and(jnp.asarray(0.0) < x, x < censor)
|
62
72
|
censored: Array = x == censor
|
63
73
|
|
64
74
|
# Evaluate log probability mass/density function
|
bayinx/dists/gamma2.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import jax.lax as lax
|
2
|
+
import jax.numpy as jnp
|
2
3
|
from jax.scipy.special import gammaln
|
3
4
|
from jaxtyping import Array, ArrayLike, Float
|
4
5
|
|
@@ -17,6 +18,8 @@ def prob(
|
|
17
18
|
# Returns
|
18
19
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
19
20
|
"""
|
21
|
+
# Cast to Array
|
22
|
+
x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
|
20
23
|
|
21
24
|
return lax.exp(logprob(x, mu, nu))
|
22
25
|
|
@@ -35,5 +38,12 @@ def logprob(
|
|
35
38
|
# Returns
|
36
39
|
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
37
40
|
"""
|
38
|
-
|
39
|
-
|
41
|
+
# Cast to Array
|
42
|
+
x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
|
43
|
+
|
44
|
+
return (
|
45
|
+
-gammaln(nu)
|
46
|
+
+ nu * (lax.log(nu) - lax.log(mu))
|
47
|
+
+ (nu - 1.0) * lax.log(x)
|
48
|
+
- (x * nu / mu)
|
49
|
+
) # pyright: ignore
|
bayinx/dists/normal.py
CHANGED
@@ -7,116 +7,132 @@ __PI = 3.141592653589793
|
|
7
7
|
|
8
8
|
|
9
9
|
def prob(
|
10
|
-
x: Float[ArrayLike, "..."],
|
10
|
+
x: Float[ArrayLike, "..."],
|
11
|
+
mu: Float[ArrayLike, "..."],
|
12
|
+
sigma: Float[ArrayLike, "..."],
|
11
13
|
) -> Float[Array, "..."]:
|
12
14
|
"""
|
13
15
|
The probability density function (PDF) for a Normal distribution.
|
14
16
|
|
15
17
|
# Parameters
|
16
|
-
- `x`:
|
17
|
-
- `mu`:
|
18
|
-
- `sigma`:
|
18
|
+
- `x`: Where to evaluate the PDF.
|
19
|
+
- `mu`: The mean.
|
20
|
+
- `sigma`: The standard deviation.
|
19
21
|
|
20
22
|
# Returns
|
21
23
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
22
24
|
"""
|
23
25
|
# Cast to Array
|
24
|
-
x, mu, sigma = jnp.
|
26
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
25
27
|
|
26
|
-
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (
|
27
|
-
sigma * lax.sqrt(2.0 * __PI)
|
28
|
-
)
|
28
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (sigma * lax.sqrt(2.0 * __PI))
|
29
29
|
|
30
30
|
|
31
31
|
def logprob(
|
32
|
-
x: Float[ArrayLike, "..."],
|
32
|
+
x: Float[ArrayLike, "..."],
|
33
|
+
mu: Float[ArrayLike, "..."],
|
34
|
+
sigma: Float[ArrayLike, "..."],
|
33
35
|
) -> Float[Array, "..."]:
|
34
36
|
"""
|
35
37
|
The log of the probability density function (log PDF) for a Normal distribution.
|
36
38
|
|
37
39
|
# Parameters
|
38
|
-
- `x`:
|
39
|
-
- `mu`:
|
40
|
-
- `sigma`:
|
40
|
+
- `x`: Where to evaluate the log PDF.
|
41
|
+
- `mu`: The mean.
|
42
|
+
- `sigma`: The standard deviation.
|
41
43
|
|
42
44
|
# Returns
|
43
45
|
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
44
46
|
"""
|
45
47
|
# Cast to Array
|
46
|
-
x, mu, sigma = jnp.
|
48
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
47
49
|
|
48
|
-
return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square(
|
49
|
-
(x - mu) / sigma
|
50
|
-
)
|
50
|
+
return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square((x - mu) / sigma)
|
51
51
|
|
52
52
|
|
53
53
|
def uprob(
|
54
|
-
x: Float[ArrayLike, "..."],
|
54
|
+
x: Float[ArrayLike, "..."],
|
55
|
+
mu: Float[ArrayLike, "..."],
|
56
|
+
sigma: Float[ArrayLike, "..."],
|
55
57
|
) -> Float[Array, "..."]:
|
56
58
|
"""
|
57
59
|
The unnormalized probability density function (uPDF) for a Normal distribution.
|
58
60
|
|
59
61
|
# Parameters
|
60
|
-
- `x`:
|
61
|
-
- `mu`:
|
62
|
-
- `sigma`:
|
62
|
+
- `x`: Where to evaluate the PDF.
|
63
|
+
- `mu`: The mean.
|
64
|
+
- `sigma`: The standard deviation.
|
63
65
|
|
64
66
|
# Returns
|
65
67
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
66
68
|
"""
|
67
69
|
# Cast to Array
|
68
|
-
x, mu, sigma = jnp.
|
70
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
69
71
|
|
70
72
|
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma
|
71
73
|
|
72
74
|
|
73
75
|
def ulogprob(
|
74
|
-
x: Float[ArrayLike, "..."],
|
76
|
+
x: Float[ArrayLike, "..."],
|
77
|
+
mu: Float[ArrayLike, "..."],
|
78
|
+
sigma: Float[ArrayLike, "..."],
|
75
79
|
) -> Float[Array, "..."]:
|
76
80
|
"""
|
77
81
|
The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
|
78
82
|
|
79
83
|
# Parameters
|
80
|
-
- `x`:
|
81
|
-
- `mu`:
|
82
|
-
- `sigma`:
|
84
|
+
- `x`: Where to evaluate the PDF.
|
85
|
+
- `mu`: The mean.
|
86
|
+
- `sigma`: The standard deviation.
|
83
87
|
|
84
88
|
# Returns
|
85
89
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
86
90
|
"""
|
87
91
|
# Cast to Array
|
88
|
-
x, mu, sigma = jnp.
|
92
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
89
93
|
|
90
94
|
return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
|
91
95
|
|
96
|
+
|
92
97
|
def cdf(
|
93
|
-
x: Float[ArrayLike, "..."],
|
98
|
+
x: Float[ArrayLike, "..."],
|
99
|
+
mu: Float[ArrayLike, "..."],
|
100
|
+
sigma: Float[ArrayLike, "..."],
|
94
101
|
) -> Float[Array, "..."]:
|
95
102
|
# Cast to Array
|
96
|
-
x, mu, sigma = jnp.
|
103
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
97
104
|
|
98
105
|
return jss.ndtr((x - mu) / sigma)
|
99
106
|
|
107
|
+
|
100
108
|
def logcdf(
|
101
|
-
x: Float[ArrayLike, "..."],
|
109
|
+
x: Float[ArrayLike, "..."],
|
110
|
+
mu: Float[ArrayLike, "..."],
|
111
|
+
sigma: Float[ArrayLike, "..."],
|
102
112
|
) -> Float[Array, "..."]:
|
103
113
|
# Cast to Array
|
104
|
-
x, mu, sigma = jnp.
|
114
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
105
115
|
|
106
116
|
return jss.log_ndtr((x - mu) / sigma)
|
107
117
|
|
118
|
+
|
108
119
|
def ccdf(
|
109
|
-
x: Float[ArrayLike, "..."],
|
120
|
+
x: Float[ArrayLike, "..."],
|
121
|
+
mu: Float[ArrayLike, "..."],
|
122
|
+
sigma: Float[ArrayLike, "..."],
|
110
123
|
) -> Float[Array, "..."]:
|
111
124
|
# Cast to Array
|
112
|
-
x, mu, sigma = jnp.
|
125
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
113
126
|
|
114
127
|
return jss.ndtr((mu - x) / sigma)
|
115
128
|
|
129
|
+
|
116
130
|
def logccdf(
|
117
|
-
x: Float[ArrayLike, "..."],
|
131
|
+
x: Float[ArrayLike, "..."],
|
132
|
+
mu: Float[ArrayLike, "..."],
|
133
|
+
sigma: Float[ArrayLike, "..."],
|
118
134
|
) -> Float[Array, "..."]:
|
119
135
|
# Cast to Array
|
120
|
-
x, mu, sigma = jnp.
|
136
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
121
137
|
|
122
138
|
return jss.log_ndtr((mu - x) / sigma)
|
bayinx/dists/posnormal.py
CHANGED
@@ -5,13 +5,15 @@ from bayinx.dists import normal
|
|
5
5
|
|
6
6
|
|
7
7
|
def prob(
|
8
|
-
x: Float[ArrayLike, "..."],
|
8
|
+
x: Float[ArrayLike, "..."],
|
9
|
+
mu: Float[ArrayLike, "..."],
|
10
|
+
sigma: Float[ArrayLike, "..."],
|
9
11
|
) -> Float[Array, "..."]:
|
10
12
|
"""
|
11
13
|
The probability density function (PDF) for a positive Normal distribution.
|
12
14
|
|
13
15
|
# Parameters
|
14
|
-
- `x`:
|
16
|
+
- `x`: Where to evaluate the PDF.
|
15
17
|
- `mu`: The mean.
|
16
18
|
- `sigma`: The standard deviation.
|
17
19
|
|
@@ -19,28 +21,31 @@ def prob(
|
|
19
21
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
20
22
|
"""
|
21
23
|
# Cast to Array
|
22
|
-
x, mu, sigma = jnp.
|
24
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
23
25
|
|
24
26
|
# Construct boolean mask for non-negative elements
|
25
|
-
non_negative: Array = jnp.
|
27
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
26
28
|
|
27
29
|
# Evaluate PDF
|
28
30
|
evals = jnp.where(
|
29
31
|
non_negative,
|
30
|
-
normal.prob(x, mu, sigma) / normal.cdf(mu/sigma, 0.0, 1.0),
|
31
|
-
jnp.
|
32
|
+
normal.prob(x, mu, sigma) / normal.cdf(mu / sigma, 0.0, 1.0),
|
33
|
+
jnp.asarray(0.0),
|
34
|
+
)
|
32
35
|
|
33
36
|
return evals
|
34
37
|
|
35
38
|
|
36
39
|
def logprob(
|
37
|
-
x: Float[ArrayLike, "..."],
|
40
|
+
x: Float[ArrayLike, "..."],
|
41
|
+
mu: Float[ArrayLike, "..."],
|
42
|
+
sigma: Float[ArrayLike, "..."],
|
38
43
|
) -> Float[Array, "..."]:
|
39
44
|
"""
|
40
45
|
The log of the probability density function (log PDF) for a positive Normal distribution.
|
41
46
|
|
42
47
|
# Parameters
|
43
|
-
- `x`:
|
48
|
+
- `x`: Where to evaluate the log PDF.
|
44
49
|
- `mu`: The mean.
|
45
50
|
- `sigma`: The standard deviation.
|
46
51
|
|
@@ -48,88 +53,89 @@ def logprob(
|
|
48
53
|
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
49
54
|
"""
|
50
55
|
# Cast to Array
|
51
|
-
x, mu, sigma = jnp.
|
56
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
52
57
|
|
53
58
|
# Construct boolean mask for non-negative elements
|
54
|
-
non_negative: Array = jnp.
|
59
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
55
60
|
|
56
61
|
# Evaluate log PDF
|
57
62
|
evals = jnp.where(
|
58
63
|
non_negative,
|
59
|
-
normal.logprob(x, mu, sigma) - normal.logcdf(mu/sigma, 0.0, 1.0),
|
60
|
-
-jnp.inf
|
64
|
+
normal.logprob(x, mu, sigma) - normal.logcdf(mu / sigma, 0.0, 1.0),
|
65
|
+
-jnp.inf,
|
66
|
+
)
|
61
67
|
|
62
68
|
return evals
|
63
69
|
|
64
70
|
|
65
71
|
def uprob(
|
66
|
-
x: Float[ArrayLike, "..."],
|
72
|
+
x: Float[ArrayLike, "..."],
|
73
|
+
mu: Float[ArrayLike, "..."],
|
74
|
+
sigma: Float[ArrayLike, "..."],
|
67
75
|
) -> Float[Array, "..."]:
|
68
76
|
"""
|
69
77
|
The unnormalized probability density function (uPDF) for a positive Normal distribution.
|
70
78
|
|
71
79
|
# Parameters
|
72
|
-
- `x`:
|
73
|
-
- `mu`: The mean
|
74
|
-
- `sigma`: The
|
80
|
+
- `x`: Where to evaluate the uPDF.
|
81
|
+
- `mu`: The mean.
|
82
|
+
- `sigma`: The standard deviation.
|
75
83
|
|
76
84
|
# Returns
|
77
85
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
78
86
|
"""
|
79
87
|
# Cast to Array
|
80
|
-
x, mu, sigma = jnp.
|
88
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
81
89
|
|
82
90
|
# Construct boolean mask for non-negative elements
|
83
|
-
non_negative: Array = jnp.
|
91
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
84
92
|
|
85
93
|
# Evaluate PDF
|
86
|
-
evals = jnp.where(
|
87
|
-
non_negative,
|
88
|
-
normal.prob(x, mu, sigma),
|
89
|
-
jnp.array(0.0))
|
94
|
+
evals = jnp.where(non_negative, normal.prob(x, mu, sigma), jnp.asarray(0.0))
|
90
95
|
|
91
96
|
return evals
|
92
97
|
|
93
98
|
|
94
99
|
def ulogprob(
|
95
|
-
x: Float[ArrayLike, "..."],
|
100
|
+
x: Float[ArrayLike, "..."],
|
101
|
+
mu: Float[ArrayLike, "..."],
|
102
|
+
sigma: Float[ArrayLike, "..."],
|
96
103
|
) -> Float[Array, "..."]:
|
97
104
|
"""
|
98
105
|
The log of the unnormalized probability density function (log uPDF) for a positive Normal distribution.
|
99
106
|
|
100
107
|
# Parameters
|
101
|
-
- `x`:
|
102
|
-
- `mu`: The mean
|
103
|
-
- `sigma`: The
|
108
|
+
- `x`: Where to evaluate the log uPDF.
|
109
|
+
- `mu`: The mean.
|
110
|
+
- `sigma`: The standard deviation.
|
104
111
|
|
105
112
|
# Returns
|
106
113
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
107
114
|
"""
|
108
115
|
# Cast to Array
|
109
|
-
x, mu, sigma = jnp.
|
116
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
110
117
|
|
111
118
|
# Construct boolean mask for non-negative elements
|
112
|
-
non_negative: Array = jnp.
|
119
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
113
120
|
|
114
121
|
# Evaluate log PDF
|
115
|
-
evals = jnp.where(
|
116
|
-
non_negative,
|
117
|
-
normal.logprob(x, mu, sigma),
|
118
|
-
-jnp.inf)
|
122
|
+
evals = jnp.where(non_negative, normal.logprob(x, mu, sigma), -jnp.inf)
|
119
123
|
|
120
124
|
return evals
|
121
125
|
|
122
126
|
|
123
127
|
def cdf(
|
124
|
-
x: Float[ArrayLike, "..."],
|
128
|
+
x: Float[ArrayLike, "..."],
|
129
|
+
mu: Float[ArrayLike, "..."],
|
130
|
+
sigma: Float[ArrayLike, "..."],
|
125
131
|
) -> Float[Array, "..."]:
|
126
132
|
"""
|
127
133
|
The cumulative density function (CDF) for a positive Normal distribution.
|
128
134
|
|
129
135
|
# Parameters
|
130
|
-
- `x`:
|
131
|
-
- `mu`: The mean
|
132
|
-
- `sigma`: The
|
136
|
+
- `x`: Where to evaluate the CDF.
|
137
|
+
- `mu`: The mean.
|
138
|
+
- `sigma`: The standard deviation.
|
133
139
|
|
134
140
|
# Returns
|
135
141
|
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
@@ -138,35 +144,35 @@ def cdf(
|
|
138
144
|
Not numerically stable for small `x`.
|
139
145
|
"""
|
140
146
|
# Cast to Array
|
141
|
-
x, mu, sigma = jnp.
|
147
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
142
148
|
|
143
149
|
# Construct boolean mask for non-negative elements
|
144
|
-
non_negative: Array = jnp.
|
150
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
145
151
|
|
146
152
|
# Compute intermediates
|
147
153
|
A: Array = normal.cdf(x, mu, sigma)
|
148
|
-
B: Array = normal.cdf(-
|
154
|
+
B: Array = normal.cdf(-mu / sigma, 0.0, 1.0)
|
149
155
|
C: Array = normal.cdf(mu / sigma, 0.0, 1.0)
|
150
156
|
|
151
157
|
# Evaluate CDF
|
152
|
-
evals = jnp.where(
|
153
|
-
non_negative,
|
154
|
-
(A - B) / C,
|
155
|
-
jnp.array(0.0))
|
158
|
+
evals = jnp.where(non_negative, (A - B) / C, jnp.asarray(0.0))
|
156
159
|
|
157
160
|
return evals
|
158
161
|
|
162
|
+
|
159
163
|
# TODO: make numerically stable
|
160
164
|
def logcdf(
|
161
|
-
x: Float[ArrayLike, "..."],
|
165
|
+
x: Float[ArrayLike, "..."],
|
166
|
+
mu: Float[ArrayLike, "..."],
|
167
|
+
sigma: Float[ArrayLike, "..."],
|
162
168
|
) -> Float[Array, "..."]:
|
163
169
|
"""
|
164
170
|
The log-transformed cumulative density function (log CDF) for a positive Normal distribution.
|
165
171
|
|
166
172
|
# Parameters
|
167
|
-
- `x`:
|
168
|
-
- `mu`: The mean
|
169
|
-
- `sigma`: The
|
173
|
+
- `x`: Where to evaluate the log CDF.
|
174
|
+
- `mu`: The mean.
|
175
|
+
- `sigma`: The standard deviation.
|
170
176
|
|
171
177
|
# Returns
|
172
178
|
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
@@ -175,84 +181,80 @@ def logcdf(
|
|
175
181
|
Not numerically stable for small `x`.
|
176
182
|
"""
|
177
183
|
# Cast to Array
|
178
|
-
x, mu, sigma = jnp.
|
184
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
179
185
|
|
180
186
|
# Construct boolean mask for non-negative elements
|
181
|
-
non_negative: Array = jnp.
|
187
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
182
188
|
|
183
189
|
A: Array = normal.logcdf(x, mu, sigma)
|
184
|
-
B: Array = normal.logcdf(-
|
185
|
-
C: Array = normal.logcdf(mu/sigma, 0.0, 1.0)
|
190
|
+
B: Array = normal.logcdf(-mu / sigma, 0.0, 1.0)
|
191
|
+
C: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
|
186
192
|
|
187
193
|
# Evaluate log CDF
|
188
|
-
evals = jnp.where(
|
189
|
-
non_negative,
|
190
|
-
A + jnp.log1p(-jnp.exp(B - A)) - C,
|
191
|
-
-jnp.inf)
|
194
|
+
evals = jnp.where(non_negative, A + jnp.log1p(-jnp.exp(B - A)) - C, -jnp.inf)
|
192
195
|
|
193
196
|
return evals
|
194
197
|
|
198
|
+
|
195
199
|
def ccdf(
|
196
|
-
x: Float[ArrayLike, "..."],
|
200
|
+
x: Float[ArrayLike, "..."],
|
201
|
+
mu: Float[ArrayLike, "..."],
|
202
|
+
sigma: Float[ArrayLike, "..."],
|
197
203
|
) -> Float[Array, "..."]:
|
198
204
|
"""
|
199
205
|
The complementary cumulative density function (cCDF) for a positive Normal distribution.
|
200
206
|
|
201
207
|
# Parameters
|
202
|
-
- `x`:
|
203
|
-
- `mu`: The mean
|
204
|
-
- `sigma`: The
|
208
|
+
- `x`: Where to evaluate the cCDF.
|
209
|
+
- `mu`: The mean.
|
210
|
+
- `sigma`: The standard deviation.
|
205
211
|
|
206
212
|
# Returns
|
207
213
|
The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
208
|
-
|
209
|
-
# Notes
|
210
|
-
Not numerically stable for small `x`.
|
211
214
|
"""
|
212
215
|
# Cast to arrays
|
213
|
-
x, mu, sigma = jnp.
|
216
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
214
217
|
|
215
218
|
# Construct boolean mask for non-negative elements
|
216
219
|
non_negative: Array = 0.0 <= x
|
217
220
|
|
218
221
|
# Compute intermediates
|
219
222
|
A: Array = normal.cdf(-x, -mu, sigma)
|
220
|
-
B: Array = normal.cdf(mu/sigma, 0.0, 1.0)
|
223
|
+
B: Array = normal.cdf(mu / sigma, 0.0, 1.0)
|
221
224
|
|
222
225
|
# Evaluate cCDF
|
223
|
-
evals = jnp.where(non_negative, A / B, jnp.
|
226
|
+
evals = jnp.where(non_negative, A / B, jnp.asarray(1.0))
|
224
227
|
|
225
228
|
return evals
|
226
229
|
|
227
230
|
|
228
231
|
def logccdf(
|
229
|
-
x: Float[ArrayLike, "..."],
|
232
|
+
x: Float[ArrayLike, "..."],
|
233
|
+
mu: Float[ArrayLike, "..."],
|
234
|
+
sigma: Float[ArrayLike, "..."],
|
230
235
|
) -> Float[Array, "..."]:
|
231
236
|
"""
|
232
237
|
The log-transformed complementary cumulative density function (log cCDF) for a positive Normal distribution.
|
233
238
|
|
234
239
|
# Parameters
|
235
|
-
- `x`:
|
236
|
-
- `mu`: The mean
|
237
|
-
- `sigma`: The
|
240
|
+
- `x`: Where to evaluate the log cCDF.
|
241
|
+
- `mu`: The mean.
|
242
|
+
- `sigma`: The standard deviation.
|
238
243
|
|
239
244
|
# Returns
|
240
245
|
The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
241
|
-
|
242
|
-
# Notes
|
243
|
-
Not numerically stable for small `x`.
|
244
246
|
"""
|
245
247
|
# Cast to arrays
|
246
|
-
x, mu, sigma = jnp.
|
248
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
247
249
|
|
248
250
|
# Construct boolean mask for non-negative elements
|
249
251
|
non_negative: Array = 0.0 <= x
|
250
252
|
|
251
253
|
# Compute intermediates
|
252
254
|
A: Array = normal.logcdf(-x, -mu, sigma)
|
253
|
-
B: Array = normal.logcdf(mu/sigma, 0.0, 1.0)
|
255
|
+
B: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
|
254
256
|
|
255
257
|
# Evaluate log cCDF
|
256
|
-
evals = jnp.where(non_negative, A - B, jnp.
|
258
|
+
evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
|
257
259
|
|
258
260
|
return evals
|
bayinx/mhx/vi/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
-
from bayinx.mhx.vi.meanfield import MeanField
|
2
|
-
from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
|
3
|
-
from bayinx.mhx.vi.standard import Standard
|
1
|
+
from bayinx.mhx.vi.meanfield import MeanField
|
2
|
+
from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
|
3
|
+
from bayinx.mhx.vi.standard import Standard
|
4
|
+
|
5
|
+
__all__ = ['MeanField', 'NormalizingFlow', 'Standard']
|
bayinx/mhx/vi/meanfield.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, Dict, Self
|
1
|
+
from typing import Any, Dict, Generic, Self, TypeVar
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.numpy as jnp
|
@@ -10,8 +10,8 @@ from jaxtyping import Array, Float, Key, Scalar
|
|
10
10
|
from bayinx.core import Model, Variational
|
11
11
|
from bayinx.dists import normal
|
12
12
|
|
13
|
-
|
14
|
-
class MeanField(Variational):
|
13
|
+
M = TypeVar('M', bound=Model)
|
14
|
+
class MeanField(Variational, Generic[M]):
|
15
15
|
"""
|
16
16
|
A fully factorized Gaussian approximation to a posterior distribution.
|
17
17
|
|
@@ -19,9 +19,9 @@ class MeanField(Variational):
|
|
19
19
|
- `var_params`: The variational parameters for the approximation.
|
20
20
|
"""
|
21
21
|
|
22
|
-
var_params: Dict[str, Float[Array, "..."]]
|
22
|
+
var_params: Dict[str, Float[Array, "..."]] #todo: just expand to attributes
|
23
23
|
|
24
|
-
def __init__(self, model:
|
24
|
+
def __init__(self, model: M):
|
25
25
|
"""
|
26
26
|
Constructs an unoptimized meanfield posterior approximation.
|
27
27
|
|
@@ -55,7 +55,6 @@ class MeanField(Variational):
|
|
55
55
|
|
56
56
|
return filter_spec
|
57
57
|
|
58
|
-
|
59
58
|
@eqx.filter_jit
|
60
59
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
61
60
|
# Sample variational draws
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, Self, Tuple
|
1
|
+
from typing import Any, Generic, Self, Tuple, TypeVar
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.flatten_util as jfu
|
@@ -9,8 +9,8 @@ from jaxtyping import Array, Key, Scalar
|
|
9
9
|
|
10
10
|
from bayinx.core import Flow, Model, Variational
|
11
11
|
|
12
|
-
|
13
|
-
class NormalizingFlow(Variational):
|
12
|
+
M = TypeVar('M', bound=Model)
|
13
|
+
class NormalizingFlow(Variational, Generic[M]):
|
14
14
|
"""
|
15
15
|
An ordered collection of diffeomorphisms that map a base distribution to a
|
16
16
|
normalized approximation of a posterior distribution.
|
@@ -23,7 +23,7 @@ class NormalizingFlow(Variational):
|
|
23
23
|
flows: list[Flow]
|
24
24
|
base: Variational
|
25
25
|
|
26
|
-
def __init__(self, base: Variational, flows: list[Flow], model:
|
26
|
+
def __init__(self, base: Variational, flows: list[Flow], model: M):
|
27
27
|
"""
|
28
28
|
Constructs an unoptimized normalizing flow posterior approximation.
|
29
29
|
|
bayinx/mhx/vi/standard.py
CHANGED
@@ -1,29 +1,25 @@
|
|
1
|
-
from typing import Callable
|
2
1
|
|
3
2
|
import equinox as eqx
|
4
3
|
import jax.numpy as jnp
|
5
4
|
import jax.random as jr
|
6
5
|
import jax.tree_util as jtu
|
7
6
|
from jax.flatten_util import ravel_pytree
|
8
|
-
from jaxtyping import Array,
|
7
|
+
from jaxtyping import Array, Key
|
9
8
|
|
10
|
-
from bayinx.core import
|
9
|
+
from bayinx.core._variational import M, Variational
|
11
10
|
from bayinx.dists import normal
|
12
11
|
|
13
12
|
|
14
|
-
class Standard(Variational):
|
13
|
+
class Standard(Variational[M]):
|
15
14
|
"""
|
16
15
|
A standard normal approximation to a posterior distribution.
|
17
16
|
|
18
17
|
# Attributes
|
19
18
|
- `dim`: Dimension of the parameter space.
|
20
19
|
"""
|
21
|
-
|
22
20
|
dim: int
|
23
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
|
-
_constraints: Model
|
25
21
|
|
26
|
-
def __init__(self, model:
|
22
|
+
def __init__(self, model: M):
|
27
23
|
"""
|
28
24
|
Constructs a standard normal approximation to a posterior distribution.
|
29
25
|
|
@@ -0,0 +1,35 @@
|
|
1
|
+
bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
|
2
|
+
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
bayinx/constraints/__init__.py,sha256=PiWXZKi7YdbTMKvw-OE5f-t87jJT893uAFrwWWBfOdg,64
|
4
|
+
bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
|
5
|
+
bayinx/core/__init__.py,sha256=bZvQITgW0DWuPKl3wCLKt6WHKogYKx8Zz36g8z9Aung,253
|
6
|
+
bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
|
7
|
+
bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
8
|
+
bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
|
9
|
+
bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
|
10
|
+
bayinx/core/_variational.py,sha256=b7xlUcw8JDDBfXgDLMcjsOMHpFZ2Tg3sEt965eWmctI,5431
|
11
|
+
bayinx/dists/__init__.py,sha256=9DdPea7HAnBOzaV_4gM5noPX8YCb_p06d8PJvGfFy3Y,118
|
12
|
+
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
13
|
+
bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
|
14
|
+
bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
|
15
|
+
bayinx/dists/posnormal.py,sha256=w9plA1EctXwXOiY0doc4ZndjnwptbEZBHHCGdc4gviY,7292
|
16
|
+
bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
|
17
|
+
bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
|
18
|
+
bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
19
|
+
bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
|
20
|
+
bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
21
|
+
bayinx/dists/censored/posnormal/r.py,sha256=hyuNR3HZY-Tgtso-WwjcZT6Ejxfyax_VKwIvVix44Jc,2362
|
22
|
+
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
23
|
+
bayinx/mhx/vi/__init__.py,sha256=2woNB5oZxfs8pZCkOfzriGahRFLzkLdkTj8_keTN0I0,205
|
24
|
+
bayinx/mhx/vi/meanfield.py,sha256=Z7kGQAyp5iB8rEdjbwAbVTFH4GwxlTKDZFbdJ-FN5Vs,3739
|
25
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=8pLMDdZPIt5wlgbhHWSFY1ChSWM9pvSD2bQx3zgz1F8,4710
|
26
|
+
bayinx/mhx/vi/standard.py,sha256=W-ZvigJkUpqVlREgiFm9io8ansT1XpZwq5AqSmdv--E,1578
|
27
|
+
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
28
|
+
bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
|
29
|
+
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
30
|
+
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
31
|
+
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
32
|
+
bayinx-0.3.7.dist-info/METADATA,sha256=bQGouAjty73m1UeFCOWgRMw7Is0ffja7xDXAyS-EzDM,3079
|
33
|
+
bayinx-0.3.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
34
|
+
bayinx-0.3.7.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
35
|
+
bayinx-0.3.7.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 Todd McCready
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
bayinx-0.3.5.dist-info/RECORD
DELETED
@@ -1,34 +0,0 @@
|
|
1
|
-
bayinx/__init__.py,sha256=5fb_tGeEVnrNt6IQqu7gZaJskBJHqjcg08JRPrY2ANo,139
|
2
|
-
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx/constraints/__init__.py,sha256=PSxvcuSox2JL61AG1iag2PTNKPcid_DbOQzHpYdj5RE,52
|
4
|
-
bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
|
5
|
-
bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
|
6
|
-
bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
|
7
|
-
bayinx/core/flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
8
|
-
bayinx/core/model.py,sha256=1vQPVjE0ebCdW7mLuabgQcCTi95o8n8CC6GuzJdNL1s,2956
|
9
|
-
bayinx/core/parameter.py,sha256=eECqvfMNWSU8_CkGYaAfOCneMMQGZI21kF0mErsh2Rc,1080
|
10
|
-
bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
|
11
|
-
bayinx/dists/__init__.py,sha256=qPQrl5vkS9K56GzIaHZXkSUP07YAu4lVB8K2yQ1m3SY,78
|
12
|
-
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
13
|
-
bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
|
14
|
-
bayinx/dists/normal.py,sha256=BLlp7hGMAxUbroROvzA5ChH5YLXgadeK4VOuBtjjdjs,3978
|
15
|
-
bayinx/dists/posnormal.py,sha256=NNr5OHv1fWCxYvc6hwUMIGXX31UAg0sEnc4tsxHLjUg,7726
|
16
|
-
bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
|
17
|
-
bayinx/dists/censored/__init__.py,sha256=p8T03TenD-_8YNiOgB_RKksq8hFNFejA5bnoK4JJ8Ms,67
|
18
|
-
bayinx/dists/censored/gamma2/__init__.py,sha256=qqm0n2hfid617PvyFRHAOMAp3AvpOlt5v3ns8HgTD7E,33
|
19
|
-
bayinx/dists/censored/gamma2/r.py,sha256=dE0MNTAl0E6npQhFONv341U7XbomBB-fNzQhgRjxYpk,2436
|
20
|
-
bayinx/dists/censored/posnormal/__init__.py,sha256=qqm0n2hfid617PvyFRHAOMAp3AvpOlt5v3ns8HgTD7E,33
|
21
|
-
bayinx/dists/censored/posnormal/r.py,sha256=4MfFkQ2klzOZJNjxS9g4zz1bdoJ6ehBxZQi6QkmPGgE,2232
|
22
|
-
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
23
|
-
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
24
|
-
bayinx/mhx/vi/meanfield.py,sha256=M4QrOuHaIMLTuQSD6JNF9vELnTm370tXV68JPB7B67M,3652
|
25
|
-
bayinx/mhx/vi/normalizing_flow.py,sha256=9c5ayMJ_Wgq6pUb1GYHIFIzw3Bf1AsIdOjcerLoYMrA,4655
|
26
|
-
bayinx/mhx/vi/standard.py,sha256=DfSV0r9oXzp9UM8OsZBpoJPRUhiDoAq_X2_2z_M83lA,1685
|
27
|
-
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
28
|
-
bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
|
29
|
-
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
30
|
-
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
31
|
-
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
32
|
-
bayinx-0.3.5.dist-info/METADATA,sha256=Hj8GWJef3kfJ6umsHGIFWovYXXtPegAlcsopunoHFFs,3057
|
33
|
-
bayinx-0.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
34
|
-
bayinx-0.3.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|