bayinx 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,61 @@
1
+ from abc import abstractmethod
2
+ from typing import Tuple
3
+
4
+ import equinox as eqx
5
+ import jax.numpy as jnp
6
+ from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
7
+
8
+
9
+ class Constraint(eqx.Module):
10
+ """
11
+ Abstract base class for defining parameter constraints.
12
+
13
+ Subclasses should implement the `constrain` method to apply the
14
+ transformation and compute the ladj adjustment.
15
+ """
16
+ @abstractmethod
17
+ def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
18
+ """
19
+ Applies the constraining transformation to an unconstrained input
20
+ and computes the log absolute determinant of the Jacobian (ladj)
21
+ of this transformation.
22
+
23
+ # Parameters
24
+ - `x`: The unconstrained JAX Array-like input.
25
+
26
+ # Returns
27
+ A tuple containing:
28
+ - The constrained JAX Array.
29
+ - A scalar JAX Array representing the ladj of the transformation.
30
+ """
31
+ pass
32
+
33
+
34
+ class LowerBound(Constraint):
35
+ """
36
+ Enforces a lower bound on the parameter.
37
+ """
38
+ lb: ScalarLike
39
+
40
+ def __init__(self, lb: ScalarLike):
41
+ self.lb = lb
42
+
43
+ def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
44
+ """
45
+ Applies the lower bound constraint and computes the ladj.
46
+
47
+ # Parameters
48
+ - `x`: The unconstrained JAX Array-like input.
49
+
50
+ # Parameters
51
+ A tuple containing:
52
+ - The constrained JAX Array (x > self.lb).
53
+ - A scalar JAX Array representing the ladj of the transformation.
54
+ """
55
+ # Compute transformation adjustment
56
+ ladj = jnp.sum(x)
57
+
58
+ # Compute transformation
59
+ x = jnp.exp(x) + self.lb
60
+
61
+ return x, ladj
bayinx/core/model.py CHANGED
@@ -1,10 +1,12 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, Callable, Dict
2
+ from typing import Any, Dict, Tuple
3
3
 
4
4
  import equinox as eqx
5
+ import jax.numpy as jnp
5
6
  import jax.tree_util as jtu
6
7
  from jaxtyping import Array, Scalar
7
8
 
9
+ from bayinx.core.constraints import Constraint
8
10
  from bayinx.core.utils import __MyMeta
9
11
 
10
12
 
@@ -18,7 +20,7 @@ class Model(eqx.Module, metaclass=__MyMeta):
18
20
  """
19
21
 
20
22
  params: Dict[str, Array]
21
- constraints: Dict[str, Callable[[Array], Array]]
23
+ constraints: Dict[str, Constraint]
22
24
 
23
25
  @abstractmethod
24
26
  def eval(self, data: Any) -> Scalar:
@@ -41,34 +43,33 @@ class Model(eqx.Module, metaclass=__MyMeta):
41
43
 
42
44
  return filter_spec
43
45
 
44
- def __init_subclass__(cls):
45
- # Add constrain method
46
- def constrain_pars(self: Model) -> Dict[str, Array]:
47
- """
48
- Constrain `params` to the appropriate domain.
49
-
50
- # Returns
51
- A dictionary of transformed JAX Arrays representing the constrained parameters.
52
- """
53
- t_params = self.params
46
+ # Add constrain method
47
+ @eqx.filter_jit
48
+ def constrain_pars(self) -> Tuple[Dict[str, Array], Scalar]:
49
+ """
50
+ Constrain `params` to the appropriate domain.
54
51
 
55
- for par, map in self.constraints.items():
56
- t_params[par] = map(t_params[par])
52
+ # Returns
53
+ A dictionary of transformed JAX Arrays representing the constrained parameters and the adjustment to the posterior density.
54
+ """
55
+ t_params: Dict[str, Array] = self.params
56
+ target: Scalar = jnp.array(0.0)
57
57
 
58
- return t_params
58
+ for par, map in self.constraints.items():
59
+ # Apply transformation
60
+ t_params[par], ladj = map.constrain(t_params[par])
59
61
 
60
- cls.constrain_pars = eqx.filter_jit(constrain_pars)
62
+ # Adjust posterior density
63
+ target -= ladj
61
64
 
62
- # Add transform_pars method if not present
63
- if not callable(getattr(cls, "transform_pars", None)):
65
+ return t_params, target
64
66
 
65
- def transform_pars(self: Model) -> Dict[str, Array]:
66
- """
67
- Apply a custom transformation to `params` if needed.
68
67
 
69
- # Returns
70
- A dictionary of transformed JAX Arrays representing the transformed parameters.
71
- """
72
- return self.constrain_pars()
68
+ def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
69
+ """
70
+ Apply a custom transformation to `params` if needed.
73
71
 
74
- cls.transform_pars = eqx.filter_jit(transform_pars)
72
+ # Returns
73
+ A dictionary of transformed JAX Arrays representing the transformed parameters.
74
+ """
75
+ return self.constrain_pars()
@@ -1,4 +1,5 @@
1
1
  from abc import abstractmethod
2
+ from functools import partial
2
3
  from typing import Any, Callable, Self, Tuple
3
4
 
4
5
  import equinox as eqx
@@ -60,108 +61,102 @@ class Variational(eqx.Module):
60
61
  """
61
62
  pass
62
63
 
63
- def __init_subclass__(cls):
64
- """
65
- Construct methods that are shared across all VI methods.
64
+ @eqx.filter_jit
65
+ @partial(jax.vmap, in_axes=(None, 0, None))
66
+ def eval_model(self, draws: Array, data: Any = None) -> Array:
66
67
  """
68
+ Reconstruct models from variational draws and evaluate their posterior density.
67
69
 
68
- def eval_model(self, draws: Array, data: Any = None) -> Array:
69
- """
70
- Reconstruct models from variational draws and evaluate their posterior density.
71
-
72
- # Parameters
73
- - `draws`: A set of variational draws.
74
- - `data`: Data used to evaluate the posterior(if needed).
75
- """
76
- # Unflatten variational draw
77
- model: Model = self._unflatten(draws)
78
-
79
- # Combine with constraints
80
- model: Model = eqx.combine(model, self._constraints)
81
-
82
- # Evaluate posterior density
83
- return model.eval(data)
84
-
85
- cls.eval_model = jax.vmap(eqx.filter_jit(eval_model), (None, 0, None))
86
-
87
- def fit(
88
- self,
89
- max_iters: int,
90
- data: Any = None,
91
- learning_rate: float = 1,
92
- weight_decay: float = 1e-4,
93
- tolerance: float = 1e-4,
94
- var_draws: int = 1,
95
- key: Key = jr.PRNGKey(0),
96
- ) -> Self:
97
- """
98
- Optimize the variational distribution.
99
-
100
- # Parameters
101
- - `max_iters`: Maximum number of iterations for the optimization loop.
102
- - `data`: Data to evaluate the posterior density with(if available).
103
- - `learning_rate`: Initial learning rate for optimization.
104
- - `tolerance`: Relative tolerance of ELBO decrease for stopping early.
105
- - `var_draws`: Number of variational draws to draw each iteration.
106
- - `key`: A PRNG key.
107
- """
108
- # Partition variational
109
- dyn, static = eqx.partition(self, self.filter_spec())
110
-
111
- # Construct scheduler
112
- schedule: Schedule = opx.cosine_decay_schedule(
113
- init_value=learning_rate, decay_steps=max_iters
114
- )
70
+ # Parameters
71
+ - `draws`: A set of variational draws.
72
+ - `data`: Data used to evaluate the posterior(if needed).
73
+ """
74
+ # Unflatten variational draw
75
+ model: Model = self._unflatten(draws)
76
+
77
+ # Combine with constraints
78
+ model: Model = eqx.combine(model, self._constraints)
79
+
80
+ # Evaluate posterior density
81
+ return model.eval(data)
82
+
83
+ @eqx.filter_jit
84
+ def fit(
85
+ self,
86
+ max_iters: int,
87
+ data: Any = None,
88
+ learning_rate: float = 1,
89
+ weight_decay: float = 1e-4,
90
+ tolerance: float = 1e-4,
91
+ var_draws: int = 1,
92
+ key: Key = jr.PRNGKey(0),
93
+ ) -> Self:
94
+ """
95
+ Optimize the variational distribution.
96
+
97
+ # Parameters
98
+ - `max_iters`: Maximum number of iterations for the optimization loop.
99
+ - `data`: Data to evaluate the posterior density with(if available).
100
+ - `learning_rate`: Initial learning rate for optimization.
101
+ - `tolerance`: Relative tolerance of ELBO decrease for stopping early.
102
+ - `var_draws`: Number of variational draws to draw each iteration.
103
+ - `key`: A PRNG key.
104
+ """
105
+ # Partition variational
106
+ dyn, static = eqx.partition(self, self.filter_spec())
115
107
 
116
- # Initialize optimizer
117
- optim: GradientTransformation = opx.chain(
118
- opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
119
- )
120
- opt_state: OptState = optim.init(dyn)
108
+ # Construct scheduler
109
+ schedule: Schedule = opx.cosine_decay_schedule(
110
+ init_value=learning_rate, decay_steps=max_iters
111
+ )
121
112
 
122
- # Optimization loop helper functions
123
- @eqx.filter_jit
124
- def condition(state: Tuple[Self, OptState, Scalar, Key]):
125
- # Unpack iteration state
126
- dyn, opt_state, i, key = state
113
+ # Initialize optimizer
114
+ optim: GradientTransformation = opx.chain(
115
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
116
+ )
117
+ opt_state: OptState = optim.init(dyn)
127
118
 
128
- return i < max_iters
119
+ # Optimization loop helper functions
120
+ @eqx.filter_jit
121
+ def condition(state: Tuple[Self, OptState, Scalar, Key]):
122
+ # Unpack iteration state
123
+ dyn, opt_state, i, key = state
129
124
 
130
- @eqx.filter_jit
131
- def body(state: Tuple[Self, OptState, Scalar, Key]):
132
- # Unpack iteration state
133
- dyn, opt_state, i, key = state
125
+ return i < max_iters
134
126
 
135
- # Update iteration
136
- i = i + 1
127
+ @eqx.filter_jit
128
+ def body(state: Tuple[Self, OptState, Scalar, Key]):
129
+ # Unpack iteration state
130
+ dyn, opt_state, i, key = state
137
131
 
138
- # Update PRNG key
139
- key, _ = jr.split(key)
132
+ # Update iteration
133
+ i = i + 1
140
134
 
141
- # Combine variational
142
- vari = eqx.combine(dyn, static)
135
+ # Update PRNG key
136
+ key, _ = jr.split(key)
143
137
 
144
- # Compute gradient of the ELBO
145
- updates: PyTree = vari.elbo_grad(var_draws, key, data)
138
+ # Combine variational
139
+ vari = eqx.combine(dyn, static)
146
140
 
147
- # Compute updates
148
- updates, opt_state = optim.update(
149
- updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
150
- )
141
+ # Compute gradient of the ELBO
142
+ updates: PyTree = vari.elbo_grad(var_draws, key, data)
151
143
 
152
- # Update variational distribution
153
- dyn = eqx.apply_updates(dyn, updates)
144
+ # Compute updates
145
+ updates, opt_state = optim.update(
146
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
147
+ )
154
148
 
155
- return dyn, opt_state, i, key
149
+ # Update variational distribution
150
+ dyn = eqx.apply_updates(dyn, updates)
156
151
 
157
- # Run optimization loop
158
- dyn = lax.while_loop(
159
- cond_fun=condition,
160
- body_fun=body,
161
- init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
162
- )[0]
152
+ return dyn, opt_state, i, key
163
153
 
164
- # Return optimized variational
165
- return eqx.combine(dyn, static)
154
+ # Run optimization loop
155
+ dyn = lax.while_loop(
156
+ cond_fun=condition,
157
+ body_fun=body,
158
+ init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
159
+ )[0]
166
160
 
167
- cls.fit = eqx.filter_jit(fit)
161
+ # Return optimized variational
162
+ return eqx.combine(dyn, static)
@@ -30,7 +30,7 @@ class Planar(Flow):
30
30
  - `dim`: The dimension of the parameter space.
31
31
  """
32
32
  self.params = {
33
- "u": jnp.ones(dim),
33
+ "u": jnp.zeros(dim),
34
34
  "w": jnp.ones(dim),
35
35
  "b": jnp.zeros(1),
36
36
  }
@@ -144,6 +144,7 @@ class NormalizingFlow(Variational):
144
144
  draws: Array = self.base.sample(n, key)
145
145
 
146
146
  posterior_evals, variational_evals = self.__eval(draws, data)
147
+
147
148
  # Evaluate ELBO
148
149
  return jnp.mean(posterior_evals - variational_evals)
149
150
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.22
3
+ Version: 0.2.24
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,10 +1,11 @@
1
1
  bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
2
2
  bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
+ bayinx/core/constraints.py,sha256=Y8FJX3CkgnLQ3HXuTPGuzvLtXlKs0B7z0-YymoHgdfg,1682
4
5
  bayinx/core/flow.py,sha256=oZE0OHCninIHjp-WVLFWd1DaN0-qXxNWFAUAdgIDmRU,2423
5
- bayinx/core/model.py,sha256=-rT3NHjxqGB0lDBMi0Mr9XNOz1_TUnJWtd4ITj0rsus,2257
6
+ bayinx/core/model.py,sha256=t7s5Yt4E3iC_MPvynJnk6lb4OLal7Gnk59tIZ6e5M4I,2203
6
7
  bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
7
- bayinx/core/variational.py,sha256=k9wWn7Tnj3eET-qK1pZtzDyPZVvQTRUexJUBVSdGXOA,5251
8
+ bayinx/core/variational.py,sha256=vUZ6u5CXCHfs6ZrA8PF4PHfmUXHTK2RJGHyZ3afFfsg,4820
8
9
  bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
10
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
10
11
  bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -15,13 +16,13 @@ bayinx/dists/uniform.py,sha256=PSZIIc2QfNF5XYi-TLGltnr_vnAIA-MZsn1rKV8QXAo,2353
15
16
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
16
17
  bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
17
18
  bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
18
- bayinx/mhx/vi/normalizing_flow.py,sha256=XBUWYZpm_Ipi6X9oTnGhqIs3ARY-5QFiuxM7uAWFRps,4790
19
+ bayinx/mhx/vi/normalizing_flow.py,sha256=nj7bpIoMJl6GTOXPxQCAsPArchbHd5vwwqMm3cLbMII,4791
19
20
  bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
20
21
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
21
22
  bayinx/mhx/vi/flows/fullaffine.py,sha256=2QbOtA1Jmu-yRcJeFmCKc8N1atm8G7JXYMLEZaEXKV0,1711
22
- bayinx/mhx/vi/flows/planar.py,sha256=qmtWpIBXRct2seI78pkmtF0X7cASUBELqmZmf2QS5Gs,1918
23
+ bayinx/mhx/vi/flows/planar.py,sha256=u9ZVwEeOv4fMfwiORlseCz463atsWMuid_9crRg05Z8,1919
23
24
  bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
24
25
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
25
- bayinx-0.2.22.dist-info/METADATA,sha256=TwM3DxTXPTttN0TGJOPpM1pRMzDvoUjFZWkWWCR9vNI,3058
26
- bayinx-0.2.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
27
- bayinx-0.2.22.dist-info/RECORD,,
26
+ bayinx-0.2.24.dist-info/METADATA,sha256=sR0C0Pk5vrAmdvAtB3faXZO-hIDpKzqLjnXcfMsikjw,3058
27
+ bayinx-0.2.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
+ bayinx-0.2.24.dist-info/RECORD,,