bayinx 0.3.17__py3-none-any.whl → 0.3.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/core/_variational.py +13 -6
- bayinx/mhx/vi/meanfield.py +20 -15
- {bayinx-0.3.17.dist-info → bayinx-0.3.18.dist-info}/METADATA +1 -1
- {bayinx-0.3.17.dist-info → bayinx-0.3.18.dist-info}/RECORD +6 -6
- {bayinx-0.3.17.dist-info → bayinx-0.3.18.dist-info}/WHEEL +0 -0
- {bayinx-0.3.17.dist-info → bayinx-0.3.18.dist-info}/licenses/LICENSE +0 -0
bayinx/core/_variational.py
CHANGED
@@ -63,6 +63,16 @@ class Variational(eqx.Module, Generic[M]):
|
|
63
63
|
"""
|
64
64
|
pass
|
65
65
|
|
66
|
+
@eqx.filter_jit
|
67
|
+
def reconstruct_model(self, draw: Array) -> M:
|
68
|
+
# Unflatten variational draw
|
69
|
+
model: M = self._unflatten(draw)
|
70
|
+
|
71
|
+
# Combine with constraints
|
72
|
+
model: M = eqx.combine(model, self._constraints)
|
73
|
+
|
74
|
+
return model
|
75
|
+
|
66
76
|
@eqx.filter_jit
|
67
77
|
@partial(jax.vmap, in_axes=(None, 0, None))
|
68
78
|
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
@@ -74,10 +84,7 @@ class Variational(eqx.Module, Generic[M]):
|
|
74
84
|
- `data`: Data used to evaluate the posterior(if needed).
|
75
85
|
"""
|
76
86
|
# Unflatten variational draw
|
77
|
-
model: M = self.
|
78
|
-
|
79
|
-
# Combine with constraints
|
80
|
-
model: M = eqx.combine(model, self._constraints)
|
87
|
+
model: M = self.reconstruct_model(draws)
|
81
88
|
|
82
89
|
# Evaluate posterior density
|
83
90
|
return model.eval(data)
|
@@ -177,7 +184,7 @@ class Variational(eqx.Module, Generic[M]):
|
|
177
184
|
) -> Array:
|
178
185
|
# Sample a single draw to evaluate shape of output
|
179
186
|
draw: Array = self.sample(1, key)[0]
|
180
|
-
output: Array = func(self.
|
187
|
+
output: Array = func(self.reconstruct_model(draw), data)
|
181
188
|
|
182
189
|
# Allocate space for results
|
183
190
|
results: Array = jnp.zeros((n,) + output.shape, dtype=output.dtype)
|
@@ -194,7 +201,7 @@ class Variational(eqx.Module, Generic[M]):
|
|
194
201
|
draw: Array = self.sample(1, key)
|
195
202
|
|
196
203
|
# Reconstruct model
|
197
|
-
model: M = self.
|
204
|
+
model: M = self.reconstruct_model(draw)
|
198
205
|
|
199
206
|
# Update results with output
|
200
207
|
results = results.at[i].set(func(model, data))
|
bayinx/mhx/vi/meanfield.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Generic, Self, TypeVar
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.numpy as jnp
|
5
5
|
import jax.random as jr
|
6
6
|
import jax.tree_util as jtu
|
7
7
|
from jax.flatten_util import ravel_pytree
|
8
|
-
from jaxtyping import Array,
|
8
|
+
from jaxtyping import Array, Key, Scalar
|
9
9
|
|
10
10
|
from bayinx.core import Model, Variational
|
11
11
|
from bayinx.dists import normal
|
@@ -18,12 +18,14 @@ class MeanField(Variational, Generic[M]):
|
|
18
18
|
A fully factorized Gaussian approximation to a posterior distribution.
|
19
19
|
|
20
20
|
# Attributes
|
21
|
-
- `
|
21
|
+
- `mean`: The mean of the unconstrained approximation.
|
22
|
+
- `log_std` The log-transformed standard deviation of the unconstrained approximation.
|
22
23
|
"""
|
23
24
|
|
24
|
-
|
25
|
+
mean: Array
|
26
|
+
log_std: Array
|
25
27
|
|
26
|
-
def __init__(self, model: M, init_log_std: float =
|
28
|
+
def __init__(self, model: M, init_log_std: float = -5.0):
|
27
29
|
"""
|
28
30
|
Constructs an unoptimized meanfield posterior approximation.
|
29
31
|
|
@@ -38,10 +40,8 @@ class MeanField(Variational, Generic[M]):
|
|
38
40
|
params, self._unflatten = ravel_pytree(params)
|
39
41
|
|
40
42
|
# Initialize variational parameters
|
41
|
-
self.
|
42
|
-
|
43
|
-
"log_std": jnp.full(params.size, init_log_std, params.dtype),
|
44
|
-
}
|
43
|
+
self.mean = params
|
44
|
+
self.log_std = jnp.full(params.size, init_log_std, params.dtype)
|
45
45
|
|
46
46
|
@property
|
47
47
|
@eqx.filter_jit
|
@@ -51,7 +51,12 @@ class MeanField(Variational, Generic[M]):
|
|
51
51
|
|
52
52
|
# Specify variational parameters
|
53
53
|
filter_spec = eqx.tree_at(
|
54
|
-
lambda mf: mf.
|
54
|
+
lambda mf: mf.mean,
|
55
|
+
filter_spec,
|
56
|
+
replace=True,
|
57
|
+
)
|
58
|
+
filter_spec = eqx.tree_at(
|
59
|
+
lambda mf: mf.log_std,
|
55
60
|
filter_spec,
|
56
61
|
replace=True,
|
57
62
|
)
|
@@ -62,9 +67,9 @@ class MeanField(Variational, Generic[M]):
|
|
62
67
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
63
68
|
# Sample variational draws
|
64
69
|
draws: Array = (
|
65
|
-
jr.normal(key=key, shape=(n, self.
|
66
|
-
* jnp.exp(self.
|
67
|
-
+ self.
|
70
|
+
jr.normal(key=key, shape=(n, self.mean.size))
|
71
|
+
* jnp.exp(self.log_std)
|
72
|
+
+ self.mean
|
68
73
|
)
|
69
74
|
|
70
75
|
return draws
|
@@ -73,8 +78,8 @@ class MeanField(Variational, Generic[M]):
|
|
73
78
|
def eval(self, draws: Array) -> Array:
|
74
79
|
return normal.logprob(
|
75
80
|
x=draws,
|
76
|
-
mu=self.
|
77
|
-
sigma=jnp.exp(self.
|
81
|
+
mu=self.mean,
|
82
|
+
sigma=jnp.exp(self.log_std),
|
78
83
|
).sum(axis=1)
|
79
84
|
|
80
85
|
@eqx.filter_jit
|
@@ -8,7 +8,7 @@ bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
|
8
8
|
bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
|
9
9
|
bayinx/core/_optimization.py,sha256=mmeVUqfFARz8F7q4LRl-uEwVWzekmzh-9o7PnuvsHZk,2651
|
10
10
|
bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
|
11
|
-
bayinx/core/_variational.py,sha256=
|
11
|
+
bayinx/core/_variational.py,sha256=NlCvBHsfP-7HeNimhD4x9JLT5VWeiCBUADr3bx6tt84,6381
|
12
12
|
bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
|
13
13
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
14
14
|
bayinx/dists/gamma2.py,sha256=HtB60LUQdPj4yDAHme2jsHNmLfrAKWsSZnDYkxAGaOI,1548
|
@@ -23,7 +23,7 @@ bayinx/dists/censored/posnormal/r.py,sha256=wMDt2Am1TD376ms8B-o6PFCJZXmUJd2-aBC-
|
|
23
23
|
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
24
24
|
bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
|
26
|
-
bayinx/mhx/vi/meanfield.py,sha256=
|
26
|
+
bayinx/mhx/vi/meanfield.py,sha256=iX4AeDG9jrLZd6d9NimuJ3O5zaoBXsD03JbgPgxVrfY,3917
|
27
27
|
bayinx/mhx/vi/normalizing_flow.py,sha256=vzLu5H1G1-pBqhgHWmIZkUTyPE1DxC9vBwpiZeIyu1I,4712
|
28
28
|
bayinx/mhx/vi/standard.py,sha256=LYgglaGQMGmXpzFR4eMJnXkl2PhBeggbXMvO5zJpf2c,1578
|
29
29
|
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
@@ -31,7 +31,7 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
|
|
31
31
|
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
32
32
|
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
33
33
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
34
|
-
bayinx-0.3.
|
35
|
-
bayinx-0.3.
|
36
|
-
bayinx-0.3.
|
37
|
-
bayinx-0.3.
|
34
|
+
bayinx-0.3.18.dist-info/METADATA,sha256=osJ-IhnQs0OuZoA6AGk0GKSwibISSNaOeyLg5F0T4w4,3087
|
35
|
+
bayinx-0.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
36
|
+
bayinx-0.3.18.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
37
|
+
bayinx-0.3.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|