bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +4 -3
- bayinx/constraints/identity.py +26 -0
- bayinx/constraints/interval.py +62 -0
- bayinx/constraints/lower.py +31 -24
- bayinx/constraints/upper.py +57 -0
- bayinx/core/__init__.py +0 -7
- bayinx/core/constraint.py +32 -0
- bayinx/core/context.py +42 -0
- bayinx/core/distribution.py +34 -0
- bayinx/core/flow.py +99 -0
- bayinx/core/model.py +228 -0
- bayinx/core/node.py +201 -0
- bayinx/core/types.py +17 -0
- bayinx/core/utils.py +109 -0
- bayinx/core/variational.py +170 -0
- bayinx/dists/__init__.py +5 -3
- bayinx/dists/bernoulli.py +180 -11
- bayinx/dists/binomial.py +215 -0
- bayinx/dists/exponential.py +211 -0
- bayinx/dists/normal.py +131 -59
- bayinx/dists/poisson.py +203 -0
- bayinx/flows/__init__.py +5 -0
- bayinx/flows/diagaffine.py +120 -0
- bayinx/flows/fullaffine.py +123 -0
- bayinx/flows/lowrankaffine.py +165 -0
- bayinx/flows/planar.py +155 -0
- bayinx/flows/radial.py +1 -0
- bayinx/flows/sylvester.py +225 -0
- bayinx/nodes/__init__.py +3 -0
- bayinx/nodes/continuous.py +64 -0
- bayinx/nodes/observed.py +36 -0
- bayinx/nodes/stochastic.py +25 -0
- bayinx/ops.py +104 -0
- bayinx/posterior.py +220 -0
- bayinx/vi/__init__.py +0 -0
- bayinx/{mhx/vi → vi}/meanfield.py +33 -29
- bayinx/vi/normalizing_flow.py +246 -0
- bayinx/vi/standard.py +95 -0
- bayinx-0.5.3.dist-info/METADATA +93 -0
- bayinx-0.5.3.dist-info/RECORD +44 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
- bayinx/core/_constraint.py +0 -28
- bayinx/core/_flow.py +0 -80
- bayinx/core/_model.py +0 -98
- bayinx/core/_parameter.py +0 -44
- bayinx/core/_variational.py +0 -181
- bayinx/dists/censored/__init__.py +0 -3
- bayinx/dists/censored/gamma2/__init__.py +0 -3
- bayinx/dists/censored/gamma2/r.py +0 -68
- bayinx/dists/censored/posnormal/__init__.py +0 -3
- bayinx/dists/censored/posnormal/r.py +0 -116
- bayinx/dists/gamma2.py +0 -49
- bayinx/dists/posnormal.py +0 -260
- bayinx/dists/uniform.py +0 -75
- bayinx/mhx/__init__.py +0 -1
- bayinx/mhx/vi/__init__.py +0 -5
- bayinx/mhx/vi/flows/__init__.py +0 -3
- bayinx/mhx/vi/flows/fullaffine.py +0 -75
- bayinx/mhx/vi/flows/planar.py +0 -74
- bayinx/mhx/vi/flows/radial.py +0 -94
- bayinx/mhx/vi/flows/sylvester.py +0 -19
- bayinx/mhx/vi/normalizing_flow.py +0 -149
- bayinx/mhx/vi/standard.py +0 -63
- bayinx-0.3.10.dist-info/METADATA +0 -39
- bayinx-0.3.10.dist-info/RECORD +0 -35
- /bayinx/{py.typed → flows/otflow.py} +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from typing import Callable, Optional, Self, Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax.flatten_util as jfu
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
import jax.tree_util as jtu
|
|
8
|
+
from jax.lax import scan
|
|
9
|
+
from jaxtyping import Array, PRNGKeyArray, Scalar
|
|
10
|
+
|
|
11
|
+
from bayinx.core.flow import FlowLayer
|
|
12
|
+
from bayinx.core.variational import M, Variational
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class NormalizingFlow(Variational[M]):
|
|
16
|
+
"""
|
|
17
|
+
An ordered collection of diffeomorphisms that map a base distribution to a
|
|
18
|
+
normalized approximation of a posterior distribution.
|
|
19
|
+
|
|
20
|
+
# Attributes
|
|
21
|
+
- `dim`: The dimension of the support.
|
|
22
|
+
- `base`: A base variational distribution.
|
|
23
|
+
- `flows`: An ordered collection of continuously parameterized diffeomorphisms.
|
|
24
|
+
"""
|
|
25
|
+
flows: list[FlowLayer]
|
|
26
|
+
base: Variational[M]
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
base: Variational[M],
|
|
31
|
+
flows: list[FlowLayer],
|
|
32
|
+
model: Optional[M] = None,
|
|
33
|
+
_static: Optional[M] = None,
|
|
34
|
+
_unflatten: Optional[Callable[[Array], M]] = None
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Constructs an unoptimized normalizing flow posterior approximation.
|
|
38
|
+
|
|
39
|
+
# Parameters
|
|
40
|
+
- `base`: The base variational distribution.
|
|
41
|
+
- `flows`: A list of flows.
|
|
42
|
+
- `model`: A probabilistic `Model` object.
|
|
43
|
+
"""
|
|
44
|
+
if model is not None:
|
|
45
|
+
# Partition model
|
|
46
|
+
params, self._static = eqx.partition(model, model.filter_spec)
|
|
47
|
+
|
|
48
|
+
# Flatten params component
|
|
49
|
+
_, self._unflatten = jfu.ravel_pytree(params)
|
|
50
|
+
elif _static is not None and _unflatten is not None:
|
|
51
|
+
self._static = _static
|
|
52
|
+
self._unflatten = _unflatten
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError("Either 'model' or '_static' and '_unflatten' must be specified.")
|
|
55
|
+
|
|
56
|
+
self.dim = base.dim
|
|
57
|
+
self.base = base
|
|
58
|
+
self.flows = flows
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def filter_spec(self) -> Self:
|
|
62
|
+
# Generate empty specification
|
|
63
|
+
filter_spec: Self = jtu.tree_map(lambda _: False, self)
|
|
64
|
+
|
|
65
|
+
# Specify variational parameters based on each flow's filter spec.
|
|
66
|
+
filter_spec: Self = eqx.tree_at(
|
|
67
|
+
lambda vari: vari.flows,
|
|
68
|
+
filter_spec,
|
|
69
|
+
replace=[flow.filter_spec for flow in self.flows],
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return filter_spec
|
|
73
|
+
|
|
74
|
+
@eqx.filter_jit
|
|
75
|
+
def sample(self, n: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Array:
|
|
76
|
+
# Sample from the base distribution
|
|
77
|
+
draws: Array = self.base.sample(n, key)
|
|
78
|
+
|
|
79
|
+
assert len(draws.shape) == 2
|
|
80
|
+
|
|
81
|
+
# Apply forward transformations
|
|
82
|
+
for map in self.flows:
|
|
83
|
+
draws = map.forward(draws)
|
|
84
|
+
|
|
85
|
+
assert len(draws.shape) == 2
|
|
86
|
+
|
|
87
|
+
return draws
|
|
88
|
+
|
|
89
|
+
@eqx.filter_jit
|
|
90
|
+
def eval(self, draws: Array) -> Array:
|
|
91
|
+
raise RuntimeError("Evaluating the variational density for a normalizing flow requires an analytic inverse to exist, which many useful flows do not have. Therefore, do not use this method.")
|
|
92
|
+
return jnp.full(draws.shape[0], jnp.nan)
|
|
93
|
+
|
|
94
|
+
@eqx.filter_jit
|
|
95
|
+
def __eval(self, draws: Array) -> Tuple[Array, Array]:
|
|
96
|
+
"""
|
|
97
|
+
Evaluate the posterior and variational densities together at the
|
|
98
|
+
transformed `draws` to avoid extra compute.
|
|
99
|
+
|
|
100
|
+
# Parameters
|
|
101
|
+
- `draws`: Draws from the base variational distribution.
|
|
102
|
+
|
|
103
|
+
# Returns
|
|
104
|
+
The posterior and variational densities as JAX Arrays.
|
|
105
|
+
"""
|
|
106
|
+
# Evaluate base density
|
|
107
|
+
variational_evals: Array = self.base.eval(draws)
|
|
108
|
+
|
|
109
|
+
# Shape checks
|
|
110
|
+
assert len(variational_evals.shape) == 1
|
|
111
|
+
assert len(draws.shape) == 2
|
|
112
|
+
|
|
113
|
+
for map in self.flows:
|
|
114
|
+
# Apply transformation
|
|
115
|
+
draws, ljas = map.forward_and_adjust(draws)
|
|
116
|
+
assert len(draws.shape) == 2
|
|
117
|
+
assert len(ljas.shape) == 1
|
|
118
|
+
|
|
119
|
+
# Adjust variational density
|
|
120
|
+
variational_evals = variational_evals - ljas
|
|
121
|
+
|
|
122
|
+
# Evaluate posterior at final variational draws
|
|
123
|
+
posterior_evals = self.eval_model(draws)
|
|
124
|
+
|
|
125
|
+
# Shape checks
|
|
126
|
+
assert len(posterior_evals.shape) == 1
|
|
127
|
+
assert len(variational_evals.shape) == 1
|
|
128
|
+
assert posterior_evals.shape == variational_evals.shape
|
|
129
|
+
|
|
130
|
+
return posterior_evals, variational_evals
|
|
131
|
+
|
|
132
|
+
@eqx.filter_jit
|
|
133
|
+
def elbo(self, n: int, batch_size: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Scalar:
|
|
134
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
|
135
|
+
|
|
136
|
+
def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
|
|
137
|
+
self = eqx.combine(dyn, static)
|
|
138
|
+
|
|
139
|
+
# Split keys
|
|
140
|
+
keys = jr.split(key, n // batch_size)
|
|
141
|
+
|
|
142
|
+
# Split ELBO calculation into batches
|
|
143
|
+
def batched_elbo(carry: None, key: PRNGKeyArray) -> Tuple[None, Array]:
|
|
144
|
+
# Draw from variational distribution
|
|
145
|
+
draws: Array = self.base.sample(batch_size, key)
|
|
146
|
+
|
|
147
|
+
# Evaluate posterior and variational densities
|
|
148
|
+
batched_post_evals, batched_vari_evals = self.__eval(draws)
|
|
149
|
+
|
|
150
|
+
# Compute ELBO estimate
|
|
151
|
+
batched_elbo_evals: Array = batched_post_evals - batched_vari_evals
|
|
152
|
+
|
|
153
|
+
return None, batched_elbo_evals
|
|
154
|
+
|
|
155
|
+
elbo_evals = scan(
|
|
156
|
+
batched_elbo,
|
|
157
|
+
init=None,
|
|
158
|
+
xs=keys,
|
|
159
|
+
length=n // batch_size
|
|
160
|
+
)[1]
|
|
161
|
+
|
|
162
|
+
# Compute average of ELBO estimates
|
|
163
|
+
return jnp.mean(elbo_evals)
|
|
164
|
+
|
|
165
|
+
return elbo(dyn, n, key)
|
|
166
|
+
|
|
167
|
+
@eqx.filter_jit
|
|
168
|
+
def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Self:
|
|
169
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
|
170
|
+
|
|
171
|
+
# Define ELBO function
|
|
172
|
+
def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
|
|
173
|
+
self = eqx.combine(dyn, static)
|
|
174
|
+
|
|
175
|
+
# Split key
|
|
176
|
+
keys = jr.split(key, n // batch_size)
|
|
177
|
+
|
|
178
|
+
# Split ELBO calculation into batches
|
|
179
|
+
def batched_elbo(carry: None, key: PRNGKeyArray) -> Tuple[None, Array]:
|
|
180
|
+
# Draw from variational distribution
|
|
181
|
+
draws: Array = self.base.sample(batch_size, key)
|
|
182
|
+
|
|
183
|
+
# Evaluate posterior and variational densities
|
|
184
|
+
batched_post_evals, batched_vari_evals = self.__eval(draws)
|
|
185
|
+
|
|
186
|
+
# Compute ELBO estimate
|
|
187
|
+
batched_elbo_evals: Array = batched_post_evals - batched_vari_evals
|
|
188
|
+
|
|
189
|
+
return None, batched_elbo_evals
|
|
190
|
+
|
|
191
|
+
elbo_evals = scan(
|
|
192
|
+
batched_elbo,
|
|
193
|
+
init=None,
|
|
194
|
+
xs=keys,
|
|
195
|
+
length=n // batch_size
|
|
196
|
+
)[1]
|
|
197
|
+
|
|
198
|
+
# Compute average of ELBO estimates
|
|
199
|
+
return jnp.mean(elbo_evals)
|
|
200
|
+
|
|
201
|
+
# Map to its gradient
|
|
202
|
+
elbo_grad: Callable[
|
|
203
|
+
[Self, int, PRNGKeyArray], Self
|
|
204
|
+
] = eqx.filter_grad(elbo)
|
|
205
|
+
|
|
206
|
+
return elbo_grad(dyn, n, key)
|
|
207
|
+
|
|
208
|
+
@eqx.filter_jit
|
|
209
|
+
def elbo_and_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Tuple[Scalar, Self]:
|
|
210
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
|
211
|
+
|
|
212
|
+
def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
|
|
213
|
+
self = eqx.combine(dyn, static)
|
|
214
|
+
|
|
215
|
+
# Split keys
|
|
216
|
+
keys = jr.split(key, n // batch_size)
|
|
217
|
+
|
|
218
|
+
# Split ELBO calculation into batches
|
|
219
|
+
def batched_elbo(carry: None, key: PRNGKeyArray) -> Tuple[None, Array]:
|
|
220
|
+
# Draw from variational distribution
|
|
221
|
+
draws: Array = self.base.sample(batch_size, key)
|
|
222
|
+
|
|
223
|
+
# Evaluate posterior and variational densities
|
|
224
|
+
batched_post_evals, batched_vari_evals = self.__eval(draws)
|
|
225
|
+
|
|
226
|
+
# Compute ELBO estimate
|
|
227
|
+
batched_elbo_evals: Array = batched_post_evals - batched_vari_evals
|
|
228
|
+
|
|
229
|
+
return None, batched_elbo_evals
|
|
230
|
+
|
|
231
|
+
elbo_evals = scan(
|
|
232
|
+
batched_elbo,
|
|
233
|
+
init=None,
|
|
234
|
+
xs=keys,
|
|
235
|
+
length=n // batch_size
|
|
236
|
+
)[1]
|
|
237
|
+
|
|
238
|
+
# Compute average of ELBO estimates
|
|
239
|
+
return jnp.mean(elbo_evals)
|
|
240
|
+
|
|
241
|
+
# Map to its value & gradient
|
|
242
|
+
elbo_and_grad: Callable[
|
|
243
|
+
[Self, int, PRNGKeyArray], Tuple[Scalar, Self]
|
|
244
|
+
] = eqx.filter_value_and_grad(elbo)
|
|
245
|
+
|
|
246
|
+
return elbo_and_grad(dyn, n, key)
|
bayinx/vi/standard.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Self, Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import jax.random as jr
|
|
6
|
+
import jax.tree_util as jtu
|
|
7
|
+
from jax.flatten_util import ravel_pytree
|
|
8
|
+
from jaxtyping import Array, PRNGKeyArray, PyTree, Scalar
|
|
9
|
+
|
|
10
|
+
from bayinx.core.variational import M, Variational
|
|
11
|
+
from bayinx.dists import normal
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Standard(Variational[M]):
|
|
15
|
+
"""
|
|
16
|
+
A standard normal approximation to a posterior distribution.
|
|
17
|
+
|
|
18
|
+
# Attributes
|
|
19
|
+
- `dim`: The dimension of the support.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, model: M):
|
|
23
|
+
"""
|
|
24
|
+
Constructs a standard normal approximation to a posterior distribution.
|
|
25
|
+
|
|
26
|
+
# Parameters
|
|
27
|
+
- `model`: A probabilistic `Model` object.
|
|
28
|
+
"""
|
|
29
|
+
# Partition model
|
|
30
|
+
params, self._static = eqx.partition(model, model.filter_spec)
|
|
31
|
+
|
|
32
|
+
# Flatten params component
|
|
33
|
+
params, self._unflatten = ravel_pytree(params)
|
|
34
|
+
|
|
35
|
+
# Store dimension of parameter space
|
|
36
|
+
self.dim = jnp.size(params)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@eqx.filter_jit
|
|
40
|
+
def sample(self, n: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Array:
|
|
41
|
+
# Sample variational draws
|
|
42
|
+
draws: Array = jr.normal(key=key, shape=(n, self.dim))
|
|
43
|
+
|
|
44
|
+
# Shape checks
|
|
45
|
+
assert len(draws.shape) == 2
|
|
46
|
+
|
|
47
|
+
return draws
|
|
48
|
+
|
|
49
|
+
@eqx.filter_jit
|
|
50
|
+
def eval(self, draws: Array) -> Array:
|
|
51
|
+
return normal.logprob(
|
|
52
|
+
x=draws,
|
|
53
|
+
mu=jnp.array(0.0),
|
|
54
|
+
sigma=jnp.array(1.0),
|
|
55
|
+
).sum(axis=1)
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def filter_spec(self):
|
|
59
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
|
60
|
+
|
|
61
|
+
return filter_spec
|
|
62
|
+
|
|
63
|
+
@eqx.filter_jit
|
|
64
|
+
def elbo(self, n: int, batch_size: int, key: PRNGKeyArray) -> Scalar:
|
|
65
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
|
66
|
+
|
|
67
|
+
@eqx.filter_jit
|
|
68
|
+
def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
|
|
69
|
+
vari = eqx.combine(dyn, static)
|
|
70
|
+
|
|
71
|
+
# Sample draws from variational distribution
|
|
72
|
+
draws: Array = vari.sample(n, key)
|
|
73
|
+
|
|
74
|
+
# Evaluate posterior density for each draw
|
|
75
|
+
posterior_evals: Array = vari.eval_model(draws)
|
|
76
|
+
|
|
77
|
+
# Evaluate variational density for each draw
|
|
78
|
+
variational_evals: Array = vari.eval(draws)
|
|
79
|
+
|
|
80
|
+
# Evaluate ELBO
|
|
81
|
+
return jnp.mean(posterior_evals - variational_evals)
|
|
82
|
+
|
|
83
|
+
return elbo(dyn, n, key)
|
|
84
|
+
|
|
85
|
+
@eqx.filter_jit
|
|
86
|
+
def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Self:
|
|
87
|
+
raise RuntimeError("Do not use the 'elbo_grad' method for a Standard variational approximation. It has no variational parameters.")
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
def elbo_and_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Tuple[Scalar, PyTree]:
|
|
91
|
+
"""
|
|
92
|
+
Evaluate the ELBO and its gradient.
|
|
93
|
+
"""
|
|
94
|
+
raise RuntimeError("Do not use the 'elbo_and_grad' method for a Standard variational approximation. It has no variational parameters.")
|
|
95
|
+
return self.elbo(n, key), self
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: bayinx
|
|
3
|
+
Version: 0.5.3
|
|
4
|
+
Summary: Bayesian Inference with JAX
|
|
5
|
+
Author: Todd Pocuca
|
|
6
|
+
Maintainer: Todd Pocuca
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Python: >=3.13
|
|
9
|
+
Requires-Dist: diffrax>=0.7.0
|
|
10
|
+
Requires-Dist: equinox>=0.13.2
|
|
11
|
+
Requires-Dist: jax>=0.8.0
|
|
12
|
+
Requires-Dist: jaxtyping>=0.2.36
|
|
13
|
+
Requires-Dist: optax>=0.2.4
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
# Bayinx: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
|
17
|
+
|
|
18
|
+
Bayinx is an embedded probabilistic programming language in Python, powered by
|
|
19
|
+
[JAX](https://mc-stan.org/). It is heavily inspired by and aims to have
|
|
20
|
+
feature parity with [Stan](https://mc-stan.org/), but extends the types of
|
|
21
|
+
objects you can work with and focuses on normalizing flows variational
|
|
22
|
+
inference for sampling.
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
## Coming From Stan
|
|
26
|
+
|
|
27
|
+
There are a few differences between the syntax of Bayinx and Stan.
|
|
28
|
+
First, as Bayinx is embedded in Python, model definitions are Pythonic and
|
|
29
|
+
rely on you defining a class that inherits from the `Model` base class:
|
|
30
|
+
|
|
31
|
+
```py
|
|
32
|
+
class MyModel(Model, init=False):
|
|
33
|
+
# ...
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
> Note: Users should specify `init=False` to avoid static type checkers from
|
|
37
|
+
raising irrelevant errors, but more importantly it should remind you that
|
|
38
|
+
you should **NOT** implement your own `__init__` method!
|
|
39
|
+
|
|
40
|
+
The `data` and `parameters` blocks in Stan are then combined into the attribute
|
|
41
|
+
definitions with Bayinx. For example, if we are modelling a simple normal distribution
|
|
42
|
+
with an unknown mean and variance 1, then we might write:
|
|
43
|
+
|
|
44
|
+
```py
|
|
45
|
+
class MyModel(Model, init=False):
|
|
46
|
+
mean: Continuous[Array] = define(shape = ()) # a scalar mean parameter
|
|
47
|
+
x: Observed[Array] = define(shape = 'n_obs') # a vector of observed values
|
|
48
|
+
|
|
49
|
+
# ...
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
The `model` block in Stan is then defined by implementing the `model` method with Bayinx:
|
|
53
|
+
|
|
54
|
+
```py
|
|
55
|
+
class MyModel(Model, init=False):
|
|
56
|
+
mean: Continuous[Array] = define(shape = ())
|
|
57
|
+
x: Observed[Array] = define(shape = 'n_obs')
|
|
58
|
+
|
|
59
|
+
def model(self, target):
|
|
60
|
+
# Equivalent to 'x ~ normal(mean, 1.0)' in Stan
|
|
61
|
+
self.x << Normal(self.mean, 1.0)
|
|
62
|
+
|
|
63
|
+
return target
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
Notice that the `~` operator in Stan has been replaced with `<<`, and to reference nodes of a model you must work with `self`.
|
|
67
|
+
|
|
68
|
+
> Note: Bayinx does not currently have something similar to `transformed data` or `transformed parameters`, however that is likely to be included in a future release.
|
|
69
|
+
|
|
70
|
+
You can then construct the variational approximation to the posterior:
|
|
71
|
+
|
|
72
|
+
```py
|
|
73
|
+
import bayinx as byx
|
|
74
|
+
from bayinx.flows import DiagAffine
|
|
75
|
+
import jax.numpy as jnp
|
|
76
|
+
|
|
77
|
+
# Fit variational approximation
|
|
78
|
+
posterior = byx.Posterior(MyModel, n_obs = 3, x = jnp.array([-1.0, 0.0, 1.0]))
|
|
79
|
+
posterior.configure(flowspecs = [DiagAffine()])
|
|
80
|
+
posterior.fit()
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
This approximation can then be worked with by sampling nodes:
|
|
84
|
+
|
|
85
|
+
```py
|
|
86
|
+
mean_draws = posterior.sample('mean', 10000)
|
|
87
|
+
print(mean_draws.mean())
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
## Roadmap
|
|
92
|
+
- [ ] Implement OT-Flow: https://arxiv.org/abs/2006.00104
|
|
93
|
+
- [ ]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
bayinx/__init__.py,sha256=hhpQ8JM9kzJLUqeV_72ZbUf9gxTFH-SFkq-GyaYscvI,126
|
|
2
|
+
bayinx/ops.py,sha256=QsKrk2tMOxcYXuUidQyiA1e5KbV5WH80XUvM-PV8wZc,2666
|
|
3
|
+
bayinx/posterior.py,sha256=ab7Ubx3BDTkejaNciWDR-J9GEgbygRZPUgc23rz7YOg,6760
|
|
4
|
+
bayinx/constraints/__init__.py,sha256=E9WFI5xPAuVOFTzaLKgG2uV8k5Pho9w0mlmsMYrkSSI,154
|
|
5
|
+
bayinx/constraints/identity.py,sha256=IMR2WHB_GL89IOgkY7sOOScMyMIJkkGpMWpbUVkfOUY,629
|
|
6
|
+
bayinx/constraints/interval.py,sha256=OeM_aED8pZPdhpyrxOUQjg7IXWjXX2GiJGOoocaU7WI,1811
|
|
7
|
+
bayinx/constraints/lower.py,sha256=xXP-vrpQwnBSUN_1f1qYSSKApodfHfveI0f86h6go_k,1517
|
|
8
|
+
bayinx/constraints/upper.py,sha256=RWppD7SKP6KciQt_Wrd_w59vIykm_vaUtE82E9UEcBs,1529
|
|
9
|
+
bayinx/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
bayinx/core/constraint.py,sha256=w5Roomp-YITFyIcfXpd_P18JYEdYo6yVcNFaYk0MiEE,932
|
|
11
|
+
bayinx/core/context.py,sha256=kSvZFHyveUAkEr-gokZjsG6zp7QqrzcXePFcaTtBFqM,1074
|
|
12
|
+
bayinx/core/distribution.py,sha256=d0dMxzh5NxOv9-IYqV9D-amBcGFLsHsxn8D598IRKEI,990
|
|
13
|
+
bayinx/core/flow.py,sha256=SSGxdLAZLuBUQ8g1D6-QwZPjWuVGTK0fzH6X8CAgFxw,2736
|
|
14
|
+
bayinx/core/model.py,sha256=3FiqNPcLI7QTPhwvH-orYgcKoitozqT0kGxMnoa0dV4,7808
|
|
15
|
+
bayinx/core/node.py,sha256=5nRi3YzKGTUhl6-AOPNkYmYX-xb8DjuOt_XLMVX9k8Q,6134
|
|
16
|
+
bayinx/core/types.py,sha256=of52_tUQurdyfbSzdHjQ0EJUWl8DFEW2Pia9Lr_n3Lk,378
|
|
17
|
+
bayinx/core/utils.py,sha256=_2CxYev5Gu85wMqqqydENmnygvYJh2zB76lBD2-s3y0,3519
|
|
18
|
+
bayinx/core/variational.py,sha256=J7vwGKIulVdaZCoqf7XaRt0Ku7LPzNaqa4xOH_fQ-Nw,4800
|
|
19
|
+
bayinx/dists/__init__.py,sha256=7-nWGyK5obf3lxSlAv5JbzEuQVrRxIC1tU_HseMUtXU,218
|
|
20
|
+
bayinx/dists/bernoulli.py,sha256=vHatAx6FCWCjm-epKu-QqlHjrUN_RXkidGWmzTlurOk,5178
|
|
21
|
+
bayinx/dists/binomial.py,sha256=ap6KN4qW_AbGMZNdz5_8VHtlQYRmIZGlhy9-Z6vulsM,6220
|
|
22
|
+
bayinx/dists/exponential.py,sha256=229Ank-WcJsWhzEY93PvclFWz11VDQLaM1uLEAstf7Y,5705
|
|
23
|
+
bayinx/dists/normal.py,sha256=mnAG2GMjyOdHQ6VAdkDrVJ-d3Rzg2MYqLgRq6-LWwBI,5948
|
|
24
|
+
bayinx/dists/poisson.py,sha256=iBcREtJg0xjXhxKSh0lMTmr7tWAfsO0O3FO-Ft8a6vw,5496
|
|
25
|
+
bayinx/flows/__init__.py,sha256=SozrytzAbeTckrcH_zpiS2bjnhkPsrwZc3nYsUhP4YE,299
|
|
26
|
+
bayinx/flows/diagaffine.py,sha256=cMc2QyC9xipy29FQc7Rzwjo0ga8ajHhxA4hwAIhthSg,3402
|
|
27
|
+
bayinx/flows/fullaffine.py,sha256=b6eXUu72ft_Jf2Kx1wUmdMFGZzn-Bz4J76Hpz7ZKoIw,3594
|
|
28
|
+
bayinx/flows/lowrankaffine.py,sha256=heCWs2iZTTf1BXL6M16yMW8p896Bf15TkMhHv-RaObQ,4769
|
|
29
|
+
bayinx/flows/otflow.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
+
bayinx/flows/planar.py,sha256=NLtDcXrT6sVJP4YW0dtgdkins_iEOmvLg0sYsLKUndU,4280
|
|
31
|
+
bayinx/flows/radial.py,sha256=lkIotasV1VteGgHOQxkN2UXRH-ja4oUkU4sXYB52JdQ,6
|
|
32
|
+
bayinx/flows/sylvester.py,sha256=K4Dc8QaKE0WQDRNCMhNdyuBjPnfhljQYTIFz2KBm-h8,6523
|
|
33
|
+
bayinx/nodes/__init__.py,sha256=HWb4Yi-wSd_Fr7AywsAed8EezJXemKagZFRlqypu_qY,141
|
|
34
|
+
bayinx/nodes/continuous.py,sha256=k5GkM6MO6ag4r9wXY48c1j3Hjc0KSalq4xUY9nIz07M,1817
|
|
35
|
+
bayinx/nodes/observed.py,sha256=9SiOHAeLVYSAaFJFuLgF6yiwk8cQgSreSXs5ZvLHaRE,943
|
|
36
|
+
bayinx/nodes/stochastic.py,sha256=2G89o4vC0I0Pw_2zLStaEHJ5OXUSAbVA0hyenyo0aIs,614
|
|
37
|
+
bayinx/vi/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
38
|
+
bayinx/vi/meanfield.py,sha256=N9YeItj9Dk6OpdtBQ4mm-OE4nQXJ-xtLDlbK1t9EP5M,3805
|
|
39
|
+
bayinx/vi/normalizing_flow.py,sha256=LsTZgi6ehDZ4LLuHfnLeaGEo2Z0_CZhRePcoNCXJgyo,8226
|
|
40
|
+
bayinx/vi/standard.py,sha256=r4dydWZMQv5QyCPjBLG3eyU3wczxFUu3pqdA2569eoI,2941
|
|
41
|
+
bayinx-0.5.3.dist-info/METADATA,sha256=ptDkp2X80xOz0pkKajne-pJSllGH2fNHLZzwYYwJt38,2937
|
|
42
|
+
bayinx-0.5.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
43
|
+
bayinx-0.5.3.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
|
44
|
+
bayinx-0.5.3.dist-info/RECORD,,
|
bayinx/core/_constraint.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
|
-
from typing import Tuple
|
|
3
|
-
|
|
4
|
-
import equinox as eqx
|
|
5
|
-
from jaxtyping import Scalar
|
|
6
|
-
|
|
7
|
-
from bayinx.core._parameter import Parameter
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class Constraint(eqx.Module):
|
|
11
|
-
"""
|
|
12
|
-
Abstract base class for defining parameter constraints.
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
@abstractmethod
|
|
16
|
-
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
|
17
|
-
"""
|
|
18
|
-
Applies the constraining transformation to a parameter and computes the log-absolute-Jacobian of the transformation.
|
|
19
|
-
|
|
20
|
-
# Parameters
|
|
21
|
-
- `x`: The unconstrained `Parameter`.
|
|
22
|
-
|
|
23
|
-
# Returns
|
|
24
|
-
A tuple containing:
|
|
25
|
-
- The constrained `Parameter`.
|
|
26
|
-
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
|
27
|
-
"""
|
|
28
|
-
pass
|
bayinx/core/_flow.py
DELETED
|
@@ -1,80 +0,0 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
|
-
from typing import Callable, Dict, Self, Tuple
|
|
3
|
-
|
|
4
|
-
import equinox as eqx
|
|
5
|
-
import jax.tree_util as jtu
|
|
6
|
-
from jaxtyping import Array, Float
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class Flow(eqx.Module):
|
|
10
|
-
"""
|
|
11
|
-
An abstract base class for a flow(of a normalizing flow).
|
|
12
|
-
|
|
13
|
-
# Attributes
|
|
14
|
-
- `params`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
|
|
15
|
-
- `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
params: Dict[str, Float[Array, "..."]]
|
|
19
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
|
20
|
-
|
|
21
|
-
@abstractmethod
|
|
22
|
-
def forward(self, draws: Array) -> Array:
|
|
23
|
-
"""
|
|
24
|
-
Computes the forward transformation of `draws`.
|
|
25
|
-
"""
|
|
26
|
-
pass
|
|
27
|
-
|
|
28
|
-
@abstractmethod
|
|
29
|
-
def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
|
|
30
|
-
"""
|
|
31
|
-
Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
|
|
32
|
-
|
|
33
|
-
# Returns
|
|
34
|
-
A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
|
|
35
|
-
"""
|
|
36
|
-
pass
|
|
37
|
-
|
|
38
|
-
# Default filter specification
|
|
39
|
-
@property
|
|
40
|
-
@eqx.filter_jit
|
|
41
|
-
def filter_spec(self):
|
|
42
|
-
"""
|
|
43
|
-
Generates a filter specification to subset relevant parameters for the flow.
|
|
44
|
-
"""
|
|
45
|
-
# Generate empty specification
|
|
46
|
-
filter_spec = jtu.tree_map(lambda _: False, self)
|
|
47
|
-
|
|
48
|
-
# Specify JAX Array parameters
|
|
49
|
-
filter_spec = eqx.tree_at(
|
|
50
|
-
lambda flow: flow.params,
|
|
51
|
-
filter_spec,
|
|
52
|
-
replace=jtu.tree_map(eqx.is_array, self.params),
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
return filter_spec
|
|
56
|
-
|
|
57
|
-
@eqx.filter_jit
|
|
58
|
-
def constrain_params(self: Self):
|
|
59
|
-
"""
|
|
60
|
-
Constrain `params` to the appropriate domain.
|
|
61
|
-
|
|
62
|
-
# Returns
|
|
63
|
-
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
|
64
|
-
"""
|
|
65
|
-
t_params = self.params
|
|
66
|
-
|
|
67
|
-
for par, map in self.constraints.items():
|
|
68
|
-
t_params[par] = map(t_params[par])
|
|
69
|
-
|
|
70
|
-
return t_params
|
|
71
|
-
|
|
72
|
-
@eqx.filter_jit
|
|
73
|
-
def transform_params(self: Self) -> Dict[str, Array]:
|
|
74
|
-
"""
|
|
75
|
-
Apply a custom transformation to `params` if needed.
|
|
76
|
-
|
|
77
|
-
# Returns
|
|
78
|
-
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
|
79
|
-
"""
|
|
80
|
-
return self.constrain_params()
|
bayinx/core/_model.py
DELETED
|
@@ -1,98 +0,0 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
|
-
from dataclasses import field, fields
|
|
3
|
-
from typing import Any, Self, Tuple
|
|
4
|
-
|
|
5
|
-
import equinox as eqx
|
|
6
|
-
import jax.numpy as jnp
|
|
7
|
-
import jax.tree as jt
|
|
8
|
-
from jaxtyping import Scalar
|
|
9
|
-
|
|
10
|
-
from ._constraint import Constraint
|
|
11
|
-
from ._parameter import Parameter
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def constrain(constraint: Constraint):
|
|
15
|
-
"""Defines constraint metadata."""
|
|
16
|
-
return field(metadata={"constraint": constraint})
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class Model(eqx.Module):
|
|
20
|
-
"""
|
|
21
|
-
An abstract base class used to define probabilistic models.
|
|
22
|
-
|
|
23
|
-
Annotate parameter attributes with `Parameter`.
|
|
24
|
-
|
|
25
|
-
Include constraints by setting them equal to `constrain(Constraint)`.
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
@abstractmethod
|
|
29
|
-
def eval(self, data: Any) -> Scalar:
|
|
30
|
-
pass
|
|
31
|
-
|
|
32
|
-
# Default filter specification
|
|
33
|
-
@property
|
|
34
|
-
@eqx.filter_jit
|
|
35
|
-
def filter_spec(self) -> Self:
|
|
36
|
-
"""
|
|
37
|
-
Generates a filter specification to subset relevant parameters for the model.
|
|
38
|
-
"""
|
|
39
|
-
# Generate empty specification
|
|
40
|
-
filter_spec: Self = jt.map(lambda _: False, self)
|
|
41
|
-
|
|
42
|
-
for f in fields(self):
|
|
43
|
-
# Extract attribute from field
|
|
44
|
-
attr = getattr(self, f.name)
|
|
45
|
-
|
|
46
|
-
# Check if attribute is a parameter
|
|
47
|
-
if isinstance(attr, Parameter):
|
|
48
|
-
# Update filter specification for parameter
|
|
49
|
-
filter_spec = eqx.tree_at(
|
|
50
|
-
lambda model: getattr(model, f.name),
|
|
51
|
-
filter_spec,
|
|
52
|
-
replace=attr.filter_spec,
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
return filter_spec
|
|
56
|
-
|
|
57
|
-
@eqx.filter_jit
|
|
58
|
-
def constrain_params(self) -> Tuple[Self, Scalar]:
|
|
59
|
-
"""
|
|
60
|
-
Constrain parameters to the appropriate domain.
|
|
61
|
-
|
|
62
|
-
# Returns
|
|
63
|
-
A constrained `Model` object and the adjustment to the posterior.
|
|
64
|
-
"""
|
|
65
|
-
constrained: Self = self
|
|
66
|
-
target: Scalar = jnp.array(0.0)
|
|
67
|
-
|
|
68
|
-
for f in fields(self):
|
|
69
|
-
# Extract attribute
|
|
70
|
-
attr = getattr(self, f.name)
|
|
71
|
-
|
|
72
|
-
# Check if constrained parameter
|
|
73
|
-
if isinstance(attr, Parameter) and "constraint" in f.metadata:
|
|
74
|
-
param = attr
|
|
75
|
-
constraint = f.metadata["constraint"]
|
|
76
|
-
|
|
77
|
-
# Apply constraint
|
|
78
|
-
param, laj = constraint.constrain(param)
|
|
79
|
-
|
|
80
|
-
# Update parameters for constrained model
|
|
81
|
-
constrained = eqx.tree_at(
|
|
82
|
-
lambda model: getattr(model, f.name), constrained, replace=param
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
# Adjust posterior density
|
|
86
|
-
target += laj
|
|
87
|
-
|
|
88
|
-
return constrained, target
|
|
89
|
-
|
|
90
|
-
@eqx.filter_jit
|
|
91
|
-
def transform_params(self) -> Tuple[Self, Scalar]:
|
|
92
|
-
"""
|
|
93
|
-
Apply a custom transformation to parameters if needed(defaults to constrained parameters).
|
|
94
|
-
|
|
95
|
-
# Returns
|
|
96
|
-
A transformed `Model` object and the adjustment to the posterior.
|
|
97
|
-
"""
|
|
98
|
-
return self.constrain_params()
|