bayinx 0.2.22__tar.gz → 0.2.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {bayinx-0.2.22 → bayinx-0.2.24}/PKG-INFO +1 -1
  2. {bayinx-0.2.22 → bayinx-0.2.24}/pyproject.toml +2 -2
  3. bayinx-0.2.24/src/bayinx/core/constraints.py +61 -0
  4. bayinx-0.2.24/src/bayinx/core/model.py +75 -0
  5. bayinx-0.2.24/src/bayinx/core/variational.py +162 -0
  6. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/vi/flows/planar.py +1 -1
  7. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/vi/normalizing_flow.py +1 -0
  8. {bayinx-0.2.22 → bayinx-0.2.24}/tests/test_variational.py +17 -15
  9. bayinx-0.2.22/src/bayinx/core/model.py +0 -74
  10. bayinx-0.2.22/src/bayinx/core/variational.py +0 -167
  11. {bayinx-0.2.22 → bayinx-0.2.24}/.github/workflows/release_and_publish.yml +0 -0
  12. {bayinx-0.2.22 → bayinx-0.2.24}/.gitignore +0 -0
  13. {bayinx-0.2.22 → bayinx-0.2.24}/README.md +0 -0
  14. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/__init__.py +0 -0
  15. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/core/__init__.py +0 -0
  16. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/core/flow.py +0 -0
  17. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/core/utils.py +0 -0
  18. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/dists/__init__.py +0 -0
  19. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/dists/bernoulli.py +0 -0
  20. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/dists/binomial.py +0 -0
  21. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/dists/gamma.py +0 -0
  22. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/dists/gamma2.py +0 -0
  23. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/dists/normal.py +0 -0
  24. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/dists/uniform.py +0 -0
  25. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/__init__.py +0 -0
  26. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/vi/__init__.py +0 -0
  27. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  28. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
  29. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  30. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  31. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/vi/meanfield.py +0 -0
  32. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/mhx/vi/standard.py +0 -0
  33. {bayinx-0.2.22 → bayinx-0.2.24}/src/bayinx/py.typed +0 -0
  34. {bayinx-0.2.22 → bayinx-0.2.24}/tests/__init__.py +0 -0
  35. {bayinx-0.2.22 → bayinx-0.2.24}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.22
3
+ Version: 0.2.24
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.2.22"
3
+ version = "0.2.24"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
19
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
20
 
21
21
  [tool.bumpversion]
22
- current_version = "0.2.22"
22
+ current_version = "0.2.24"
23
23
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
24
  serialize = ["{major}.{minor}.{patch}"]
25
25
  search = "{current_version}"
@@ -0,0 +1,61 @@
1
+ from abc import abstractmethod
2
+ from typing import Tuple
3
+
4
+ import equinox as eqx
5
+ import jax.numpy as jnp
6
+ from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
7
+
8
+
9
+ class Constraint(eqx.Module):
10
+ """
11
+ Abstract base class for defining parameter constraints.
12
+
13
+ Subclasses should implement the `constrain` method to apply the
14
+ transformation and compute the ladj adjustment.
15
+ """
16
+ @abstractmethod
17
+ def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
18
+ """
19
+ Applies the constraining transformation to an unconstrained input
20
+ and computes the log absolute determinant of the Jacobian (ladj)
21
+ of this transformation.
22
+
23
+ # Parameters
24
+ - `x`: The unconstrained JAX Array-like input.
25
+
26
+ # Returns
27
+ A tuple containing:
28
+ - The constrained JAX Array.
29
+ - A scalar JAX Array representing the ladj of the transformation.
30
+ """
31
+ pass
32
+
33
+
34
+ class LowerBound(Constraint):
35
+ """
36
+ Enforces a lower bound on the parameter.
37
+ """
38
+ lb: ScalarLike
39
+
40
+ def __init__(self, lb: ScalarLike):
41
+ self.lb = lb
42
+
43
+ def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
44
+ """
45
+ Applies the lower bound constraint and computes the ladj.
46
+
47
+ # Parameters
48
+ - `x`: The unconstrained JAX Array-like input.
49
+
50
+ # Parameters
51
+ A tuple containing:
52
+ - The constrained JAX Array (x > self.lb).
53
+ - A scalar JAX Array representing the ladj of the transformation.
54
+ """
55
+ # Compute transformation adjustment
56
+ ladj = jnp.sum(x)
57
+
58
+ # Compute transformation
59
+ x = jnp.exp(x) + self.lb
60
+
61
+ return x, ladj
@@ -0,0 +1,75 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Dict, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax.numpy as jnp
6
+ import jax.tree_util as jtu
7
+ from jaxtyping import Array, Scalar
8
+
9
+ from bayinx.core.constraints import Constraint
10
+ from bayinx.core.utils import __MyMeta
11
+
12
+
13
+ class Model(eqx.Module, metaclass=__MyMeta):
14
+ """
15
+ A superclass used to define probabilistic models.
16
+
17
+ # Attributes
18
+ - `params`: A dictionary of JAX Arrays representing parameters of the model.
19
+ - `constraints`: A dictionary of functions that constrain their corresponding parameter.
20
+ """
21
+
22
+ params: Dict[str, Array]
23
+ constraints: Dict[str, Constraint]
24
+
25
+ @abstractmethod
26
+ def eval(self, data: Any) -> Scalar:
27
+ pass
28
+
29
+ # Default filter specification
30
+ def filter_spec(self):
31
+ """
32
+ Generates a filter specification to subset relevant parameters for the model.
33
+ """
34
+ # Generate empty specification
35
+ filter_spec = jtu.tree_map(lambda _: False, self)
36
+
37
+ # Specify JAX Array parameters
38
+ filter_spec = eqx.tree_at(
39
+ lambda model: model.params,
40
+ filter_spec,
41
+ replace=jtu.tree_map(eqx.is_array, self.params),
42
+ )
43
+
44
+ return filter_spec
45
+
46
+ # Add constrain method
47
+ @eqx.filter_jit
48
+ def constrain_pars(self) -> Tuple[Dict[str, Array], Scalar]:
49
+ """
50
+ Constrain `params` to the appropriate domain.
51
+
52
+ # Returns
53
+ A dictionary of transformed JAX Arrays representing the constrained parameters and the adjustment to the posterior density.
54
+ """
55
+ t_params: Dict[str, Array] = self.params
56
+ target: Scalar = jnp.array(0.0)
57
+
58
+ for par, map in self.constraints.items():
59
+ # Apply transformation
60
+ t_params[par], ladj = map.constrain(t_params[par])
61
+
62
+ # Adjust posterior density
63
+ target -= ladj
64
+
65
+ return t_params, target
66
+
67
+
68
+ def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
69
+ """
70
+ Apply a custom transformation to `params` if needed.
71
+
72
+ # Returns
73
+ A dictionary of transformed JAX Arrays representing the transformed parameters.
74
+ """
75
+ return self.constrain_pars()
@@ -0,0 +1,162 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ from typing import Any, Callable, Self, Tuple
4
+
5
+ import equinox as eqx
6
+ import jax
7
+ import jax.lax as lax
8
+ import jax.numpy as jnp
9
+ import jax.random as jr
10
+ import optax as opx
11
+ from jaxtyping import Array, Float, Key, PyTree, Scalar
12
+ from optax import GradientTransformation, OptState, Schedule
13
+
14
+ from bayinx.core import Model
15
+
16
+
17
+ class Variational(eqx.Module):
18
+ """
19
+ A superclass used to define variational methods.
20
+
21
+ # Attributes
22
+ - `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
23
+ - `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
24
+ """
25
+
26
+ _unflatten: Callable[[Float[Array, "..."]], Model]
27
+ _constraints: Model
28
+
29
+ @abstractmethod
30
+ def sample(self, n: int, key: Key) -> Array:
31
+ """
32
+ Sample from the variational distribution.
33
+ """
34
+ pass
35
+
36
+ @abstractmethod
37
+ def eval(self, draws: Array) -> Array:
38
+ """
39
+ Evaluate the variational distribution at `draws`.
40
+ """
41
+ pass
42
+
43
+ @abstractmethod
44
+ def elbo(self, n: int, key: Key, data: Any = None) -> Array:
45
+ """
46
+ Evaluate the ELBO.
47
+ """
48
+ pass
49
+
50
+ @abstractmethod
51
+ def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
52
+ """
53
+ Evaluate the gradient of the ELBO.
54
+ """
55
+ pass
56
+
57
+ @abstractmethod
58
+ def filter_spec(self):
59
+ """
60
+ Filter specification for dynamic and static components of the `Variational`.
61
+ """
62
+ pass
63
+
64
+ @eqx.filter_jit
65
+ @partial(jax.vmap, in_axes=(None, 0, None))
66
+ def eval_model(self, draws: Array, data: Any = None) -> Array:
67
+ """
68
+ Reconstruct models from variational draws and evaluate their posterior density.
69
+
70
+ # Parameters
71
+ - `draws`: A set of variational draws.
72
+ - `data`: Data used to evaluate the posterior(if needed).
73
+ """
74
+ # Unflatten variational draw
75
+ model: Model = self._unflatten(draws)
76
+
77
+ # Combine with constraints
78
+ model: Model = eqx.combine(model, self._constraints)
79
+
80
+ # Evaluate posterior density
81
+ return model.eval(data)
82
+
83
+ @eqx.filter_jit
84
+ def fit(
85
+ self,
86
+ max_iters: int,
87
+ data: Any = None,
88
+ learning_rate: float = 1,
89
+ weight_decay: float = 1e-4,
90
+ tolerance: float = 1e-4,
91
+ var_draws: int = 1,
92
+ key: Key = jr.PRNGKey(0),
93
+ ) -> Self:
94
+ """
95
+ Optimize the variational distribution.
96
+
97
+ # Parameters
98
+ - `max_iters`: Maximum number of iterations for the optimization loop.
99
+ - `data`: Data to evaluate the posterior density with(if available).
100
+ - `learning_rate`: Initial learning rate for optimization.
101
+ - `tolerance`: Relative tolerance of ELBO decrease for stopping early.
102
+ - `var_draws`: Number of variational draws to draw each iteration.
103
+ - `key`: A PRNG key.
104
+ """
105
+ # Partition variational
106
+ dyn, static = eqx.partition(self, self.filter_spec())
107
+
108
+ # Construct scheduler
109
+ schedule: Schedule = opx.cosine_decay_schedule(
110
+ init_value=learning_rate, decay_steps=max_iters
111
+ )
112
+
113
+ # Initialize optimizer
114
+ optim: GradientTransformation = opx.chain(
115
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
116
+ )
117
+ opt_state: OptState = optim.init(dyn)
118
+
119
+ # Optimization loop helper functions
120
+ @eqx.filter_jit
121
+ def condition(state: Tuple[Self, OptState, Scalar, Key]):
122
+ # Unpack iteration state
123
+ dyn, opt_state, i, key = state
124
+
125
+ return i < max_iters
126
+
127
+ @eqx.filter_jit
128
+ def body(state: Tuple[Self, OptState, Scalar, Key]):
129
+ # Unpack iteration state
130
+ dyn, opt_state, i, key = state
131
+
132
+ # Update iteration
133
+ i = i + 1
134
+
135
+ # Update PRNG key
136
+ key, _ = jr.split(key)
137
+
138
+ # Combine variational
139
+ vari = eqx.combine(dyn, static)
140
+
141
+ # Compute gradient of the ELBO
142
+ updates: PyTree = vari.elbo_grad(var_draws, key, data)
143
+
144
+ # Compute updates
145
+ updates, opt_state = optim.update(
146
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
147
+ )
148
+
149
+ # Update variational distribution
150
+ dyn = eqx.apply_updates(dyn, updates)
151
+
152
+ return dyn, opt_state, i, key
153
+
154
+ # Run optimization loop
155
+ dyn = lax.while_loop(
156
+ cond_fun=condition,
157
+ body_fun=body,
158
+ init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
159
+ )[0]
160
+
161
+ # Return optimized variational
162
+ return eqx.combine(dyn, static)
@@ -30,7 +30,7 @@ class Planar(Flow):
30
30
  - `dim`: The dimension of the parameter space.
31
31
  """
32
32
  self.params = {
33
- "u": jnp.ones(dim),
33
+ "u": jnp.zeros(dim),
34
34
  "w": jnp.ones(dim),
35
35
  "b": jnp.zeros(1),
36
36
  }
@@ -144,6 +144,7 @@ class NormalizingFlow(Variational):
144
144
  draws: Array = self.base.sample(n, key)
145
145
 
146
146
  posterior_evals, variational_evals = self.__eval(draws, data)
147
+
147
148
  # Evaluate ELBO
148
149
  return jnp.mean(posterior_evals - variational_evals)
149
150
 
@@ -26,12 +26,13 @@ def test_meanfield(benchmark, var_draws):
26
26
  @eqx.filter_jit
27
27
  def eval(self, data: dict):
28
28
  # Get constrained parameters
29
- params = self.constrain_pars()
29
+ params, target = self.constrain_pars()
30
30
 
31
31
  # Evaluate mu ~ N(10,1)
32
- return jnp.sum(
33
- normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0))
34
- )
32
+ target += normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)).sum()
33
+
34
+ # Evaluate mu ~ N(10,1)
35
+ return target
35
36
 
36
37
  # Construct model
37
38
  model = NormalDist()
@@ -44,7 +45,7 @@ def test_meanfield(benchmark, var_draws):
44
45
  vari.fit(10000, var_draws=var_draws)
45
46
 
46
47
  benchmark(benchmark_fit)
47
- vari = vari.fit(20000)
48
+ vari = vari.fit(20000,var_draws=var_draws)
48
49
 
49
50
  # Assert parameters are roughly correct
50
51
  assert all(abs(10.0 - vari.var_params["mean"]) < 0.1) and all(
@@ -66,12 +67,13 @@ def test_affine(benchmark, var_draws):
66
67
  @eqx.filter_jit
67
68
  def eval(self, data: dict):
68
69
  # Get constrained parameters
69
- params = self.constrain_pars()
70
+ params, target = self.constrain_pars()
70
71
 
71
72
  # Evaluate mu ~ N(10,1)
72
- return jnp.sum(
73
- normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0))
74
- )
73
+ target += normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)).sum()
74
+
75
+ # Evaluate mu ~ N(10,1)
76
+ return target
75
77
 
76
78
  # Construct model
77
79
  model = NormalDist()
@@ -84,7 +86,7 @@ def test_affine(benchmark, var_draws):
84
86
  vari.fit(10000, var_draws=var_draws)
85
87
 
86
88
  benchmark(benchmark_fit)
87
- vari = vari.fit(20000)
89
+ vari = vari.fit(20000,var_draws=var_draws)
88
90
 
89
91
  params = vari.flows[0].constrain_pars()
90
92
  assert (abs(10.0 - vari.flows[0].params["shift"]) < 0.1).all() and (
@@ -106,12 +108,12 @@ def test_flows(benchmark, var_draws):
106
108
  @eqx.filter_jit
107
109
  def eval(self, data: dict):
108
110
  # Get constrained parameters
109
- params = self.constrain_pars()
111
+ params, target = self.constrain_pars()
110
112
 
111
113
  # Evaluate mu ~ N(10,1)
112
- return jnp.sum(
113
- normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0))
114
- )
114
+ target += normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)).sum()
115
+
116
+ return target
115
117
 
116
118
  # Construct model
117
119
  model = NormalDist()
@@ -126,7 +128,7 @@ def test_flows(benchmark, var_draws):
126
128
  vari.fit(10000, var_draws=var_draws)
127
129
 
128
130
  benchmark(benchmark_fit)
129
- vari = vari.fit(20000)
131
+ vari = vari.fit(20000,var_draws=var_draws)
130
132
 
131
133
  mean = vari.sample(1000).mean(0)
132
134
  var = vari.sample(1000).var(0)
@@ -1,74 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Callable, Dict
3
-
4
- import equinox as eqx
5
- import jax.tree_util as jtu
6
- from jaxtyping import Array, Scalar
7
-
8
- from bayinx.core.utils import __MyMeta
9
-
10
-
11
- class Model(eqx.Module, metaclass=__MyMeta):
12
- """
13
- A superclass used to define probabilistic models.
14
-
15
- # Attributes
16
- - `params`: A dictionary of JAX Arrays representing parameters of the model.
17
- - `constraints`: A dictionary of functions that constrain their corresponding parameter.
18
- """
19
-
20
- params: Dict[str, Array]
21
- constraints: Dict[str, Callable[[Array], Array]]
22
-
23
- @abstractmethod
24
- def eval(self, data: Any) -> Scalar:
25
- pass
26
-
27
- # Default filter specification
28
- def filter_spec(self):
29
- """
30
- Generates a filter specification to subset relevant parameters for the model.
31
- """
32
- # Generate empty specification
33
- filter_spec = jtu.tree_map(lambda _: False, self)
34
-
35
- # Specify JAX Array parameters
36
- filter_spec = eqx.tree_at(
37
- lambda model: model.params,
38
- filter_spec,
39
- replace=jtu.tree_map(eqx.is_array, self.params),
40
- )
41
-
42
- return filter_spec
43
-
44
- def __init_subclass__(cls):
45
- # Add constrain method
46
- def constrain_pars(self: Model) -> Dict[str, Array]:
47
- """
48
- Constrain `params` to the appropriate domain.
49
-
50
- # Returns
51
- A dictionary of transformed JAX Arrays representing the constrained parameters.
52
- """
53
- t_params = self.params
54
-
55
- for par, map in self.constraints.items():
56
- t_params[par] = map(t_params[par])
57
-
58
- return t_params
59
-
60
- cls.constrain_pars = eqx.filter_jit(constrain_pars)
61
-
62
- # Add transform_pars method if not present
63
- if not callable(getattr(cls, "transform_pars", None)):
64
-
65
- def transform_pars(self: Model) -> Dict[str, Array]:
66
- """
67
- Apply a custom transformation to `params` if needed.
68
-
69
- # Returns
70
- A dictionary of transformed JAX Arrays representing the transformed parameters.
71
- """
72
- return self.constrain_pars()
73
-
74
- cls.transform_pars = eqx.filter_jit(transform_pars)
@@ -1,167 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Callable, Self, Tuple
3
-
4
- import equinox as eqx
5
- import jax
6
- import jax.lax as lax
7
- import jax.numpy as jnp
8
- import jax.random as jr
9
- import optax as opx
10
- from jaxtyping import Array, Float, Key, PyTree, Scalar
11
- from optax import GradientTransformation, OptState, Schedule
12
-
13
- from bayinx.core import Model
14
-
15
-
16
- class Variational(eqx.Module):
17
- """
18
- A superclass used to define variational methods.
19
-
20
- # Attributes
21
- - `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
22
- - `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
23
- """
24
-
25
- _unflatten: Callable[[Float[Array, "..."]], Model]
26
- _constraints: Model
27
-
28
- @abstractmethod
29
- def sample(self, n: int, key: Key) -> Array:
30
- """
31
- Sample from the variational distribution.
32
- """
33
- pass
34
-
35
- @abstractmethod
36
- def eval(self, draws: Array) -> Array:
37
- """
38
- Evaluate the variational distribution at `draws`.
39
- """
40
- pass
41
-
42
- @abstractmethod
43
- def elbo(self, n: int, key: Key, data: Any = None) -> Array:
44
- """
45
- Evaluate the ELBO.
46
- """
47
- pass
48
-
49
- @abstractmethod
50
- def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
51
- """
52
- Evaluate the gradient of the ELBO.
53
- """
54
- pass
55
-
56
- @abstractmethod
57
- def filter_spec(self):
58
- """
59
- Filter specification for dynamic and static components of the `Variational`.
60
- """
61
- pass
62
-
63
- def __init_subclass__(cls):
64
- """
65
- Construct methods that are shared across all VI methods.
66
- """
67
-
68
- def eval_model(self, draws: Array, data: Any = None) -> Array:
69
- """
70
- Reconstruct models from variational draws and evaluate their posterior density.
71
-
72
- # Parameters
73
- - `draws`: A set of variational draws.
74
- - `data`: Data used to evaluate the posterior(if needed).
75
- """
76
- # Unflatten variational draw
77
- model: Model = self._unflatten(draws)
78
-
79
- # Combine with constraints
80
- model: Model = eqx.combine(model, self._constraints)
81
-
82
- # Evaluate posterior density
83
- return model.eval(data)
84
-
85
- cls.eval_model = jax.vmap(eqx.filter_jit(eval_model), (None, 0, None))
86
-
87
- def fit(
88
- self,
89
- max_iters: int,
90
- data: Any = None,
91
- learning_rate: float = 1,
92
- weight_decay: float = 1e-4,
93
- tolerance: float = 1e-4,
94
- var_draws: int = 1,
95
- key: Key = jr.PRNGKey(0),
96
- ) -> Self:
97
- """
98
- Optimize the variational distribution.
99
-
100
- # Parameters
101
- - `max_iters`: Maximum number of iterations for the optimization loop.
102
- - `data`: Data to evaluate the posterior density with(if available).
103
- - `learning_rate`: Initial learning rate for optimization.
104
- - `tolerance`: Relative tolerance of ELBO decrease for stopping early.
105
- - `var_draws`: Number of variational draws to draw each iteration.
106
- - `key`: A PRNG key.
107
- """
108
- # Partition variational
109
- dyn, static = eqx.partition(self, self.filter_spec())
110
-
111
- # Construct scheduler
112
- schedule: Schedule = opx.cosine_decay_schedule(
113
- init_value=learning_rate, decay_steps=max_iters
114
- )
115
-
116
- # Initialize optimizer
117
- optim: GradientTransformation = opx.chain(
118
- opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
119
- )
120
- opt_state: OptState = optim.init(dyn)
121
-
122
- # Optimization loop helper functions
123
- @eqx.filter_jit
124
- def condition(state: Tuple[Self, OptState, Scalar, Key]):
125
- # Unpack iteration state
126
- dyn, opt_state, i, key = state
127
-
128
- return i < max_iters
129
-
130
- @eqx.filter_jit
131
- def body(state: Tuple[Self, OptState, Scalar, Key]):
132
- # Unpack iteration state
133
- dyn, opt_state, i, key = state
134
-
135
- # Update iteration
136
- i = i + 1
137
-
138
- # Update PRNG key
139
- key, _ = jr.split(key)
140
-
141
- # Combine variational
142
- vari = eqx.combine(dyn, static)
143
-
144
- # Compute gradient of the ELBO
145
- updates: PyTree = vari.elbo_grad(var_draws, key, data)
146
-
147
- # Compute updates
148
- updates, opt_state = optim.update(
149
- updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
150
- )
151
-
152
- # Update variational distribution
153
- dyn = eqx.apply_updates(dyn, updates)
154
-
155
- return dyn, opt_state, i, key
156
-
157
- # Run optimization loop
158
- dyn = lax.while_loop(
159
- cond_fun=condition,
160
- body_fun=body,
161
- init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
162
- )[0]
163
-
164
- # Return optimized variational
165
- return eqx.combine(dyn, static)
166
-
167
- cls.fit = eqx.filter_jit(fit)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes