bayinx 0.2.23__tar.gz → 0.2.25__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bayinx-0.2.23 → bayinx-0.2.25}/PKG-INFO +1 -1
- {bayinx-0.2.23 → bayinx-0.2.25}/pyproject.toml +2 -2
- bayinx-0.2.25/src/bayinx/core/constraints.py +61 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/core/flow.py +1 -3
- bayinx-0.2.25/src/bayinx/core/model.py +74 -0
- bayinx-0.2.25/src/bayinx/core/variational.py +162 -0
- bayinx-0.2.25/src/bayinx/mhx/__init__.py +1 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/mhx/vi/normalizing_flow.py +1 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/tests/test_variational.py +17 -15
- bayinx-0.2.23/src/bayinx/core/model.py +0 -74
- bayinx-0.2.23/src/bayinx/core/utils.py +0 -54
- bayinx-0.2.23/src/bayinx/core/variational.py +0 -167
- {bayinx-0.2.23 → bayinx-0.2.25}/.github/workflows/release_and_publish.yml +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/.gitignore +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/README.md +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/__init__.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/core/__init__.py +0 -0
- /bayinx-0.2.23/src/bayinx/mhx/__init__.py → /bayinx-0.2.25/src/bayinx/core/utils.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/dists/__init__.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/dists/bernoulli.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/dists/binomial.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/dists/gamma.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/dists/gamma2.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/dists/normal.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/dists/uniform.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/mhx/vi/__init__.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/mhx/vi/flows/planar.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/mhx/vi/flows/radial.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/mhx/vi/meanfield.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/mhx/vi/standard.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/src/bayinx/py.typed +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/tests/__init__.py +0 -0
- {bayinx-0.2.23 → bayinx-0.2.25}/uv.lock +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "bayinx"
|
3
|
-
version = "0.2.
|
3
|
+
version = "0.2.25"
|
4
4
|
description = "Bayesian Inference with JAX"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.12"
|
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
|
|
19
19
|
addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
|
20
20
|
|
21
21
|
[tool.bumpversion]
|
22
|
-
current_version = "0.2.
|
22
|
+
current_version = "0.2.25"
|
23
23
|
parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
|
24
24
|
serialize = ["{major}.{minor}.{patch}"]
|
25
25
|
search = "{current_version}"
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
|
7
|
+
|
8
|
+
|
9
|
+
class Constraint(eqx.Module):
|
10
|
+
"""
|
11
|
+
Abstract base class for defining parameter constraints.
|
12
|
+
|
13
|
+
Subclasses should implement the `constrain` method to apply the
|
14
|
+
transformation and compute the ladj adjustment.
|
15
|
+
"""
|
16
|
+
@abstractmethod
|
17
|
+
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
18
|
+
"""
|
19
|
+
Applies the constraining transformation to an unconstrained input
|
20
|
+
and computes the log absolute determinant of the Jacobian (ladj)
|
21
|
+
of this transformation.
|
22
|
+
|
23
|
+
# Parameters
|
24
|
+
- `x`: The unconstrained JAX Array-like input.
|
25
|
+
|
26
|
+
# Returns
|
27
|
+
A tuple containing:
|
28
|
+
- The constrained JAX Array.
|
29
|
+
- A scalar JAX Array representing the ladj of the transformation.
|
30
|
+
"""
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class LowerBound(Constraint):
|
35
|
+
"""
|
36
|
+
Enforces a lower bound on the parameter.
|
37
|
+
"""
|
38
|
+
lb: ScalarLike
|
39
|
+
|
40
|
+
def __init__(self, lb: ScalarLike):
|
41
|
+
self.lb = lb
|
42
|
+
|
43
|
+
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
44
|
+
"""
|
45
|
+
Applies the lower bound constraint and computes the ladj.
|
46
|
+
|
47
|
+
# Parameters
|
48
|
+
- `x`: The unconstrained JAX Array-like input.
|
49
|
+
|
50
|
+
# Parameters
|
51
|
+
A tuple containing:
|
52
|
+
- The constrained JAX Array (x > self.lb).
|
53
|
+
- A scalar JAX Array representing the ladj of the transformation.
|
54
|
+
"""
|
55
|
+
# Compute transformation adjustment
|
56
|
+
ladj = jnp.sum(x)
|
57
|
+
|
58
|
+
# Compute transformation
|
59
|
+
x = jnp.exp(x) + self.lb
|
60
|
+
|
61
|
+
return x, ladj
|
@@ -5,10 +5,8 @@ import equinox as eqx
|
|
5
5
|
import jax.tree_util as jtu
|
6
6
|
from jaxtyping import Array, Float
|
7
7
|
|
8
|
-
from bayinx.core.utils import __MyMeta
|
9
8
|
|
10
|
-
|
11
|
-
class Flow(eqx.Module, metaclass=__MyMeta):
|
9
|
+
class Flow(eqx.Module):
|
12
10
|
"""
|
13
11
|
A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
|
14
12
|
|
@@ -0,0 +1,74 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Dict, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import jax.tree_util as jtu
|
7
|
+
from jaxtyping import Array, Scalar
|
8
|
+
|
9
|
+
from bayinx.core.constraints import Constraint
|
10
|
+
|
11
|
+
|
12
|
+
class Model(eqx.Module):
|
13
|
+
"""
|
14
|
+
A superclass used to define probabilistic models.
|
15
|
+
|
16
|
+
# Attributes
|
17
|
+
- `params`: A dictionary of JAX Arrays representing parameters of the model.
|
18
|
+
- `constraints`: A dictionary of functions that constrain their corresponding parameter.
|
19
|
+
"""
|
20
|
+
|
21
|
+
params: Dict[str, Array]
|
22
|
+
constraints: Dict[str, Constraint]
|
23
|
+
|
24
|
+
@abstractmethod
|
25
|
+
def eval(self, data: Any) -> Scalar:
|
26
|
+
pass
|
27
|
+
|
28
|
+
# Default filter specification
|
29
|
+
def filter_spec(self):
|
30
|
+
"""
|
31
|
+
Generates a filter specification to subset relevant parameters for the model.
|
32
|
+
"""
|
33
|
+
# Generate empty specification
|
34
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
35
|
+
|
36
|
+
# Specify JAX Array parameters
|
37
|
+
filter_spec = eqx.tree_at(
|
38
|
+
lambda model: model.params,
|
39
|
+
filter_spec,
|
40
|
+
replace=jtu.tree_map(eqx.is_array, self.params),
|
41
|
+
)
|
42
|
+
|
43
|
+
return filter_spec
|
44
|
+
|
45
|
+
# Add constrain method
|
46
|
+
@eqx.filter_jit
|
47
|
+
def constrain_pars(self) -> Tuple[Dict[str, Array], Scalar]:
|
48
|
+
"""
|
49
|
+
Constrain `params` to the appropriate domain.
|
50
|
+
|
51
|
+
# Returns
|
52
|
+
A dictionary of transformed JAX Arrays representing the constrained parameters and the adjustment to the posterior density.
|
53
|
+
"""
|
54
|
+
t_params: Dict[str, Array] = self.params
|
55
|
+
target: Scalar = jnp.array(0.0)
|
56
|
+
|
57
|
+
for par, map in self.constraints.items():
|
58
|
+
# Apply transformation
|
59
|
+
t_params[par], ladj = map.constrain(t_params[par])
|
60
|
+
|
61
|
+
# Adjust posterior density
|
62
|
+
target -= ladj
|
63
|
+
|
64
|
+
return t_params, target
|
65
|
+
|
66
|
+
|
67
|
+
def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
|
68
|
+
"""
|
69
|
+
Apply a custom transformation to `params` if needed.
|
70
|
+
|
71
|
+
# Returns
|
72
|
+
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
73
|
+
"""
|
74
|
+
return self.constrain_pars()
|
@@ -0,0 +1,162 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from functools import partial
|
3
|
+
from typing import Any, Callable, Self, Tuple
|
4
|
+
|
5
|
+
import equinox as eqx
|
6
|
+
import jax
|
7
|
+
import jax.lax as lax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import jax.random as jr
|
10
|
+
import optax as opx
|
11
|
+
from jaxtyping import Array, Float, Key, PyTree, Scalar
|
12
|
+
from optax import GradientTransformation, OptState, Schedule
|
13
|
+
|
14
|
+
from bayinx.core import Model
|
15
|
+
|
16
|
+
|
17
|
+
class Variational(eqx.Module):
|
18
|
+
"""
|
19
|
+
A superclass used to define variational methods.
|
20
|
+
|
21
|
+
# Attributes
|
22
|
+
- `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
|
23
|
+
- `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
|
24
|
+
"""
|
25
|
+
|
26
|
+
_unflatten: Callable[[Float[Array, "..."]], Model]
|
27
|
+
_constraints: Model
|
28
|
+
|
29
|
+
@abstractmethod
|
30
|
+
def sample(self, n: int, key: Key) -> Array:
|
31
|
+
"""
|
32
|
+
Sample from the variational distribution.
|
33
|
+
"""
|
34
|
+
pass
|
35
|
+
|
36
|
+
@abstractmethod
|
37
|
+
def eval(self, draws: Array) -> Array:
|
38
|
+
"""
|
39
|
+
Evaluate the variational distribution at `draws`.
|
40
|
+
"""
|
41
|
+
pass
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def elbo(self, n: int, key: Key, data: Any = None) -> Array:
|
45
|
+
"""
|
46
|
+
Evaluate the ELBO.
|
47
|
+
"""
|
48
|
+
pass
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
|
52
|
+
"""
|
53
|
+
Evaluate the gradient of the ELBO.
|
54
|
+
"""
|
55
|
+
pass
|
56
|
+
|
57
|
+
@abstractmethod
|
58
|
+
def filter_spec(self):
|
59
|
+
"""
|
60
|
+
Filter specification for dynamic and static components of the `Variational`.
|
61
|
+
"""
|
62
|
+
pass
|
63
|
+
|
64
|
+
@eqx.filter_jit
|
65
|
+
@partial(jax.vmap, in_axes=(None, 0, None))
|
66
|
+
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
67
|
+
"""
|
68
|
+
Reconstruct models from variational draws and evaluate their posterior density.
|
69
|
+
|
70
|
+
# Parameters
|
71
|
+
- `draws`: A set of variational draws.
|
72
|
+
- `data`: Data used to evaluate the posterior(if needed).
|
73
|
+
"""
|
74
|
+
# Unflatten variational draw
|
75
|
+
model: Model = self._unflatten(draws)
|
76
|
+
|
77
|
+
# Combine with constraints
|
78
|
+
model: Model = eqx.combine(model, self._constraints)
|
79
|
+
|
80
|
+
# Evaluate posterior density
|
81
|
+
return model.eval(data)
|
82
|
+
|
83
|
+
@eqx.filter_jit
|
84
|
+
def fit(
|
85
|
+
self,
|
86
|
+
max_iters: int,
|
87
|
+
data: Any = None,
|
88
|
+
learning_rate: float = 1,
|
89
|
+
weight_decay: float = 1e-4,
|
90
|
+
tolerance: float = 1e-4,
|
91
|
+
var_draws: int = 1,
|
92
|
+
key: Key = jr.PRNGKey(0),
|
93
|
+
) -> Self:
|
94
|
+
"""
|
95
|
+
Optimize the variational distribution.
|
96
|
+
|
97
|
+
# Parameters
|
98
|
+
- `max_iters`: Maximum number of iterations for the optimization loop.
|
99
|
+
- `data`: Data to evaluate the posterior density with(if available).
|
100
|
+
- `learning_rate`: Initial learning rate for optimization.
|
101
|
+
- `tolerance`: Relative tolerance of ELBO decrease for stopping early.
|
102
|
+
- `var_draws`: Number of variational draws to draw each iteration.
|
103
|
+
- `key`: A PRNG key.
|
104
|
+
"""
|
105
|
+
# Partition variational
|
106
|
+
dyn, static = eqx.partition(self, self.filter_spec())
|
107
|
+
|
108
|
+
# Construct scheduler
|
109
|
+
schedule: Schedule = opx.cosine_decay_schedule(
|
110
|
+
init_value=learning_rate, decay_steps=max_iters
|
111
|
+
)
|
112
|
+
|
113
|
+
# Initialize optimizer
|
114
|
+
optim: GradientTransformation = opx.chain(
|
115
|
+
opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
|
116
|
+
)
|
117
|
+
opt_state: OptState = optim.init(dyn)
|
118
|
+
|
119
|
+
# Optimization loop helper functions
|
120
|
+
@eqx.filter_jit
|
121
|
+
def condition(state: Tuple[Self, OptState, Scalar, Key]):
|
122
|
+
# Unpack iteration state
|
123
|
+
dyn, opt_state, i, key = state
|
124
|
+
|
125
|
+
return i < max_iters
|
126
|
+
|
127
|
+
@eqx.filter_jit
|
128
|
+
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
129
|
+
# Unpack iteration state
|
130
|
+
dyn, opt_state, i, key = state
|
131
|
+
|
132
|
+
# Update iteration
|
133
|
+
i = i + 1
|
134
|
+
|
135
|
+
# Update PRNG key
|
136
|
+
key, _ = jr.split(key)
|
137
|
+
|
138
|
+
# Combine variational
|
139
|
+
vari = eqx.combine(dyn, static)
|
140
|
+
|
141
|
+
# Compute gradient of the ELBO
|
142
|
+
updates: PyTree = vari.elbo_grad(var_draws, key, data)
|
143
|
+
|
144
|
+
# Compute updates
|
145
|
+
updates, opt_state = optim.update(
|
146
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
|
147
|
+
)
|
148
|
+
|
149
|
+
# Update variational distribution
|
150
|
+
dyn = eqx.apply_updates(dyn, updates)
|
151
|
+
|
152
|
+
return dyn, opt_state, i, key
|
153
|
+
|
154
|
+
# Run optimization loop
|
155
|
+
dyn = lax.while_loop(
|
156
|
+
cond_fun=condition,
|
157
|
+
body_fun=body,
|
158
|
+
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
159
|
+
)[0]
|
160
|
+
|
161
|
+
# Return optimized variational
|
162
|
+
return eqx.combine(dyn, static)
|
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -26,12 +26,13 @@ def test_meanfield(benchmark, var_draws):
|
|
26
26
|
@eqx.filter_jit
|
27
27
|
def eval(self, data: dict):
|
28
28
|
# Get constrained parameters
|
29
|
-
params = self.constrain_pars()
|
29
|
+
params, target = self.constrain_pars()
|
30
30
|
|
31
31
|
# Evaluate mu ~ N(10,1)
|
32
|
-
|
33
|
-
|
34
|
-
)
|
32
|
+
target += normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)).sum()
|
33
|
+
|
34
|
+
# Evaluate mu ~ N(10,1)
|
35
|
+
return target
|
35
36
|
|
36
37
|
# Construct model
|
37
38
|
model = NormalDist()
|
@@ -44,7 +45,7 @@ def test_meanfield(benchmark, var_draws):
|
|
44
45
|
vari.fit(10000, var_draws=var_draws)
|
45
46
|
|
46
47
|
benchmark(benchmark_fit)
|
47
|
-
vari = vari.fit(20000)
|
48
|
+
vari = vari.fit(20000,var_draws=var_draws)
|
48
49
|
|
49
50
|
# Assert parameters are roughly correct
|
50
51
|
assert all(abs(10.0 - vari.var_params["mean"]) < 0.1) and all(
|
@@ -66,12 +67,13 @@ def test_affine(benchmark, var_draws):
|
|
66
67
|
@eqx.filter_jit
|
67
68
|
def eval(self, data: dict):
|
68
69
|
# Get constrained parameters
|
69
|
-
params = self.constrain_pars()
|
70
|
+
params, target = self.constrain_pars()
|
70
71
|
|
71
72
|
# Evaluate mu ~ N(10,1)
|
72
|
-
|
73
|
-
|
74
|
-
)
|
73
|
+
target += normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)).sum()
|
74
|
+
|
75
|
+
# Evaluate mu ~ N(10,1)
|
76
|
+
return target
|
75
77
|
|
76
78
|
# Construct model
|
77
79
|
model = NormalDist()
|
@@ -84,7 +86,7 @@ def test_affine(benchmark, var_draws):
|
|
84
86
|
vari.fit(10000, var_draws=var_draws)
|
85
87
|
|
86
88
|
benchmark(benchmark_fit)
|
87
|
-
vari = vari.fit(20000)
|
89
|
+
vari = vari.fit(20000,var_draws=var_draws)
|
88
90
|
|
89
91
|
params = vari.flows[0].constrain_pars()
|
90
92
|
assert (abs(10.0 - vari.flows[0].params["shift"]) < 0.1).all() and (
|
@@ -106,12 +108,12 @@ def test_flows(benchmark, var_draws):
|
|
106
108
|
@eqx.filter_jit
|
107
109
|
def eval(self, data: dict):
|
108
110
|
# Get constrained parameters
|
109
|
-
params = self.constrain_pars()
|
111
|
+
params, target = self.constrain_pars()
|
110
112
|
|
111
113
|
# Evaluate mu ~ N(10,1)
|
112
|
-
|
113
|
-
|
114
|
-
|
114
|
+
target += normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)).sum()
|
115
|
+
|
116
|
+
return target
|
115
117
|
|
116
118
|
# Construct model
|
117
119
|
model = NormalDist()
|
@@ -126,7 +128,7 @@ def test_flows(benchmark, var_draws):
|
|
126
128
|
vari.fit(10000, var_draws=var_draws)
|
127
129
|
|
128
130
|
benchmark(benchmark_fit)
|
129
|
-
vari = vari.fit(20000)
|
131
|
+
vari = vari.fit(20000,var_draws=var_draws)
|
130
132
|
|
131
133
|
mean = vari.sample(1000).mean(0)
|
132
134
|
var = vari.sample(1000).var(0)
|
@@ -1,74 +0,0 @@
|
|
1
|
-
from abc import abstractmethod
|
2
|
-
from typing import Any, Callable, Dict
|
3
|
-
|
4
|
-
import equinox as eqx
|
5
|
-
import jax.tree_util as jtu
|
6
|
-
from jaxtyping import Array, Scalar
|
7
|
-
|
8
|
-
from bayinx.core.utils import __MyMeta
|
9
|
-
|
10
|
-
|
11
|
-
class Model(eqx.Module, metaclass=__MyMeta):
|
12
|
-
"""
|
13
|
-
A superclass used to define probabilistic models.
|
14
|
-
|
15
|
-
# Attributes
|
16
|
-
- `params`: A dictionary of JAX Arrays representing parameters of the model.
|
17
|
-
- `constraints`: A dictionary of functions that constrain their corresponding parameter.
|
18
|
-
"""
|
19
|
-
|
20
|
-
params: Dict[str, Array]
|
21
|
-
constraints: Dict[str, Callable[[Array], Array]]
|
22
|
-
|
23
|
-
@abstractmethod
|
24
|
-
def eval(self, data: Any) -> Scalar:
|
25
|
-
pass
|
26
|
-
|
27
|
-
# Default filter specification
|
28
|
-
def filter_spec(self):
|
29
|
-
"""
|
30
|
-
Generates a filter specification to subset relevant parameters for the model.
|
31
|
-
"""
|
32
|
-
# Generate empty specification
|
33
|
-
filter_spec = jtu.tree_map(lambda _: False, self)
|
34
|
-
|
35
|
-
# Specify JAX Array parameters
|
36
|
-
filter_spec = eqx.tree_at(
|
37
|
-
lambda model: model.params,
|
38
|
-
filter_spec,
|
39
|
-
replace=jtu.tree_map(eqx.is_array, self.params),
|
40
|
-
)
|
41
|
-
|
42
|
-
return filter_spec
|
43
|
-
|
44
|
-
def __init_subclass__(cls):
|
45
|
-
# Add constrain method
|
46
|
-
def constrain_pars(self: Model) -> Dict[str, Array]:
|
47
|
-
"""
|
48
|
-
Constrain `params` to the appropriate domain.
|
49
|
-
|
50
|
-
# Returns
|
51
|
-
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
52
|
-
"""
|
53
|
-
t_params = self.params
|
54
|
-
|
55
|
-
for par, map in self.constraints.items():
|
56
|
-
t_params[par] = map(t_params[par])
|
57
|
-
|
58
|
-
return t_params
|
59
|
-
|
60
|
-
cls.constrain_pars = eqx.filter_jit(constrain_pars)
|
61
|
-
|
62
|
-
# Add transform_pars method if not present
|
63
|
-
if not callable(getattr(cls, "transform_pars", None)):
|
64
|
-
|
65
|
-
def transform_pars(self: Model) -> Dict[str, Array]:
|
66
|
-
"""
|
67
|
-
Apply a custom transformation to `params` if needed.
|
68
|
-
|
69
|
-
# Returns
|
70
|
-
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
71
|
-
"""
|
72
|
-
return self.constrain_pars()
|
73
|
-
|
74
|
-
cls.transform_pars = eqx.filter_jit(transform_pars)
|
@@ -1,54 +0,0 @@
|
|
1
|
-
from typing import Callable, Dict
|
2
|
-
|
3
|
-
import equinox as eqx
|
4
|
-
from jaxtyping import Array
|
5
|
-
|
6
|
-
|
7
|
-
class __MyMeta(type(eqx.Module)):
|
8
|
-
"""
|
9
|
-
Metaclass to ensure attribute types are respected.
|
10
|
-
"""
|
11
|
-
|
12
|
-
def __call__(cls, *args, **kwargs):
|
13
|
-
obj = super().__call__(*args, **kwargs)
|
14
|
-
|
15
|
-
# Check parameters are a Dict of JAX Arrays
|
16
|
-
if not isinstance(obj.params, Dict):
|
17
|
-
raise ValueError(
|
18
|
-
f"Model {cls.__name__} must initialize 'params' as a dictionary."
|
19
|
-
)
|
20
|
-
|
21
|
-
for key, value in obj.params.items():
|
22
|
-
if not isinstance(value, Array):
|
23
|
-
raise TypeError(f"Parameter '{key}' must be a JAX Array.")
|
24
|
-
|
25
|
-
# Check constraints are a Dict of functions
|
26
|
-
if not isinstance(obj.constraints, Dict):
|
27
|
-
raise ValueError(
|
28
|
-
f"Model {cls.__name__} must initialize 'constraints' as a dictionary."
|
29
|
-
)
|
30
|
-
|
31
|
-
for key, value in obj.constraints.items():
|
32
|
-
if not isinstance(value, Callable):
|
33
|
-
raise TypeError(f"Constraint for parameter '{key}' must be a function.")
|
34
|
-
|
35
|
-
# Check that the constrain method returns a dict equivalent to `params`
|
36
|
-
t_params: Dict[str, Array] = obj.constrain_pars()
|
37
|
-
|
38
|
-
if not isinstance(t_params, Dict):
|
39
|
-
raise ValueError(
|
40
|
-
f"The 'constrain' method of {cls.__name__} must return a Dict of JAX Arrays."
|
41
|
-
)
|
42
|
-
|
43
|
-
for key, value in t_params.items():
|
44
|
-
if not isinstance(value, Array):
|
45
|
-
raise TypeError(f"Constrained parameter '{key}' must be a JAX Array.")
|
46
|
-
|
47
|
-
if not value.shape == obj.params[key].shape:
|
48
|
-
raise ValueError(
|
49
|
-
f"Constrained parameter '{key}' must have same shape as unconstrained counterpart."
|
50
|
-
)
|
51
|
-
|
52
|
-
# Check transform_pars
|
53
|
-
|
54
|
-
return obj
|
@@ -1,167 +0,0 @@
|
|
1
|
-
from abc import abstractmethod
|
2
|
-
from typing import Any, Callable, Self, Tuple
|
3
|
-
|
4
|
-
import equinox as eqx
|
5
|
-
import jax
|
6
|
-
import jax.lax as lax
|
7
|
-
import jax.numpy as jnp
|
8
|
-
import jax.random as jr
|
9
|
-
import optax as opx
|
10
|
-
from jaxtyping import Array, Float, Key, PyTree, Scalar
|
11
|
-
from optax import GradientTransformation, OptState, Schedule
|
12
|
-
|
13
|
-
from bayinx.core import Model
|
14
|
-
|
15
|
-
|
16
|
-
class Variational(eqx.Module):
|
17
|
-
"""
|
18
|
-
A superclass used to define variational methods.
|
19
|
-
|
20
|
-
# Attributes
|
21
|
-
- `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
|
22
|
-
- `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
|
23
|
-
"""
|
24
|
-
|
25
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
26
|
-
_constraints: Model
|
27
|
-
|
28
|
-
@abstractmethod
|
29
|
-
def sample(self, n: int, key: Key) -> Array:
|
30
|
-
"""
|
31
|
-
Sample from the variational distribution.
|
32
|
-
"""
|
33
|
-
pass
|
34
|
-
|
35
|
-
@abstractmethod
|
36
|
-
def eval(self, draws: Array) -> Array:
|
37
|
-
"""
|
38
|
-
Evaluate the variational distribution at `draws`.
|
39
|
-
"""
|
40
|
-
pass
|
41
|
-
|
42
|
-
@abstractmethod
|
43
|
-
def elbo(self, n: int, key: Key, data: Any = None) -> Array:
|
44
|
-
"""
|
45
|
-
Evaluate the ELBO.
|
46
|
-
"""
|
47
|
-
pass
|
48
|
-
|
49
|
-
@abstractmethod
|
50
|
-
def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
|
51
|
-
"""
|
52
|
-
Evaluate the gradient of the ELBO.
|
53
|
-
"""
|
54
|
-
pass
|
55
|
-
|
56
|
-
@abstractmethod
|
57
|
-
def filter_spec(self):
|
58
|
-
"""
|
59
|
-
Filter specification for dynamic and static components of the `Variational`.
|
60
|
-
"""
|
61
|
-
pass
|
62
|
-
|
63
|
-
def __init_subclass__(cls):
|
64
|
-
"""
|
65
|
-
Construct methods that are shared across all VI methods.
|
66
|
-
"""
|
67
|
-
|
68
|
-
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
69
|
-
"""
|
70
|
-
Reconstruct models from variational draws and evaluate their posterior density.
|
71
|
-
|
72
|
-
# Parameters
|
73
|
-
- `draws`: A set of variational draws.
|
74
|
-
- `data`: Data used to evaluate the posterior(if needed).
|
75
|
-
"""
|
76
|
-
# Unflatten variational draw
|
77
|
-
model: Model = self._unflatten(draws)
|
78
|
-
|
79
|
-
# Combine with constraints
|
80
|
-
model: Model = eqx.combine(model, self._constraints)
|
81
|
-
|
82
|
-
# Evaluate posterior density
|
83
|
-
return model.eval(data)
|
84
|
-
|
85
|
-
cls.eval_model = jax.vmap(eqx.filter_jit(eval_model), (None, 0, None))
|
86
|
-
|
87
|
-
def fit(
|
88
|
-
self,
|
89
|
-
max_iters: int,
|
90
|
-
data: Any = None,
|
91
|
-
learning_rate: float = 1,
|
92
|
-
weight_decay: float = 1e-4,
|
93
|
-
tolerance: float = 1e-4,
|
94
|
-
var_draws: int = 1,
|
95
|
-
key: Key = jr.PRNGKey(0),
|
96
|
-
) -> Self:
|
97
|
-
"""
|
98
|
-
Optimize the variational distribution.
|
99
|
-
|
100
|
-
# Parameters
|
101
|
-
- `max_iters`: Maximum number of iterations for the optimization loop.
|
102
|
-
- `data`: Data to evaluate the posterior density with(if available).
|
103
|
-
- `learning_rate`: Initial learning rate for optimization.
|
104
|
-
- `tolerance`: Relative tolerance of ELBO decrease for stopping early.
|
105
|
-
- `var_draws`: Number of variational draws to draw each iteration.
|
106
|
-
- `key`: A PRNG key.
|
107
|
-
"""
|
108
|
-
# Partition variational
|
109
|
-
dyn, static = eqx.partition(self, self.filter_spec())
|
110
|
-
|
111
|
-
# Construct scheduler
|
112
|
-
schedule: Schedule = opx.cosine_decay_schedule(
|
113
|
-
init_value=learning_rate, decay_steps=max_iters
|
114
|
-
)
|
115
|
-
|
116
|
-
# Initialize optimizer
|
117
|
-
optim: GradientTransformation = opx.chain(
|
118
|
-
opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
|
119
|
-
)
|
120
|
-
opt_state: OptState = optim.init(dyn)
|
121
|
-
|
122
|
-
# Optimization loop helper functions
|
123
|
-
@eqx.filter_jit
|
124
|
-
def condition(state: Tuple[Self, OptState, Scalar, Key]):
|
125
|
-
# Unpack iteration state
|
126
|
-
dyn, opt_state, i, key = state
|
127
|
-
|
128
|
-
return i < max_iters
|
129
|
-
|
130
|
-
@eqx.filter_jit
|
131
|
-
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
132
|
-
# Unpack iteration state
|
133
|
-
dyn, opt_state, i, key = state
|
134
|
-
|
135
|
-
# Update iteration
|
136
|
-
i = i + 1
|
137
|
-
|
138
|
-
# Update PRNG key
|
139
|
-
key, _ = jr.split(key)
|
140
|
-
|
141
|
-
# Combine variational
|
142
|
-
vari = eqx.combine(dyn, static)
|
143
|
-
|
144
|
-
# Compute gradient of the ELBO
|
145
|
-
updates: PyTree = vari.elbo_grad(var_draws, key, data)
|
146
|
-
|
147
|
-
# Compute updates
|
148
|
-
updates, opt_state = optim.update(
|
149
|
-
updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
|
150
|
-
)
|
151
|
-
|
152
|
-
# Update variational distribution
|
153
|
-
dyn = eqx.apply_updates(dyn, updates)
|
154
|
-
|
155
|
-
return dyn, opt_state, i, key
|
156
|
-
|
157
|
-
# Run optimization loop
|
158
|
-
dyn = lax.while_loop(
|
159
|
-
cond_fun=condition,
|
160
|
-
body_fun=body,
|
161
|
-
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
162
|
-
)[0]
|
163
|
-
|
164
|
-
# Return optimized variational
|
165
|
-
return eqx.combine(dyn, static)
|
166
|
-
|
167
|
-
cls.fit = eqx.filter_jit(fit)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|