bayinx 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/core/constraints.py +61 -0
- bayinx/core/model.py +27 -26
- bayinx/core/variational.py +84 -89
- bayinx/mhx/vi/flows/planar.py +1 -1
- bayinx/mhx/vi/normalizing_flow.py +1 -0
- {bayinx-0.2.22.dist-info → bayinx-0.2.24.dist-info}/METADATA +1 -1
- {bayinx-0.2.22.dist-info → bayinx-0.2.24.dist-info}/RECORD +8 -7
- {bayinx-0.2.22.dist-info → bayinx-0.2.24.dist-info}/WHEEL +0 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
|
7
|
+
|
8
|
+
|
9
|
+
class Constraint(eqx.Module):
|
10
|
+
"""
|
11
|
+
Abstract base class for defining parameter constraints.
|
12
|
+
|
13
|
+
Subclasses should implement the `constrain` method to apply the
|
14
|
+
transformation and compute the ladj adjustment.
|
15
|
+
"""
|
16
|
+
@abstractmethod
|
17
|
+
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
18
|
+
"""
|
19
|
+
Applies the constraining transformation to an unconstrained input
|
20
|
+
and computes the log absolute determinant of the Jacobian (ladj)
|
21
|
+
of this transformation.
|
22
|
+
|
23
|
+
# Parameters
|
24
|
+
- `x`: The unconstrained JAX Array-like input.
|
25
|
+
|
26
|
+
# Returns
|
27
|
+
A tuple containing:
|
28
|
+
- The constrained JAX Array.
|
29
|
+
- A scalar JAX Array representing the ladj of the transformation.
|
30
|
+
"""
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class LowerBound(Constraint):
|
35
|
+
"""
|
36
|
+
Enforces a lower bound on the parameter.
|
37
|
+
"""
|
38
|
+
lb: ScalarLike
|
39
|
+
|
40
|
+
def __init__(self, lb: ScalarLike):
|
41
|
+
self.lb = lb
|
42
|
+
|
43
|
+
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
44
|
+
"""
|
45
|
+
Applies the lower bound constraint and computes the ladj.
|
46
|
+
|
47
|
+
# Parameters
|
48
|
+
- `x`: The unconstrained JAX Array-like input.
|
49
|
+
|
50
|
+
# Parameters
|
51
|
+
A tuple containing:
|
52
|
+
- The constrained JAX Array (x > self.lb).
|
53
|
+
- A scalar JAX Array representing the ladj of the transformation.
|
54
|
+
"""
|
55
|
+
# Compute transformation adjustment
|
56
|
+
ladj = jnp.sum(x)
|
57
|
+
|
58
|
+
# Compute transformation
|
59
|
+
x = jnp.exp(x) + self.lb
|
60
|
+
|
61
|
+
return x, ladj
|
bayinx/core/model.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Dict, Tuple
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
|
+
import jax.numpy as jnp
|
5
6
|
import jax.tree_util as jtu
|
6
7
|
from jaxtyping import Array, Scalar
|
7
8
|
|
9
|
+
from bayinx.core.constraints import Constraint
|
8
10
|
from bayinx.core.utils import __MyMeta
|
9
11
|
|
10
12
|
|
@@ -18,7 +20,7 @@ class Model(eqx.Module, metaclass=__MyMeta):
|
|
18
20
|
"""
|
19
21
|
|
20
22
|
params: Dict[str, Array]
|
21
|
-
constraints: Dict[str,
|
23
|
+
constraints: Dict[str, Constraint]
|
22
24
|
|
23
25
|
@abstractmethod
|
24
26
|
def eval(self, data: Any) -> Scalar:
|
@@ -41,34 +43,33 @@ class Model(eqx.Module, metaclass=__MyMeta):
|
|
41
43
|
|
42
44
|
return filter_spec
|
43
45
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
# Returns
|
51
|
-
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
52
|
-
"""
|
53
|
-
t_params = self.params
|
46
|
+
# Add constrain method
|
47
|
+
@eqx.filter_jit
|
48
|
+
def constrain_pars(self) -> Tuple[Dict[str, Array], Scalar]:
|
49
|
+
"""
|
50
|
+
Constrain `params` to the appropriate domain.
|
54
51
|
|
55
|
-
|
56
|
-
|
52
|
+
# Returns
|
53
|
+
A dictionary of transformed JAX Arrays representing the constrained parameters and the adjustment to the posterior density.
|
54
|
+
"""
|
55
|
+
t_params: Dict[str, Array] = self.params
|
56
|
+
target: Scalar = jnp.array(0.0)
|
57
57
|
|
58
|
-
|
58
|
+
for par, map in self.constraints.items():
|
59
|
+
# Apply transformation
|
60
|
+
t_params[par], ladj = map.constrain(t_params[par])
|
59
61
|
|
60
|
-
|
62
|
+
# Adjust posterior density
|
63
|
+
target -= ladj
|
61
64
|
|
62
|
-
|
63
|
-
if not callable(getattr(cls, "transform_pars", None)):
|
65
|
+
return t_params, target
|
64
66
|
|
65
|
-
def transform_pars(self: Model) -> Dict[str, Array]:
|
66
|
-
"""
|
67
|
-
Apply a custom transformation to `params` if needed.
|
68
67
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
return self.constrain_pars()
|
68
|
+
def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
|
69
|
+
"""
|
70
|
+
Apply a custom transformation to `params` if needed.
|
73
71
|
|
74
|
-
|
72
|
+
# Returns
|
73
|
+
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
74
|
+
"""
|
75
|
+
return self.constrain_pars()
|
bayinx/core/variational.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
+
from functools import partial
|
2
3
|
from typing import Any, Callable, Self, Tuple
|
3
4
|
|
4
5
|
import equinox as eqx
|
@@ -60,108 +61,102 @@ class Variational(eqx.Module):
|
|
60
61
|
"""
|
61
62
|
pass
|
62
63
|
|
63
|
-
|
64
|
-
|
65
|
-
|
64
|
+
@eqx.filter_jit
|
65
|
+
@partial(jax.vmap, in_axes=(None, 0, None))
|
66
|
+
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
66
67
|
"""
|
68
|
+
Reconstruct models from variational draws and evaluate their posterior density.
|
67
69
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
- `var_draws`: Number of variational draws to draw each iteration.
|
106
|
-
- `key`: A PRNG key.
|
107
|
-
"""
|
108
|
-
# Partition variational
|
109
|
-
dyn, static = eqx.partition(self, self.filter_spec())
|
110
|
-
|
111
|
-
# Construct scheduler
|
112
|
-
schedule: Schedule = opx.cosine_decay_schedule(
|
113
|
-
init_value=learning_rate, decay_steps=max_iters
|
114
|
-
)
|
70
|
+
# Parameters
|
71
|
+
- `draws`: A set of variational draws.
|
72
|
+
- `data`: Data used to evaluate the posterior(if needed).
|
73
|
+
"""
|
74
|
+
# Unflatten variational draw
|
75
|
+
model: Model = self._unflatten(draws)
|
76
|
+
|
77
|
+
# Combine with constraints
|
78
|
+
model: Model = eqx.combine(model, self._constraints)
|
79
|
+
|
80
|
+
# Evaluate posterior density
|
81
|
+
return model.eval(data)
|
82
|
+
|
83
|
+
@eqx.filter_jit
|
84
|
+
def fit(
|
85
|
+
self,
|
86
|
+
max_iters: int,
|
87
|
+
data: Any = None,
|
88
|
+
learning_rate: float = 1,
|
89
|
+
weight_decay: float = 1e-4,
|
90
|
+
tolerance: float = 1e-4,
|
91
|
+
var_draws: int = 1,
|
92
|
+
key: Key = jr.PRNGKey(0),
|
93
|
+
) -> Self:
|
94
|
+
"""
|
95
|
+
Optimize the variational distribution.
|
96
|
+
|
97
|
+
# Parameters
|
98
|
+
- `max_iters`: Maximum number of iterations for the optimization loop.
|
99
|
+
- `data`: Data to evaluate the posterior density with(if available).
|
100
|
+
- `learning_rate`: Initial learning rate for optimization.
|
101
|
+
- `tolerance`: Relative tolerance of ELBO decrease for stopping early.
|
102
|
+
- `var_draws`: Number of variational draws to draw each iteration.
|
103
|
+
- `key`: A PRNG key.
|
104
|
+
"""
|
105
|
+
# Partition variational
|
106
|
+
dyn, static = eqx.partition(self, self.filter_spec())
|
115
107
|
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
opt_state: OptState = optim.init(dyn)
|
108
|
+
# Construct scheduler
|
109
|
+
schedule: Schedule = opx.cosine_decay_schedule(
|
110
|
+
init_value=learning_rate, decay_steps=max_iters
|
111
|
+
)
|
121
112
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
113
|
+
# Initialize optimizer
|
114
|
+
optim: GradientTransformation = opx.chain(
|
115
|
+
opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
|
116
|
+
)
|
117
|
+
opt_state: OptState = optim.init(dyn)
|
127
118
|
|
128
|
-
|
119
|
+
# Optimization loop helper functions
|
120
|
+
@eqx.filter_jit
|
121
|
+
def condition(state: Tuple[Self, OptState, Scalar, Key]):
|
122
|
+
# Unpack iteration state
|
123
|
+
dyn, opt_state, i, key = state
|
129
124
|
|
130
|
-
|
131
|
-
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
132
|
-
# Unpack iteration state
|
133
|
-
dyn, opt_state, i, key = state
|
125
|
+
return i < max_iters
|
134
126
|
|
135
|
-
|
136
|
-
|
127
|
+
@eqx.filter_jit
|
128
|
+
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
129
|
+
# Unpack iteration state
|
130
|
+
dyn, opt_state, i, key = state
|
137
131
|
|
138
|
-
|
139
|
-
|
132
|
+
# Update iteration
|
133
|
+
i = i + 1
|
140
134
|
|
141
|
-
|
142
|
-
|
135
|
+
# Update PRNG key
|
136
|
+
key, _ = jr.split(key)
|
143
137
|
|
144
|
-
|
145
|
-
|
138
|
+
# Combine variational
|
139
|
+
vari = eqx.combine(dyn, static)
|
146
140
|
|
147
|
-
|
148
|
-
|
149
|
-
updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
|
150
|
-
)
|
141
|
+
# Compute gradient of the ELBO
|
142
|
+
updates: PyTree = vari.elbo_grad(var_draws, key, data)
|
151
143
|
|
152
|
-
|
153
|
-
|
144
|
+
# Compute updates
|
145
|
+
updates, opt_state = optim.update(
|
146
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
|
147
|
+
)
|
154
148
|
|
155
|
-
|
149
|
+
# Update variational distribution
|
150
|
+
dyn = eqx.apply_updates(dyn, updates)
|
156
151
|
|
157
|
-
|
158
|
-
dyn = lax.while_loop(
|
159
|
-
cond_fun=condition,
|
160
|
-
body_fun=body,
|
161
|
-
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
162
|
-
)[0]
|
152
|
+
return dyn, opt_state, i, key
|
163
153
|
|
164
|
-
|
165
|
-
|
154
|
+
# Run optimization loop
|
155
|
+
dyn = lax.while_loop(
|
156
|
+
cond_fun=condition,
|
157
|
+
body_fun=body,
|
158
|
+
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
159
|
+
)[0]
|
166
160
|
|
167
|
-
|
161
|
+
# Return optimized variational
|
162
|
+
return eqx.combine(dyn, static)
|
bayinx/mhx/vi/flows/planar.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
2
|
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
+
bayinx/core/constraints.py,sha256=Y8FJX3CkgnLQ3HXuTPGuzvLtXlKs0B7z0-YymoHgdfg,1682
|
4
5
|
bayinx/core/flow.py,sha256=oZE0OHCninIHjp-WVLFWd1DaN0-qXxNWFAUAdgIDmRU,2423
|
5
|
-
bayinx/core/model.py,sha256
|
6
|
+
bayinx/core/model.py,sha256=t7s5Yt4E3iC_MPvynJnk6lb4OLal7Gnk59tIZ6e5M4I,2203
|
6
7
|
bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
|
7
|
-
bayinx/core/variational.py,sha256=
|
8
|
+
bayinx/core/variational.py,sha256=vUZ6u5CXCHfs6ZrA8PF4PHfmUXHTK2RJGHyZ3afFfsg,4820
|
8
9
|
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
10
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
10
11
|
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -15,13 +16,13 @@ bayinx/dists/uniform.py,sha256=PSZIIc2QfNF5XYi-TLGltnr_vnAIA-MZsn1rKV8QXAo,2353
|
|
15
16
|
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
16
17
|
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
17
18
|
bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
|
18
|
-
bayinx/mhx/vi/normalizing_flow.py,sha256=
|
19
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=nj7bpIoMJl6GTOXPxQCAsPArchbHd5vwwqMm3cLbMII,4791
|
19
20
|
bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
|
20
21
|
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
21
22
|
bayinx/mhx/vi/flows/fullaffine.py,sha256=2QbOtA1Jmu-yRcJeFmCKc8N1atm8G7JXYMLEZaEXKV0,1711
|
22
|
-
bayinx/mhx/vi/flows/planar.py,sha256=
|
23
|
+
bayinx/mhx/vi/flows/planar.py,sha256=u9ZVwEeOv4fMfwiORlseCz463atsWMuid_9crRg05Z8,1919
|
23
24
|
bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
|
24
25
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
25
|
-
bayinx-0.2.
|
26
|
-
bayinx-0.2.
|
27
|
-
bayinx-0.2.
|
26
|
+
bayinx-0.2.24.dist-info/METADATA,sha256=sR0C0Pk5vrAmdvAtB3faXZO-hIDpKzqLjnXcfMsikjw,3058
|
27
|
+
bayinx-0.2.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
28
|
+
bayinx-0.2.24.dist-info/RECORD,,
|
File without changes
|