bayinx 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/core/constraints.py +61 -0
- bayinx/core/flow.py +1 -3
- bayinx/core/model.py +28 -28
- bayinx/core/utils.py +0 -53
- bayinx/core/variational.py +84 -89
- bayinx/mhx/vi/normalizing_flow.py +1 -0
- {bayinx-0.2.23.dist-info → bayinx-0.2.25.dist-info}/METADATA +1 -1
- {bayinx-0.2.23.dist-info → bayinx-0.2.25.dist-info}/RECORD +9 -8
- {bayinx-0.2.23.dist-info → bayinx-0.2.25.dist-info}/WHEEL +0 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
|
7
|
+
|
8
|
+
|
9
|
+
class Constraint(eqx.Module):
|
10
|
+
"""
|
11
|
+
Abstract base class for defining parameter constraints.
|
12
|
+
|
13
|
+
Subclasses should implement the `constrain` method to apply the
|
14
|
+
transformation and compute the ladj adjustment.
|
15
|
+
"""
|
16
|
+
@abstractmethod
|
17
|
+
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
18
|
+
"""
|
19
|
+
Applies the constraining transformation to an unconstrained input
|
20
|
+
and computes the log absolute determinant of the Jacobian (ladj)
|
21
|
+
of this transformation.
|
22
|
+
|
23
|
+
# Parameters
|
24
|
+
- `x`: The unconstrained JAX Array-like input.
|
25
|
+
|
26
|
+
# Returns
|
27
|
+
A tuple containing:
|
28
|
+
- The constrained JAX Array.
|
29
|
+
- A scalar JAX Array representing the ladj of the transformation.
|
30
|
+
"""
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class LowerBound(Constraint):
|
35
|
+
"""
|
36
|
+
Enforces a lower bound on the parameter.
|
37
|
+
"""
|
38
|
+
lb: ScalarLike
|
39
|
+
|
40
|
+
def __init__(self, lb: ScalarLike):
|
41
|
+
self.lb = lb
|
42
|
+
|
43
|
+
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
44
|
+
"""
|
45
|
+
Applies the lower bound constraint and computes the ladj.
|
46
|
+
|
47
|
+
# Parameters
|
48
|
+
- `x`: The unconstrained JAX Array-like input.
|
49
|
+
|
50
|
+
# Parameters
|
51
|
+
A tuple containing:
|
52
|
+
- The constrained JAX Array (x > self.lb).
|
53
|
+
- A scalar JAX Array representing the ladj of the transformation.
|
54
|
+
"""
|
55
|
+
# Compute transformation adjustment
|
56
|
+
ladj = jnp.sum(x)
|
57
|
+
|
58
|
+
# Compute transformation
|
59
|
+
x = jnp.exp(x) + self.lb
|
60
|
+
|
61
|
+
return x, ladj
|
bayinx/core/flow.py
CHANGED
@@ -5,10 +5,8 @@ import equinox as eqx
|
|
5
5
|
import jax.tree_util as jtu
|
6
6
|
from jaxtyping import Array, Float
|
7
7
|
|
8
|
-
from bayinx.core.utils import __MyMeta
|
9
8
|
|
10
|
-
|
11
|
-
class Flow(eqx.Module, metaclass=__MyMeta):
|
9
|
+
class Flow(eqx.Module):
|
12
10
|
"""
|
13
11
|
A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
|
14
12
|
|
bayinx/core/model.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Dict, Tuple
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
|
+
import jax.numpy as jnp
|
5
6
|
import jax.tree_util as jtu
|
6
7
|
from jaxtyping import Array, Scalar
|
7
8
|
|
8
|
-
from bayinx.core.
|
9
|
+
from bayinx.core.constraints import Constraint
|
9
10
|
|
10
11
|
|
11
|
-
class Model(eqx.Module
|
12
|
+
class Model(eqx.Module):
|
12
13
|
"""
|
13
14
|
A superclass used to define probabilistic models.
|
14
15
|
|
@@ -18,7 +19,7 @@ class Model(eqx.Module, metaclass=__MyMeta):
|
|
18
19
|
"""
|
19
20
|
|
20
21
|
params: Dict[str, Array]
|
21
|
-
constraints: Dict[str,
|
22
|
+
constraints: Dict[str, Constraint]
|
22
23
|
|
23
24
|
@abstractmethod
|
24
25
|
def eval(self, data: Any) -> Scalar:
|
@@ -41,34 +42,33 @@ class Model(eqx.Module, metaclass=__MyMeta):
|
|
41
42
|
|
42
43
|
return filter_spec
|
43
44
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
# Returns
|
51
|
-
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
52
|
-
"""
|
53
|
-
t_params = self.params
|
45
|
+
# Add constrain method
|
46
|
+
@eqx.filter_jit
|
47
|
+
def constrain_pars(self) -> Tuple[Dict[str, Array], Scalar]:
|
48
|
+
"""
|
49
|
+
Constrain `params` to the appropriate domain.
|
54
50
|
|
55
|
-
|
56
|
-
|
51
|
+
# Returns
|
52
|
+
A dictionary of transformed JAX Arrays representing the constrained parameters and the adjustment to the posterior density.
|
53
|
+
"""
|
54
|
+
t_params: Dict[str, Array] = self.params
|
55
|
+
target: Scalar = jnp.array(0.0)
|
57
56
|
|
58
|
-
|
57
|
+
for par, map in self.constraints.items():
|
58
|
+
# Apply transformation
|
59
|
+
t_params[par], ladj = map.constrain(t_params[par])
|
59
60
|
|
60
|
-
|
61
|
+
# Adjust posterior density
|
62
|
+
target -= ladj
|
61
63
|
|
62
|
-
|
63
|
-
if not callable(getattr(cls, "transform_pars", None)):
|
64
|
+
return t_params, target
|
64
65
|
|
65
|
-
def transform_pars(self: Model) -> Dict[str, Array]:
|
66
|
-
"""
|
67
|
-
Apply a custom transformation to `params` if needed.
|
68
66
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
return self.constrain_pars()
|
67
|
+
def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
|
68
|
+
"""
|
69
|
+
Apply a custom transformation to `params` if needed.
|
73
70
|
|
74
|
-
|
71
|
+
# Returns
|
72
|
+
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
73
|
+
"""
|
74
|
+
return self.constrain_pars()
|
bayinx/core/utils.py
CHANGED
@@ -1,54 +1 @@
|
|
1
|
-
from typing import Callable, Dict
|
2
1
|
|
3
|
-
import equinox as eqx
|
4
|
-
from jaxtyping import Array
|
5
|
-
|
6
|
-
|
7
|
-
class __MyMeta(type(eqx.Module)):
|
8
|
-
"""
|
9
|
-
Metaclass to ensure attribute types are respected.
|
10
|
-
"""
|
11
|
-
|
12
|
-
def __call__(cls, *args, **kwargs):
|
13
|
-
obj = super().__call__(*args, **kwargs)
|
14
|
-
|
15
|
-
# Check parameters are a Dict of JAX Arrays
|
16
|
-
if not isinstance(obj.params, Dict):
|
17
|
-
raise ValueError(
|
18
|
-
f"Model {cls.__name__} must initialize 'params' as a dictionary."
|
19
|
-
)
|
20
|
-
|
21
|
-
for key, value in obj.params.items():
|
22
|
-
if not isinstance(value, Array):
|
23
|
-
raise TypeError(f"Parameter '{key}' must be a JAX Array.")
|
24
|
-
|
25
|
-
# Check constraints are a Dict of functions
|
26
|
-
if not isinstance(obj.constraints, Dict):
|
27
|
-
raise ValueError(
|
28
|
-
f"Model {cls.__name__} must initialize 'constraints' as a dictionary."
|
29
|
-
)
|
30
|
-
|
31
|
-
for key, value in obj.constraints.items():
|
32
|
-
if not isinstance(value, Callable):
|
33
|
-
raise TypeError(f"Constraint for parameter '{key}' must be a function.")
|
34
|
-
|
35
|
-
# Check that the constrain method returns a dict equivalent to `params`
|
36
|
-
t_params: Dict[str, Array] = obj.constrain_pars()
|
37
|
-
|
38
|
-
if not isinstance(t_params, Dict):
|
39
|
-
raise ValueError(
|
40
|
-
f"The 'constrain' method of {cls.__name__} must return a Dict of JAX Arrays."
|
41
|
-
)
|
42
|
-
|
43
|
-
for key, value in t_params.items():
|
44
|
-
if not isinstance(value, Array):
|
45
|
-
raise TypeError(f"Constrained parameter '{key}' must be a JAX Array.")
|
46
|
-
|
47
|
-
if not value.shape == obj.params[key].shape:
|
48
|
-
raise ValueError(
|
49
|
-
f"Constrained parameter '{key}' must have same shape as unconstrained counterpart."
|
50
|
-
)
|
51
|
-
|
52
|
-
# Check transform_pars
|
53
|
-
|
54
|
-
return obj
|
bayinx/core/variational.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
+
from functools import partial
|
2
3
|
from typing import Any, Callable, Self, Tuple
|
3
4
|
|
4
5
|
import equinox as eqx
|
@@ -60,108 +61,102 @@ class Variational(eqx.Module):
|
|
60
61
|
"""
|
61
62
|
pass
|
62
63
|
|
63
|
-
|
64
|
-
|
65
|
-
|
64
|
+
@eqx.filter_jit
|
65
|
+
@partial(jax.vmap, in_axes=(None, 0, None))
|
66
|
+
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
66
67
|
"""
|
68
|
+
Reconstruct models from variational draws and evaluate their posterior density.
|
67
69
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
- `var_draws`: Number of variational draws to draw each iteration.
|
106
|
-
- `key`: A PRNG key.
|
107
|
-
"""
|
108
|
-
# Partition variational
|
109
|
-
dyn, static = eqx.partition(self, self.filter_spec())
|
110
|
-
|
111
|
-
# Construct scheduler
|
112
|
-
schedule: Schedule = opx.cosine_decay_schedule(
|
113
|
-
init_value=learning_rate, decay_steps=max_iters
|
114
|
-
)
|
70
|
+
# Parameters
|
71
|
+
- `draws`: A set of variational draws.
|
72
|
+
- `data`: Data used to evaluate the posterior(if needed).
|
73
|
+
"""
|
74
|
+
# Unflatten variational draw
|
75
|
+
model: Model = self._unflatten(draws)
|
76
|
+
|
77
|
+
# Combine with constraints
|
78
|
+
model: Model = eqx.combine(model, self._constraints)
|
79
|
+
|
80
|
+
# Evaluate posterior density
|
81
|
+
return model.eval(data)
|
82
|
+
|
83
|
+
@eqx.filter_jit
|
84
|
+
def fit(
|
85
|
+
self,
|
86
|
+
max_iters: int,
|
87
|
+
data: Any = None,
|
88
|
+
learning_rate: float = 1,
|
89
|
+
weight_decay: float = 1e-4,
|
90
|
+
tolerance: float = 1e-4,
|
91
|
+
var_draws: int = 1,
|
92
|
+
key: Key = jr.PRNGKey(0),
|
93
|
+
) -> Self:
|
94
|
+
"""
|
95
|
+
Optimize the variational distribution.
|
96
|
+
|
97
|
+
# Parameters
|
98
|
+
- `max_iters`: Maximum number of iterations for the optimization loop.
|
99
|
+
- `data`: Data to evaluate the posterior density with(if available).
|
100
|
+
- `learning_rate`: Initial learning rate for optimization.
|
101
|
+
- `tolerance`: Relative tolerance of ELBO decrease for stopping early.
|
102
|
+
- `var_draws`: Number of variational draws to draw each iteration.
|
103
|
+
- `key`: A PRNG key.
|
104
|
+
"""
|
105
|
+
# Partition variational
|
106
|
+
dyn, static = eqx.partition(self, self.filter_spec())
|
115
107
|
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
opt_state: OptState = optim.init(dyn)
|
108
|
+
# Construct scheduler
|
109
|
+
schedule: Schedule = opx.cosine_decay_schedule(
|
110
|
+
init_value=learning_rate, decay_steps=max_iters
|
111
|
+
)
|
121
112
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
113
|
+
# Initialize optimizer
|
114
|
+
optim: GradientTransformation = opx.chain(
|
115
|
+
opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
|
116
|
+
)
|
117
|
+
opt_state: OptState = optim.init(dyn)
|
127
118
|
|
128
|
-
|
119
|
+
# Optimization loop helper functions
|
120
|
+
@eqx.filter_jit
|
121
|
+
def condition(state: Tuple[Self, OptState, Scalar, Key]):
|
122
|
+
# Unpack iteration state
|
123
|
+
dyn, opt_state, i, key = state
|
129
124
|
|
130
|
-
|
131
|
-
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
132
|
-
# Unpack iteration state
|
133
|
-
dyn, opt_state, i, key = state
|
125
|
+
return i < max_iters
|
134
126
|
|
135
|
-
|
136
|
-
|
127
|
+
@eqx.filter_jit
|
128
|
+
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
129
|
+
# Unpack iteration state
|
130
|
+
dyn, opt_state, i, key = state
|
137
131
|
|
138
|
-
|
139
|
-
|
132
|
+
# Update iteration
|
133
|
+
i = i + 1
|
140
134
|
|
141
|
-
|
142
|
-
|
135
|
+
# Update PRNG key
|
136
|
+
key, _ = jr.split(key)
|
143
137
|
|
144
|
-
|
145
|
-
|
138
|
+
# Combine variational
|
139
|
+
vari = eqx.combine(dyn, static)
|
146
140
|
|
147
|
-
|
148
|
-
|
149
|
-
updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
|
150
|
-
)
|
141
|
+
# Compute gradient of the ELBO
|
142
|
+
updates: PyTree = vari.elbo_grad(var_draws, key, data)
|
151
143
|
|
152
|
-
|
153
|
-
|
144
|
+
# Compute updates
|
145
|
+
updates, opt_state = optim.update(
|
146
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
|
147
|
+
)
|
154
148
|
|
155
|
-
|
149
|
+
# Update variational distribution
|
150
|
+
dyn = eqx.apply_updates(dyn, updates)
|
156
151
|
|
157
|
-
|
158
|
-
dyn = lax.while_loop(
|
159
|
-
cond_fun=condition,
|
160
|
-
body_fun=body,
|
161
|
-
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
162
|
-
)[0]
|
152
|
+
return dyn, opt_state, i, key
|
163
153
|
|
164
|
-
|
165
|
-
|
154
|
+
# Run optimization loop
|
155
|
+
dyn = lax.while_loop(
|
156
|
+
cond_fun=condition,
|
157
|
+
body_fun=body,
|
158
|
+
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
159
|
+
)[0]
|
166
160
|
|
167
|
-
|
161
|
+
# Return optimized variational
|
162
|
+
return eqx.combine(dyn, static)
|
@@ -1,10 +1,11 @@
|
|
1
1
|
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
2
|
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
-
bayinx/core/
|
5
|
-
bayinx/core/
|
6
|
-
bayinx/core/
|
7
|
-
bayinx/core/
|
4
|
+
bayinx/core/constraints.py,sha256=Y8FJX3CkgnLQ3HXuTPGuzvLtXlKs0B7z0-YymoHgdfg,1682
|
5
|
+
bayinx/core/flow.py,sha256=9swS5wh7AsIZWgG_IhQS-upcPlr7G-juaP_5rsbX6G0,2363
|
6
|
+
bayinx/core/model.py,sha256=U1xBnAXnIvFJjWF-XIT8BYjP9PtoRZY_PwyhRwJf-HA,2144
|
7
|
+
bayinx/core/utils.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
8
|
+
bayinx/core/variational.py,sha256=vUZ6u5CXCHfs6ZrA8PF4PHfmUXHTK2RJGHyZ3afFfsg,4820
|
8
9
|
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
10
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
10
11
|
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -15,13 +16,13 @@ bayinx/dists/uniform.py,sha256=PSZIIc2QfNF5XYi-TLGltnr_vnAIA-MZsn1rKV8QXAo,2353
|
|
15
16
|
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
16
17
|
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
17
18
|
bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
|
18
|
-
bayinx/mhx/vi/normalizing_flow.py,sha256=
|
19
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=nj7bpIoMJl6GTOXPxQCAsPArchbHd5vwwqMm3cLbMII,4791
|
19
20
|
bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
|
20
21
|
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
21
22
|
bayinx/mhx/vi/flows/fullaffine.py,sha256=2QbOtA1Jmu-yRcJeFmCKc8N1atm8G7JXYMLEZaEXKV0,1711
|
22
23
|
bayinx/mhx/vi/flows/planar.py,sha256=u9ZVwEeOv4fMfwiORlseCz463atsWMuid_9crRg05Z8,1919
|
23
24
|
bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
|
24
25
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
25
|
-
bayinx-0.2.
|
26
|
-
bayinx-0.2.
|
27
|
-
bayinx-0.2.
|
26
|
+
bayinx-0.2.25.dist-info/METADATA,sha256=9TI5NY4M1EBtYwA8E-EDds-QkOECwrfVaaEWq3pTdu4,3058
|
27
|
+
bayinx-0.2.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
28
|
+
bayinx-0.2.25.dist-info/RECORD,,
|
File without changes
|