bayinx 0.1.0__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +1 -0
- bayinx/core/__init__.py +3 -0
- bayinx/core/flow.py +68 -0
- bayinx/core/model.py +55 -0
- bayinx/core/utils.py +54 -0
- bayinx/core/variational.py +159 -0
- bayinx/dists/__init__.py +0 -0
- bayinx/dists/bernoulli.py +33 -0
- bayinx/dists/binomial.py +0 -0
- bayinx/dists/gamma.py +0 -0
- bayinx/dists/gamma2.py +0 -0
- bayinx/dists/normal.py +83 -0
- bayinx/machinery/__init__.py +0 -0
- bayinx/machinery/variational/__init__.py +5 -0
- bayinx/machinery/variational/flows/__init__.py +3 -0
- bayinx/machinery/variational/flows/affine.py +68 -0
- bayinx/machinery/variational/flows/planar.py +76 -0
- bayinx/machinery/variational/flows/radial.py +95 -0
- bayinx/machinery/variational/flows/sylvester.py +76 -0
- bayinx/machinery/variational/meanfield.py +124 -0
- bayinx/machinery/variational/normalizing_flow.py +152 -0
- bayinx/machinery/variational/standard.py +67 -0
- bayinx-0.2.3.dist-info/METADATA +40 -0
- bayinx-0.2.3.dist-info/RECORD +26 -0
- bayinx-0.1.0.dist-info/METADATA +0 -8
- bayinx-0.1.0.dist-info/RECORD +0 -5
- {bayinx-0.1.0.dist-info → bayinx-0.2.3.dist-info}/WHEEL +0 -0
bayinx/__init__.py
CHANGED
@@ -0,0 +1 @@
|
|
1
|
+
from bayinx.core.model import Model as Model
|
bayinx/core/__init__.py
ADDED
bayinx/core/flow.py
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Callable, Dict, Self, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
from jaxtyping import Array, Float
|
6
|
+
|
7
|
+
from bayinx.core.utils import __MyMeta
|
8
|
+
|
9
|
+
|
10
|
+
class Flow(eqx.Module, metaclass=__MyMeta):
|
11
|
+
"""
|
12
|
+
A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
|
13
|
+
|
14
|
+
# Attributes
|
15
|
+
- `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
|
16
|
+
- `constraints`: A dictionary of functions that constrain their corresponding parameter.
|
17
|
+
"""
|
18
|
+
|
19
|
+
params: Dict[str, Float[Array, "..."]]
|
20
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def forward(self, draws: Array) -> Array:
|
24
|
+
"""
|
25
|
+
Computes the forward transformation of `draws`.
|
26
|
+
"""
|
27
|
+
pass
|
28
|
+
|
29
|
+
@abstractmethod
|
30
|
+
def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
|
31
|
+
"""
|
32
|
+
Computes the log-absolute-determinant of the Jacobian at `draws` and applies the forward transformation.
|
33
|
+
|
34
|
+
# Returns
|
35
|
+
A tuple of JAX Arrays containing the log-absolute-determinant of the Jacobians and transformed draws.
|
36
|
+
"""
|
37
|
+
pass
|
38
|
+
|
39
|
+
def __init_subclass__(cls):
|
40
|
+
# Add contrain_pars method
|
41
|
+
def constrain_pars(self: Self):
|
42
|
+
"""
|
43
|
+
Constrain `params` to the appropriate domain.
|
44
|
+
|
45
|
+
# Returns
|
46
|
+
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
47
|
+
"""
|
48
|
+
t_params = self.params
|
49
|
+
|
50
|
+
for par, map in self.constraints.items():
|
51
|
+
t_params[par] = map(t_params[par])
|
52
|
+
|
53
|
+
return t_params
|
54
|
+
|
55
|
+
cls.constrain_pars = eqx.filter_jit(constrain_pars)
|
56
|
+
|
57
|
+
# Add transform_pars method if not present
|
58
|
+
if not callable(getattr(cls, "transform_pars", None)):
|
59
|
+
def transform_pars(self: Self) -> Dict[str, Array]:
|
60
|
+
"""
|
61
|
+
Apply a custom transformation to `params` if needed.
|
62
|
+
|
63
|
+
# Returns
|
64
|
+
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
65
|
+
"""
|
66
|
+
return self.constrain_pars()
|
67
|
+
|
68
|
+
cls.transform_pars = eqx.filter_jit(transform_pars)
|
bayinx/core/model.py
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Callable, Dict
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
from jaxtyping import Array, Scalar
|
6
|
+
|
7
|
+
from bayinx.core.utils import __MyMeta
|
8
|
+
|
9
|
+
|
10
|
+
class Model(eqx.Module, metaclass=__MyMeta):
|
11
|
+
"""
|
12
|
+
A superclass used to define probabilistic models.
|
13
|
+
|
14
|
+
# Attributes
|
15
|
+
- `params`: A dictionary of JAX Arrays representing parameters of the model.
|
16
|
+
- `constraints`: A dictionary of functions that constrain their corresponding parameter.
|
17
|
+
"""
|
18
|
+
|
19
|
+
params: Dict[str, Array]
|
20
|
+
constraints: Dict[str, Callable[[Array], Array]]
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def eval(self, data: Any) -> Scalar:
|
24
|
+
pass
|
25
|
+
|
26
|
+
def __init_subclass__(cls):
|
27
|
+
# Add constrain method
|
28
|
+
def constrain_pars(self: Model) -> Dict[str, Array]:
|
29
|
+
"""
|
30
|
+
Constrain `params` to the appropriate domain.
|
31
|
+
|
32
|
+
# Returns
|
33
|
+
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
34
|
+
"""
|
35
|
+
t_params = self.params
|
36
|
+
|
37
|
+
for par, map in self.constraints.items():
|
38
|
+
t_params[par] = map(t_params[par])
|
39
|
+
|
40
|
+
return t_params
|
41
|
+
|
42
|
+
cls.constrain_pars = eqx.filter_jit(constrain_pars)
|
43
|
+
|
44
|
+
# Add transform_pars method if not present
|
45
|
+
if not callable(getattr(cls, "transform_pars", None)):
|
46
|
+
def transform_pars(self: Model) -> Dict[str, Array]:
|
47
|
+
"""
|
48
|
+
Apply a custom transformation to `params` if needed.
|
49
|
+
|
50
|
+
# Returns
|
51
|
+
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
52
|
+
"""
|
53
|
+
return self.constrain_pars()
|
54
|
+
|
55
|
+
cls.transform_pars = eqx.filter_jit(transform_pars)
|
bayinx/core/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
from typing import Callable, Dict
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
from jaxtyping import Array
|
5
|
+
|
6
|
+
|
7
|
+
class __MyMeta(type(eqx.Module)):
|
8
|
+
"""
|
9
|
+
Metaclass to ensure attribute types are respected.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def __call__(cls, *args, **kwargs):
|
13
|
+
obj = super().__call__(*args, **kwargs)
|
14
|
+
|
15
|
+
# Check parameters are a Dict of JAX Arrays
|
16
|
+
if not isinstance(obj.params, Dict):
|
17
|
+
raise ValueError(
|
18
|
+
f"Model {cls.__name__} must initialize 'params' as a dictionary."
|
19
|
+
)
|
20
|
+
|
21
|
+
for key, value in obj.params.items():
|
22
|
+
if not isinstance(value, Array):
|
23
|
+
raise TypeError(f"Parameter '{key}' must be a JAX Array.")
|
24
|
+
|
25
|
+
# Check constraints are a Dict of functions
|
26
|
+
if not isinstance(obj.constraints, Dict):
|
27
|
+
raise ValueError(
|
28
|
+
f"Model {cls.__name__} must initialize 'constraints' as a dictionary."
|
29
|
+
)
|
30
|
+
|
31
|
+
for key, value in obj.constraints.items():
|
32
|
+
if not isinstance(value, Callable):
|
33
|
+
raise TypeError(f"Constraint for parameter '{key}' must be a function.")
|
34
|
+
|
35
|
+
# Check that the constrain method returns a dict equivalent to `params`
|
36
|
+
t_params: Dict[str, Array] = obj.constrain_pars()
|
37
|
+
|
38
|
+
if not isinstance(t_params, Dict):
|
39
|
+
raise ValueError(
|
40
|
+
f"The 'constrain' method of {cls.__name__} must return a Dict of JAX Arrays."
|
41
|
+
)
|
42
|
+
|
43
|
+
for key, value in t_params.items():
|
44
|
+
if not isinstance(value, Array):
|
45
|
+
raise TypeError(f"Constrained parameter '{key}' must be a JAX Array.")
|
46
|
+
|
47
|
+
if not value.shape == obj.params[key].shape:
|
48
|
+
raise ValueError(
|
49
|
+
f"Constrained parameter '{key}' must have same shape as unconstrained counterpart."
|
50
|
+
)
|
51
|
+
|
52
|
+
# Check transform_pars
|
53
|
+
|
54
|
+
return obj
|
@@ -0,0 +1,159 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Callable, Self, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax
|
6
|
+
import jax.lax as lax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import jax.random as jr
|
9
|
+
import optax as opx
|
10
|
+
from jaxtyping import Array, Float, Key, PyTree, Scalar
|
11
|
+
from optax import GradientTransformation, OptState, Schedule
|
12
|
+
|
13
|
+
from bayinx.core import Model
|
14
|
+
|
15
|
+
|
16
|
+
class Variational(eqx.Module):
|
17
|
+
"""
|
18
|
+
A superclass used to define variational methods.
|
19
|
+
|
20
|
+
# Attributes
|
21
|
+
- `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
|
22
|
+
- `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
|
23
|
+
"""
|
24
|
+
|
25
|
+
_unflatten: Callable[[Float[Array, "..."]], Model]
|
26
|
+
_constraints: Model
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def sample(self, n: int, key: Key) -> Array:
|
30
|
+
"""
|
31
|
+
Sample from the variational distribution.
|
32
|
+
"""
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def eval(self, draws: Array) -> Array:
|
37
|
+
"""
|
38
|
+
Evaluate the variational distribution at `draws`.
|
39
|
+
"""
|
40
|
+
pass
|
41
|
+
|
42
|
+
@abstractmethod
|
43
|
+
def elbo(self, n: int, key: Key, data: Any = None) -> Array:
|
44
|
+
"""
|
45
|
+
Evaluate the ELBO.
|
46
|
+
"""
|
47
|
+
pass
|
48
|
+
|
49
|
+
@abstractmethod
|
50
|
+
def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
|
51
|
+
"""
|
52
|
+
Evaluate the gradient of the ELBO.
|
53
|
+
"""
|
54
|
+
pass
|
55
|
+
|
56
|
+
@abstractmethod
|
57
|
+
def filter_spec(self):
|
58
|
+
"""
|
59
|
+
Filter specification for dynamic and static components of the `Variational`.
|
60
|
+
"""
|
61
|
+
pass
|
62
|
+
|
63
|
+
def __init_subclass__(cls):
|
64
|
+
"""
|
65
|
+
Construct methods that are shared across all VI methods.
|
66
|
+
"""
|
67
|
+
|
68
|
+
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
69
|
+
"""
|
70
|
+
Reconstruct models from variational draws and evaluate their posterior density.
|
71
|
+
|
72
|
+
# Parameters
|
73
|
+
- `draws`: A set of variational draws.
|
74
|
+
- `data`: Data used to evaluate the posterior(if needed).
|
75
|
+
"""
|
76
|
+
# Unflatten variational draw
|
77
|
+
model: Model = self._unflatten(draws)
|
78
|
+
|
79
|
+
# Combine with constraints
|
80
|
+
model: Model = eqx.combine(model, self._constraints)
|
81
|
+
|
82
|
+
# Evaluate posterior density
|
83
|
+
return model.eval(data)
|
84
|
+
|
85
|
+
cls.eval_model = jax.vmap(eqx.filter_jit(eval_model), (None, 0, None))
|
86
|
+
|
87
|
+
def fit(
|
88
|
+
self,
|
89
|
+
max_iters: int,
|
90
|
+
data: Any = None,
|
91
|
+
learning_rate: float = 1,
|
92
|
+
tolerance: float = 1e-4,
|
93
|
+
var_draws: int = 1,
|
94
|
+
key: Key = jr.PRNGKey(0),
|
95
|
+
) -> Self:
|
96
|
+
"""
|
97
|
+
Optimize the variational distribution.
|
98
|
+
|
99
|
+
# Parameters
|
100
|
+
- `max_iters`: Maximum number of iterations for the optimization loop.
|
101
|
+
- `data`: Data to evaluate the posterior density with(if available).
|
102
|
+
- `learning_rate`: Initial learning rate for optimization.
|
103
|
+
- `tolerance`: Relative tolerance of ELBO decrease for stopping early.
|
104
|
+
- `var_draws`: Number of variational draws to draw each iteration.
|
105
|
+
- `key`: A PRNG key.
|
106
|
+
"""
|
107
|
+
# Construct scheduler
|
108
|
+
schedule: Schedule = opx.cosine_decay_schedule(
|
109
|
+
init_value=learning_rate, decay_steps=max_iters
|
110
|
+
)
|
111
|
+
|
112
|
+
# Initialize optimizer
|
113
|
+
optim: GradientTransformation = opx.chain(
|
114
|
+
opx.scale(-1.0), opx.nadam(schedule)
|
115
|
+
)
|
116
|
+
opt_state: OptState = optim.init(eqx.filter(self, self.filter_spec()))
|
117
|
+
|
118
|
+
# Optimization loop helper functions
|
119
|
+
@eqx.filter_jit
|
120
|
+
def condition(state: Tuple[Self, OptState, Scalar, Key]):
|
121
|
+
# Unpack iteration state
|
122
|
+
self, opt_state, i, key = state
|
123
|
+
|
124
|
+
return i < max_iters
|
125
|
+
|
126
|
+
@eqx.filter_jit
|
127
|
+
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
128
|
+
# Unpack iteration state
|
129
|
+
self, opt_state, i, key = state
|
130
|
+
|
131
|
+
# Update iteration
|
132
|
+
i = i + 1
|
133
|
+
|
134
|
+
# Update PRNG key
|
135
|
+
key, _ = jr.split(key)
|
136
|
+
|
137
|
+
# Compute gradient of the ELBO
|
138
|
+
updates: PyTree = self.elbo_grad(var_draws, key, data)
|
139
|
+
|
140
|
+
# Compute updates
|
141
|
+
updates, opt_state = optim.update(
|
142
|
+
updates, opt_state, eqx.filter(self, self.filter_spec())
|
143
|
+
)
|
144
|
+
|
145
|
+
# Update variational distribution
|
146
|
+
self: Self = eqx.apply_updates(self, updates)
|
147
|
+
|
148
|
+
return self, opt_state, i, key
|
149
|
+
|
150
|
+
# Run optimization loop
|
151
|
+
self = lax.while_loop(
|
152
|
+
cond_fun=condition,
|
153
|
+
body_fun=body,
|
154
|
+
init_val=(self, opt_state, jnp.array(0, jnp.uint32), key),
|
155
|
+
)[0]
|
156
|
+
|
157
|
+
return self
|
158
|
+
|
159
|
+
cls.fit = eqx.filter_jit(fit)
|
bayinx/dists/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,33 @@
|
|
1
|
+
import jax.lax as lax
|
2
|
+
from jaxtyping import Array, ArrayLike, Real, UInt
|
3
|
+
|
4
|
+
|
5
|
+
# MARK: Functions ----
|
6
|
+
def prob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
|
7
|
+
"""
|
8
|
+
The probability mass function (PMF) for a Bernoulli distribution.
|
9
|
+
|
10
|
+
# Parameters
|
11
|
+
- `x`: Value(s) at which to evaluate the PDF.
|
12
|
+
- `p`: The probability parameter(s).
|
13
|
+
|
14
|
+
# Returns
|
15
|
+
The PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
16
|
+
"""
|
17
|
+
|
18
|
+
return lax.pow(p, x) * lax.pow(1 - p, 1 - x)
|
19
|
+
|
20
|
+
|
21
|
+
def logprob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
|
22
|
+
"""
|
23
|
+
The log probability mass function (log PMF) for a Bernoulli distribution.
|
24
|
+
|
25
|
+
# Parameters
|
26
|
+
- `x`: Value(s) at which to evaluate the log PMF.
|
27
|
+
- `p`: The probability parameter(s).
|
28
|
+
|
29
|
+
# Returns
|
30
|
+
The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
31
|
+
"""
|
32
|
+
|
33
|
+
return lax.log(p) * x + (1 - x) * lax.log(1 - p)
|
bayinx/dists/binomial.py
ADDED
File without changes
|
bayinx/dists/gamma.py
ADDED
File without changes
|
bayinx/dists/gamma2.py
ADDED
File without changes
|
bayinx/dists/normal.py
ADDED
@@ -0,0 +1,83 @@
|
|
1
|
+
# MARK: Imports ----
|
2
|
+
import jax.lax as _lax
|
3
|
+
|
4
|
+
## Typing
|
5
|
+
from jaxtyping import Array, Real
|
6
|
+
|
7
|
+
# MARK: Constants
|
8
|
+
_PI = 3.141592653589793
|
9
|
+
|
10
|
+
|
11
|
+
# MARK: Functions ----
|
12
|
+
def prob(
|
13
|
+
x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
|
14
|
+
) -> Real[Array, "..."]:
|
15
|
+
"""
|
16
|
+
The probability density function (PDF) for a Normal distribution.
|
17
|
+
|
18
|
+
# Parameters
|
19
|
+
- `x`: Value(s) at which to evaluate the PDF.
|
20
|
+
- `mu`: The mean/location parameter(s).
|
21
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
22
|
+
|
23
|
+
# Returns
|
24
|
+
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
25
|
+
"""
|
26
|
+
|
27
|
+
return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / (
|
28
|
+
sigma * _lax.sqrt(2.0 * _PI)
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
def logprob(
|
33
|
+
x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
|
34
|
+
) -> Real[Array, "..."]:
|
35
|
+
"""
|
36
|
+
The log of the probability density function (log PDF) for a Normal distribution.
|
37
|
+
|
38
|
+
# Parameters
|
39
|
+
- `x`: Value(s) at which to evaluate the log PDF.
|
40
|
+
- `mu`: The mean/location parameter(s).
|
41
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
42
|
+
|
43
|
+
# Returns
|
44
|
+
The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
45
|
+
"""
|
46
|
+
|
47
|
+
return -_lax.log(sigma * _lax.sqrt(2.0 * _PI)) - 0.5 * _lax.square((x - mu) / sigma)
|
48
|
+
|
49
|
+
|
50
|
+
def uprob(
|
51
|
+
x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
|
52
|
+
) -> Real[Array, "..."]:
|
53
|
+
"""
|
54
|
+
The unnormalized probability density function (uPDF) for a Normal distribution.
|
55
|
+
|
56
|
+
# Parameters
|
57
|
+
- `x`: Value(s) at which to evaluate the uPDF.
|
58
|
+
- `mu`: The mean/location parameter(s).
|
59
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
60
|
+
|
61
|
+
# Returns
|
62
|
+
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
63
|
+
"""
|
64
|
+
|
65
|
+
return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma
|
66
|
+
|
67
|
+
|
68
|
+
def ulogprob(
|
69
|
+
x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
|
70
|
+
) -> Real[Array, "..."]:
|
71
|
+
"""
|
72
|
+
The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
|
73
|
+
|
74
|
+
# Parameters
|
75
|
+
- `x`: Value(s) at which to evaluate the log uPDF.
|
76
|
+
- `mu`: The mean/location parameter(s).
|
77
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
78
|
+
|
79
|
+
# Returns
|
80
|
+
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
81
|
+
"""
|
82
|
+
|
83
|
+
return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma)
|
File without changes
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from functools import partial
|
2
|
+
from typing import Callable, Dict, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from jaxtyping import Array, Float, Scalar
|
8
|
+
|
9
|
+
from bayinx.core import Flow
|
10
|
+
|
11
|
+
|
12
|
+
class Affine(Flow):
|
13
|
+
"""
|
14
|
+
An affine flow.
|
15
|
+
|
16
|
+
# Attributes
|
17
|
+
- `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
|
18
|
+
- `constraints`: A dictionary of constraining transformations.
|
19
|
+
"""
|
20
|
+
|
21
|
+
params: Dict[str, Float[Array, "..."]]
|
22
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
|
23
|
+
eqx.field(static=True)
|
24
|
+
)
|
25
|
+
|
26
|
+
def __init__(self, dim: int):
|
27
|
+
"""
|
28
|
+
Initializes an affine flow.
|
29
|
+
|
30
|
+
# Parameters
|
31
|
+
- `dim`: The dimension of the parameter space.
|
32
|
+
"""
|
33
|
+
self.params = {
|
34
|
+
"shift": jnp.zeros(dim),
|
35
|
+
"scale": jnp.zeros((dim, dim)),
|
36
|
+
}
|
37
|
+
|
38
|
+
self.constraints = {"scale": lambda m: jnp.tril(jnp.exp(m))}
|
39
|
+
|
40
|
+
@eqx.filter_jit
|
41
|
+
def forward(self, draws: Array) -> Array:
|
42
|
+
params = self.constrain_pars()
|
43
|
+
|
44
|
+
# Extract parameters
|
45
|
+
shift: Array = params["shift"]
|
46
|
+
scale: Array = params["scale"]
|
47
|
+
|
48
|
+
# Compute forward transformation
|
49
|
+
draws = draws @ scale + shift
|
50
|
+
|
51
|
+
return draws
|
52
|
+
|
53
|
+
@eqx.filter_jit
|
54
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
55
|
+
def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
|
56
|
+
params = self.constrain_pars()
|
57
|
+
|
58
|
+
# Extract parameters
|
59
|
+
shift: Array = params["shift"]
|
60
|
+
scale: Array = params["scale"]
|
61
|
+
|
62
|
+
# Compute forward transformation
|
63
|
+
draws = draws @ scale + shift
|
64
|
+
|
65
|
+
# Compute ladj
|
66
|
+
ladj: Scalar = jnp.log(jnp.diag(scale)).sum()
|
67
|
+
|
68
|
+
return ladj, draws
|
@@ -0,0 +1,76 @@
|
|
1
|
+
from functools import partial
|
2
|
+
from typing import Callable, Dict, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax.random as jr
|
8
|
+
from jaxtyping import Array, Float, Scalar
|
9
|
+
|
10
|
+
from bayinx.core import Flow
|
11
|
+
|
12
|
+
|
13
|
+
class Planar(Flow):
|
14
|
+
"""
|
15
|
+
A planar flow.
|
16
|
+
|
17
|
+
# Attributes
|
18
|
+
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
19
|
+
- `constraints`: A dictionary of constraining transformations.
|
20
|
+
"""
|
21
|
+
|
22
|
+
params: Dict[str, Float[Array, "..."]]
|
23
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
|
24
|
+
eqx.field(static=True)
|
25
|
+
)
|
26
|
+
|
27
|
+
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
28
|
+
"""
|
29
|
+
Initializes a planar flow.
|
30
|
+
|
31
|
+
# Parameters
|
32
|
+
- `dim`: The dimension of the parameter space.
|
33
|
+
"""
|
34
|
+
self.params = {
|
35
|
+
"u": jr.normal(key, (dim,)),
|
36
|
+
"w": jr.normal(key, (dim,)),
|
37
|
+
"b": jr.normal(key, (1,)),
|
38
|
+
}
|
39
|
+
self.constraints = {}
|
40
|
+
|
41
|
+
@eqx.filter_jit
|
42
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
43
|
+
def forward(self, draws: Array) -> Array:
|
44
|
+
params = self.constrain_pars()
|
45
|
+
|
46
|
+
# Extract parameters
|
47
|
+
w: Array = params["w"]
|
48
|
+
u: Array = params["u"]
|
49
|
+
b: Array = params["b"]
|
50
|
+
|
51
|
+
# Compute forward transformation
|
52
|
+
draws = draws + u * jnp.tanh(draws.dot(w) + b)
|
53
|
+
|
54
|
+
return draws
|
55
|
+
|
56
|
+
@eqx.filter_jit
|
57
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
58
|
+
def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
|
59
|
+
params = self.constrain_pars()
|
60
|
+
|
61
|
+
# Extract parameters
|
62
|
+
w: Array = params["w"]
|
63
|
+
u: Array = params["u"]
|
64
|
+
b: Array = params["b"]
|
65
|
+
|
66
|
+
# Compute shared intermediates
|
67
|
+
x: Array = draws.dot(w) + b
|
68
|
+
|
69
|
+
# Compute forward transformation
|
70
|
+
draws = draws + u * jnp.tanh(x)
|
71
|
+
|
72
|
+
# Compute ladj
|
73
|
+
h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
|
74
|
+
ladj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
|
75
|
+
|
76
|
+
return ladj, draws
|
@@ -0,0 +1,95 @@
|
|
1
|
+
from functools import partial
|
2
|
+
from typing import Callable, Dict, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax.random as jr
|
8
|
+
from jax.numpy.linalg import norm
|
9
|
+
from jaxtyping import Array, Float, Scalar
|
10
|
+
|
11
|
+
from bayinx.core import Flow
|
12
|
+
|
13
|
+
|
14
|
+
class Radial(Flow):
|
15
|
+
"""
|
16
|
+
A radial flow.
|
17
|
+
|
18
|
+
# Attributes
|
19
|
+
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
20
|
+
- `constraints`: A dictionary of constraining transformations.
|
21
|
+
"""
|
22
|
+
|
23
|
+
params: Dict[str, Float[Array, "..."]]
|
24
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
|
25
|
+
eqx.field(static=True)
|
26
|
+
)
|
27
|
+
|
28
|
+
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
29
|
+
"""
|
30
|
+
Initializes a planar flow.
|
31
|
+
|
32
|
+
# Parameters
|
33
|
+
- `dim`: The dimension of the parameter space.
|
34
|
+
"""
|
35
|
+
self.params = {
|
36
|
+
"alpha": jnp.array(1.0),
|
37
|
+
"beta": jnp.array(1.0),
|
38
|
+
"center": jnp.ones(dim),
|
39
|
+
}
|
40
|
+
self.constraints = {"beta": jnp.exp}
|
41
|
+
|
42
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
43
|
+
@eqx.filter_jit
|
44
|
+
def forward(self, draws: Array) -> Array:
|
45
|
+
"""
|
46
|
+
Applies the forward radial transformation for each draw.
|
47
|
+
|
48
|
+
# Parameters
|
49
|
+
- `draws`: Draws from some layer of a normalizing flow.
|
50
|
+
|
51
|
+
# Returns
|
52
|
+
The transformed samples.
|
53
|
+
"""
|
54
|
+
params = self.transform_pars()
|
55
|
+
|
56
|
+
# Extract parameters
|
57
|
+
alpha = params["alpha"]
|
58
|
+
beta = params["beta"]
|
59
|
+
center = params["center"]
|
60
|
+
|
61
|
+
# Compute distance to center per-draw
|
62
|
+
r: Array = norm(draws - center)
|
63
|
+
|
64
|
+
# Apply forward transformation
|
65
|
+
draws = draws + (beta / (alpha + r)) * (draws - center)
|
66
|
+
|
67
|
+
return draws
|
68
|
+
|
69
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
70
|
+
@eqx.filter_jit
|
71
|
+
def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
|
72
|
+
params = self.transform_pars()
|
73
|
+
|
74
|
+
# Extract parameters
|
75
|
+
alpha = params["alpha"]
|
76
|
+
beta = params["beta"]
|
77
|
+
center = params["center"]
|
78
|
+
|
79
|
+
# Compute distance to center per-draw
|
80
|
+
r: Array = norm(draws - center)
|
81
|
+
|
82
|
+
# Compute shared intermediates
|
83
|
+
x: Array = beta / (alpha + r)
|
84
|
+
|
85
|
+
# Apply forward transformation
|
86
|
+
draws = draws + (x) * (draws - center)
|
87
|
+
|
88
|
+
# Compute density adjustment
|
89
|
+
ladj = jnp.log(
|
90
|
+
jnp.abs(
|
91
|
+
(1.0 + alpha * beta / (alpha + r) ** 2.0) * (1.0 + x) ** (center.size - 1.0)
|
92
|
+
)
|
93
|
+
)
|
94
|
+
|
95
|
+
return ladj, draws
|
@@ -0,0 +1,76 @@
|
|
1
|
+
from functools import partial
|
2
|
+
from typing import Callable, Dict, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax.random as jr
|
8
|
+
from jaxtyping import Array, Float, Scalar
|
9
|
+
|
10
|
+
from bayinx.core import Flow
|
11
|
+
|
12
|
+
|
13
|
+
class Sylvester(Flow):
|
14
|
+
"""
|
15
|
+
A sylvester flow.
|
16
|
+
|
17
|
+
# Attributes
|
18
|
+
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
19
|
+
- `constraints`: A dictionary of constraining transformations.
|
20
|
+
"""
|
21
|
+
|
22
|
+
params: Dict[str, Float[Array, "..."]]
|
23
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
|
24
|
+
eqx.field(static=True)
|
25
|
+
)
|
26
|
+
|
27
|
+
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
28
|
+
"""
|
29
|
+
Initializes a planar flow.
|
30
|
+
|
31
|
+
# Parameters
|
32
|
+
- `dim`: The dimension of the parameter space.
|
33
|
+
"""
|
34
|
+
self.params = {
|
35
|
+
"u": jr.normal(key, (dim,)),
|
36
|
+
"w": jr.normal(key, (dim,)),
|
37
|
+
"b": jr.normal(key, (1,)),
|
38
|
+
}
|
39
|
+
self.constraints = {}
|
40
|
+
|
41
|
+
@eqx.filter_jit
|
42
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
43
|
+
def forward(self, draws: Array) -> Array:
|
44
|
+
params = self.constrain_pars()
|
45
|
+
|
46
|
+
# Extract parameters
|
47
|
+
w: Array = params["w"]
|
48
|
+
u: Array = params["u"]
|
49
|
+
b: Array = params["b"]
|
50
|
+
|
51
|
+
# Compute forward transformation
|
52
|
+
draws = draws + u * jnp.tanh(draws.dot(w) + b)
|
53
|
+
|
54
|
+
return draws
|
55
|
+
|
56
|
+
@eqx.filter_jit
|
57
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
58
|
+
def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
|
59
|
+
params = self.constrain_pars()
|
60
|
+
|
61
|
+
# Extract parameters
|
62
|
+
w: Array = params["w"]
|
63
|
+
u: Array = params["u"]
|
64
|
+
b: Array = params["b"]
|
65
|
+
|
66
|
+
# Compute shared intermediates
|
67
|
+
x: Array = draws.dot(w) + b
|
68
|
+
|
69
|
+
# Compute forward transformation
|
70
|
+
draws = draws + u * jnp.tanh(x)
|
71
|
+
|
72
|
+
# Compute ladj
|
73
|
+
h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
|
74
|
+
ladj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
|
75
|
+
|
76
|
+
return ladj, draws
|
@@ -0,0 +1,124 @@
|
|
1
|
+
from typing import Any, Callable, Dict, Self
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jax.random as jr
|
6
|
+
import jax.tree_util as jtu
|
7
|
+
from jax.flatten_util import ravel_pytree
|
8
|
+
from jaxtyping import Array, Float, Key, Scalar
|
9
|
+
|
10
|
+
from bayinx.core import Model, Variational
|
11
|
+
from bayinx.dists import normal
|
12
|
+
|
13
|
+
|
14
|
+
class MeanField(Variational):
|
15
|
+
"""
|
16
|
+
A fully factorized Gaussian approximation to a posterior distribution.
|
17
|
+
|
18
|
+
# Attributes
|
19
|
+
- `var_params`: The variational parameters for the approximation.
|
20
|
+
"""
|
21
|
+
|
22
|
+
var_params: Dict[str, Float[Array, "..."]]
|
23
|
+
_unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
|
24
|
+
_constraints: Model = eqx.field(static=True)
|
25
|
+
|
26
|
+
def __init__(self, model: Model):
|
27
|
+
"""
|
28
|
+
Constructs an unoptimized meanfield posterior approximation.
|
29
|
+
|
30
|
+
# Parameters
|
31
|
+
- `model`: A probabilistic `Model` object.
|
32
|
+
"""
|
33
|
+
# Partition model
|
34
|
+
params, self._constraints = eqx.partition(model, eqx.is_array)
|
35
|
+
|
36
|
+
# Flatten params component
|
37
|
+
flat_params, self._unflatten = ravel_pytree(params)
|
38
|
+
|
39
|
+
# Initialize variational parameters
|
40
|
+
self.var_params = {
|
41
|
+
"mean": flat_params,
|
42
|
+
"log_std": jnp.zeros(flat_params.size, dtype=flat_params.dtype),
|
43
|
+
}
|
44
|
+
|
45
|
+
@eqx.filter_jit
|
46
|
+
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
47
|
+
# Sample variational draws
|
48
|
+
draws: Array = (
|
49
|
+
jr.normal(key=key, shape=(n, self.var_params["mean"].size))
|
50
|
+
* jnp.exp(self.var_params["log_std"])
|
51
|
+
+ self.var_params["mean"]
|
52
|
+
)
|
53
|
+
|
54
|
+
return draws
|
55
|
+
|
56
|
+
@eqx.filter_jit
|
57
|
+
def eval(self, draws: Array) -> Array:
|
58
|
+
return normal.logprob(
|
59
|
+
x=draws,
|
60
|
+
mu=self.var_params["mean"],
|
61
|
+
sigma=jnp.exp(self.var_params["log_std"]),
|
62
|
+
).sum(axis=1)
|
63
|
+
|
64
|
+
@eqx.filter_jit
|
65
|
+
def filter_spec(self):
|
66
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
67
|
+
filter_spec = eqx.tree_at(
|
68
|
+
lambda mf: mf.var_params,
|
69
|
+
filter_spec,
|
70
|
+
replace=True,
|
71
|
+
)
|
72
|
+
return filter_spec
|
73
|
+
|
74
|
+
@eqx.filter_jit
|
75
|
+
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
76
|
+
"""
|
77
|
+
Estimate the ELBO and its gradient(w.r.t the variational parameters).
|
78
|
+
"""
|
79
|
+
# Partition variational
|
80
|
+
dyn, static = eqx.partition(self, self.filter_spec())
|
81
|
+
|
82
|
+
@eqx.filter_jit
|
83
|
+
def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
|
84
|
+
# Combine
|
85
|
+
vari = eqx.combine(dyn, static)
|
86
|
+
|
87
|
+
# Sample draws from variational distribution
|
88
|
+
draws: Array = vari.sample(n, key)
|
89
|
+
|
90
|
+
# Evaluate posterior density for each draw
|
91
|
+
posterior_evals: Array = vari.eval_model(draws, data)
|
92
|
+
|
93
|
+
# Evaluate variational density for each draw
|
94
|
+
variational_evals: Array = vari.eval(draws)
|
95
|
+
|
96
|
+
# Evaluate ELBO
|
97
|
+
return jnp.mean(posterior_evals - variational_evals)
|
98
|
+
|
99
|
+
return elbo(dyn, n, key, data)
|
100
|
+
|
101
|
+
@eqx.filter_jit
|
102
|
+
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
103
|
+
# Partition
|
104
|
+
dyn, static = eqx.partition(self, self.filter_spec())
|
105
|
+
|
106
|
+
@eqx.filter_grad
|
107
|
+
@eqx.filter_jit
|
108
|
+
def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
|
109
|
+
# Combine
|
110
|
+
vari = eqx.combine(dyn, static)
|
111
|
+
|
112
|
+
# Sample draws from variational distribution
|
113
|
+
draws: Array = vari.sample(n, key)
|
114
|
+
|
115
|
+
# Evaluate posterior density for each draw
|
116
|
+
posterior_evals: Array = vari.eval_model(draws, data)
|
117
|
+
|
118
|
+
# Evaluate variational density for each draw
|
119
|
+
variational_evals: Array = vari.eval(draws)
|
120
|
+
|
121
|
+
# Evaluate ELBO
|
122
|
+
return jnp.mean(posterior_evals - variational_evals)
|
123
|
+
|
124
|
+
return elbo_grad(dyn, n, key, data)
|
@@ -0,0 +1,152 @@
|
|
1
|
+
from typing import Any, Callable, Self, Tuple
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.flatten_util as jfu
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import jax.random as jr
|
7
|
+
import jax.tree_util as jtu
|
8
|
+
from jaxtyping import Array, Float, Key, Scalar
|
9
|
+
|
10
|
+
from bayinx.core import Flow, Model, Variational
|
11
|
+
|
12
|
+
|
13
|
+
class NormalizingFlow(Variational):
|
14
|
+
"""
|
15
|
+
An ordered collection of diffeomorphisms that map a base distribution to a
|
16
|
+
normalized approximation of a posterior distribution.
|
17
|
+
|
18
|
+
# Attributes
|
19
|
+
- `base`: A base variational distribution.
|
20
|
+
- `flows`: An ordered collection of continuously parameterized
|
21
|
+
diffeomorphisms.
|
22
|
+
"""
|
23
|
+
|
24
|
+
flows: list[Flow]
|
25
|
+
base: Variational
|
26
|
+
_unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
|
27
|
+
_constraints: Model = eqx.field(static=True)
|
28
|
+
|
29
|
+
def __init__(self, base: Variational, flows: list[Flow], model: Model):
|
30
|
+
"""
|
31
|
+
Constructs an unoptimized normalizing flow posterior approximation.
|
32
|
+
|
33
|
+
# Parameters
|
34
|
+
- `base`: The base variational distribution.
|
35
|
+
- `flows`: A list of diffeomorphisms.
|
36
|
+
- `model`: A probabilistic `Model` object.
|
37
|
+
"""
|
38
|
+
# Partition model
|
39
|
+
params, self._constraints = eqx.partition(model, eqx.is_array)
|
40
|
+
|
41
|
+
# Flatten params component
|
42
|
+
flat_params, self._unflatten = jfu.ravel_pytree(params)
|
43
|
+
|
44
|
+
self.base = base
|
45
|
+
self.flows = flows
|
46
|
+
|
47
|
+
@eqx.filter_jit
|
48
|
+
def sample(self, n: int, key: Key = jr.PRNGKey(0)):
|
49
|
+
"""
|
50
|
+
Sample from the variational distribution `n` times.
|
51
|
+
"""
|
52
|
+
# Sample from the base distribution
|
53
|
+
draws: Array = self.base.sample(n, key)
|
54
|
+
|
55
|
+
# Apply forward transformations
|
56
|
+
for map in self.flows:
|
57
|
+
draws = map.forward(draws)
|
58
|
+
|
59
|
+
return draws
|
60
|
+
|
61
|
+
@eqx.filter_jit
|
62
|
+
def eval(self, draws: Array) -> Array:
|
63
|
+
# Evaluate base density
|
64
|
+
variational_evals: Array = self.base.eval(draws)
|
65
|
+
|
66
|
+
for map in self.flows:
|
67
|
+
# Compute adjustment
|
68
|
+
ladj, draws = map.adjust_density(draws)
|
69
|
+
|
70
|
+
# Adjust variational density
|
71
|
+
variational_evals = variational_evals - ladj
|
72
|
+
|
73
|
+
return variational_evals
|
74
|
+
|
75
|
+
@eqx.filter_jit
|
76
|
+
def _eval(self, draws: Array, data=None) -> Tuple[Scalar, Array]:
|
77
|
+
"""
|
78
|
+
Evaluate the posterior and variational densities at the transformed
|
79
|
+
`draws` to avoid extra compute when requiring variational draws for
|
80
|
+
the posterior evaluation.
|
81
|
+
|
82
|
+
# Parameters
|
83
|
+
- `draws`: Draws from the base variational distribution.
|
84
|
+
- `data`: Any data required to evaluate the posterior density.
|
85
|
+
|
86
|
+
# Returns
|
87
|
+
The posterior and variational densities.
|
88
|
+
"""
|
89
|
+
# Evaluate base density
|
90
|
+
variational_evals: Array = self.base.eval(draws)
|
91
|
+
|
92
|
+
for map in self.flows:
|
93
|
+
# Compute adjustment
|
94
|
+
ladj, draws = map.adjust_density(draws)
|
95
|
+
|
96
|
+
# Adjust variational density
|
97
|
+
variational_evals = variational_evals - ladj
|
98
|
+
|
99
|
+
# Evaluate posterior at final variational draws
|
100
|
+
posterior_evals = self.eval_model(draws, data)
|
101
|
+
|
102
|
+
return posterior_evals, variational_evals
|
103
|
+
|
104
|
+
def filter_spec(self):
|
105
|
+
# Only optimize the parameters of the flows
|
106
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
107
|
+
filter_spec = eqx.tree_at(
|
108
|
+
lambda nf: nf.flows,
|
109
|
+
filter_spec,
|
110
|
+
replace=True,
|
111
|
+
)
|
112
|
+
|
113
|
+
return filter_spec
|
114
|
+
|
115
|
+
@eqx.filter_jit
|
116
|
+
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
117
|
+
# Partition
|
118
|
+
dyn, static = eqx.partition(self, self.filter_spec())
|
119
|
+
|
120
|
+
@eqx.filter_jit
|
121
|
+
def elbo(dyn: Self, n: int, key: Key, data: Any = None):
|
122
|
+
# Combine
|
123
|
+
self = eqx.combine(dyn, static)
|
124
|
+
|
125
|
+
# Sample draws from variational distribution
|
126
|
+
draws: Array = self.base.sample(n, key)
|
127
|
+
|
128
|
+
posterior_evals, variational_evals = self._eval(draws, data)
|
129
|
+
# Evaluate ELBO
|
130
|
+
return jnp.mean(posterior_evals - variational_evals)
|
131
|
+
|
132
|
+
return elbo(dyn, n, key, data)
|
133
|
+
|
134
|
+
@eqx.filter_jit
|
135
|
+
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
136
|
+
# Partition
|
137
|
+
dyn, static = eqx.partition(self, self.filter_spec())
|
138
|
+
|
139
|
+
@eqx.filter_grad
|
140
|
+
@eqx.filter_jit
|
141
|
+
def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
|
142
|
+
# Combine
|
143
|
+
self = eqx.combine(dyn, static)
|
144
|
+
|
145
|
+
# Sample draws from variational distribution
|
146
|
+
draws: Array = self.base.sample(n, key)
|
147
|
+
|
148
|
+
posterior_evals, variational_evals = self._eval(draws, data)
|
149
|
+
# Evaluate ELBO
|
150
|
+
return jnp.mean(posterior_evals - variational_evals)
|
151
|
+
|
152
|
+
return elbo_grad(dyn, n, key, data)
|
@@ -0,0 +1,67 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jax.random as jr
|
6
|
+
import jax.tree_util as jtu
|
7
|
+
from jax.flatten_util import ravel_pytree
|
8
|
+
from jaxtyping import Array, Float, Key
|
9
|
+
|
10
|
+
from bayinx.core import Model, Variational
|
11
|
+
from bayinx.dists import normal
|
12
|
+
|
13
|
+
|
14
|
+
class Standard(Variational):
|
15
|
+
"""
|
16
|
+
A standard normal distribution approximation to a posterior distribution.
|
17
|
+
|
18
|
+
# Attributes
|
19
|
+
- `dim`: Dimension of the parameter space.
|
20
|
+
"""
|
21
|
+
|
22
|
+
dim: int = eqx.field(static=True)
|
23
|
+
_unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
|
24
|
+
_constraints: Model = eqx.field(static=True)
|
25
|
+
|
26
|
+
def __init__(self, model: Model):
|
27
|
+
"""
|
28
|
+
Constructs a standard normal approximation to a posterior distribution.
|
29
|
+
|
30
|
+
# Parameters
|
31
|
+
- `model`: A probabilistic `Model` object.
|
32
|
+
"""
|
33
|
+
# Partition model
|
34
|
+
_, self._constraints = eqx.partition(model, eqx.is_array)
|
35
|
+
|
36
|
+
# Flatten params component
|
37
|
+
_, self._unflatten = ravel_pytree(_)
|
38
|
+
|
39
|
+
# Store dimension of parameter space
|
40
|
+
self.dim = jnp.size(_)
|
41
|
+
|
42
|
+
@eqx.filter_jit
|
43
|
+
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
44
|
+
# Sample variational draws
|
45
|
+
draws: Array = jr.normal(key=key, shape=(n, self.dim))
|
46
|
+
|
47
|
+
return draws
|
48
|
+
|
49
|
+
@eqx.filter_jit
|
50
|
+
def eval(self, draws: Array) -> Array:
|
51
|
+
return normal.logprob(
|
52
|
+
x=draws,
|
53
|
+
mu=jnp.array(0.0),
|
54
|
+
sigma=jnp.array(1.0),
|
55
|
+
).sum(axis=1)
|
56
|
+
|
57
|
+
@eqx.filter_jit
|
58
|
+
def filter_spec(self):
|
59
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
60
|
+
|
61
|
+
return filter_spec
|
62
|
+
|
63
|
+
def elbo(self):
|
64
|
+
return None
|
65
|
+
|
66
|
+
def elbo_grad(self):
|
67
|
+
return None
|
@@ -0,0 +1,40 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: bayinx
|
3
|
+
Version: 0.2.3
|
4
|
+
Summary: Bayesian Inference with JAX
|
5
|
+
Requires-Python: >=3.12
|
6
|
+
Requires-Dist: equinox>=0.11.12
|
7
|
+
Requires-Dist: jax>=0.4.38
|
8
|
+
Requires-Dist: jaxtyping>=0.2.36
|
9
|
+
Requires-Dist: optax>=0.2.4
|
10
|
+
Requires-Dist: pytest-benchmark>=5.1.0
|
11
|
+
Requires-Dist: pytest>=8.3.5
|
12
|
+
Description-Content-Type: text/markdown
|
13
|
+
|
14
|
+
# <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
15
|
+
|
16
|
+
The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is equivalent to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
|
17
|
+
|
18
|
+
In the short-term, I'm going to focus on:
|
19
|
+
1) Implementing as much machinery as I feel is enough.
|
20
|
+
2) Figuring out how to design the `Model` superclass to have something like the `transformed pars {}` block but unifies transformations and constraints.
|
21
|
+
3) Figuring out how to design the library to automatically recognize what kind of machinery is amenable to a given probabilistic model.
|
22
|
+
|
23
|
+
In the long-term, I'm going to focus on:
|
24
|
+
1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
|
25
|
+
2) How to make working with the posterior as easy as possible.
|
26
|
+
- That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
|
27
|
+
|
28
|
+
Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
|
29
|
+
|
30
|
+
Additionally, when I get around to it I'd like the package documentation to also include theoretical and implementation details for all machinery implemented(with overthinking boxes because I do like that design from McElreath's book).
|
31
|
+
|
32
|
+
|
33
|
+
# TODO
|
34
|
+
- Find some way to discern between models with all floating-point parameters and weirder models with integer parameters. Useful for restricting variational methods like `MeanField` to `Model`s that only have floating-point parameters.
|
35
|
+
- Look into adaptively tuning ADAM hyperparameters.
|
36
|
+
- Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
|
37
|
+
- Low-rank affine flow?
|
38
|
+
- https://arxiv.org/pdf/1803.05649 implement sylvester flows.
|
39
|
+
- Learn how to generate documentation lol.
|
40
|
+
- Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
|
@@ -0,0 +1,26 @@
|
|
1
|
+
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
|
+
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
+
bayinx/core/flow.py,sha256=4vj1t2xNPGp1VPE4xUshY-rHAw__KvSwjGDtKkW2taE,2252
|
5
|
+
bayinx/core/model.py,sha256=AI4eHrXAds3K7eWgZ9g5E6Kh76HP6WTn6s6q_0tnhck,1719
|
6
|
+
bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
|
7
|
+
bayinx/core/variational.py,sha256=T42uUNkF2tP1HJPyeIv7ISdika_G28wR_OOFXzx_hgo,4978
|
8
|
+
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
10
|
+
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
+
bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
+
bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
|
14
|
+
bayinx/machinery/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
+
bayinx/machinery/variational/__init__.py,sha256=5GdhqBOHKsXg2tZGAMNlxyrLPD0-s64wAEy8998cHZ4,247
|
16
|
+
bayinx/machinery/variational/meanfield.py,sha256=stI96DNhHNIROnr8rLNEN9SN_0lXqkwit0KNto34Q6A,3889
|
17
|
+
bayinx/machinery/variational/normalizing_flow.py,sha256=qypgPq9vIqSIJNOHCDaN-hvwFfttNQ5_yXqvmi5hslI,4796
|
18
|
+
bayinx/machinery/variational/standard.py,sha256=IQdNd5QIE8u3zcOw7K4EW69lIQ0ZUGGDvwZVyvrYHxA,1739
|
19
|
+
bayinx/machinery/variational/flows/__init__.py,sha256=VGh-ffuUfMso_0JxwGCJQ2yVnFJdOrkFsSnorojQldY,213
|
20
|
+
bayinx/machinery/variational/flows/affine.py,sha256=TPyUUPRoSkyDMwGO5wtq-Ei8DAUvlb_N6JCk7uPlbJQ,1748
|
21
|
+
bayinx/machinery/variational/flows/planar.py,sha256=rJ1XpqoWzig_5Udq6oCh5JV4ptlTRLRS7tb9DCX22lE,2013
|
22
|
+
bayinx/machinery/variational/flows/radial.py,sha256=NF1tCd_PH6m8eqjJkom2c30sRUQ04Vf8zeRx_RQCDcg,2526
|
23
|
+
bayinx/machinery/variational/flows/sylvester.py,sha256=DeZl4Fkz9XpCGsfDcjS0eWlrMR0xMO0MPfJvoDhixSA,2019
|
24
|
+
bayinx-0.2.3.dist-info/METADATA,sha256=3sJiQZ3firSKOjm5DFeet0euZ6ElfRO56x2AoP8ktNk,3099
|
25
|
+
bayinx-0.2.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
26
|
+
bayinx-0.2.3.dist-info/RECORD,,
|
bayinx-0.1.0.dist-info/METADATA
DELETED
bayinx-0.1.0.dist-info/RECORD
DELETED
@@ -1,5 +0,0 @@
|
|
1
|
-
bayinx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx-0.1.0.dist-info/METADATA,sha256=vo6K-1pFW6eW-SeCGyfJSxjkJ92h21qxUgLeVCL8C3Q,209
|
4
|
-
bayinx-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
bayinx-0.1.0.dist-info/RECORD,,
|
File without changes
|