bayinx 0.3.16__py3-none-any.whl → 0.3.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -63,6 +63,16 @@ class Variational(eqx.Module, Generic[M]):
63
63
  """
64
64
  pass
65
65
 
66
+ @eqx.filter_jit
67
+ def reconstruct_model(self, draw: Array) -> M:
68
+ # Unflatten variational draw
69
+ model: M = self._unflatten(draw)
70
+
71
+ # Combine with constraints
72
+ model: M = eqx.combine(model, self._constraints)
73
+
74
+ return model
75
+
66
76
  @eqx.filter_jit
67
77
  @partial(jax.vmap, in_axes=(None, 0, None))
68
78
  def eval_model(self, draws: Array, data: Any = None) -> Array:
@@ -74,10 +84,7 @@ class Variational(eqx.Module, Generic[M]):
74
84
  - `data`: Data used to evaluate the posterior(if needed).
75
85
  """
76
86
  # Unflatten variational draw
77
- model: M = self._unflatten(draws)
78
-
79
- # Combine with constraints
80
- model: M = eqx.combine(model, self._constraints)
87
+ model: M = self.reconstruct_model(draws)
81
88
 
82
89
  # Evaluate posterior density
83
90
  return model.eval(data)
@@ -176,8 +183,8 @@ class Variational(eqx.Module, Generic[M]):
176
183
  key: Key = jr.PRNGKey(0),
177
184
  ) -> Array:
178
185
  # Sample a single draw to evaluate shape of output
179
- draw: Array = self.sample(1, key)
180
- output: Array = func(self._unflatten(draw), data)
186
+ draw: Array = self.sample(1, key)[0]
187
+ output: Array = func(self.reconstruct_model(draw), data)
181
188
 
182
189
  # Allocate space for results
183
190
  results: Array = jnp.zeros((n,) + output.shape, dtype=output.dtype)
@@ -194,7 +201,7 @@ class Variational(eqx.Module, Generic[M]):
194
201
  draw: Array = self.sample(1, key)
195
202
 
196
203
  # Reconstruct model
197
- model: M = self._unflatten(draw)
204
+ model: M = self.reconstruct_model(draw)
198
205
 
199
206
  # Update results with output
200
207
  results = results.at[i].set(func(model, data))
@@ -1,11 +1,11 @@
1
- from typing import Any, Dict, Generic, Self, TypeVar
1
+ from typing import Any, Generic, Self, TypeVar
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.numpy as jnp
5
5
  import jax.random as jr
6
6
  import jax.tree_util as jtu
7
7
  from jax.flatten_util import ravel_pytree
8
- from jaxtyping import Array, Float, Key, Scalar
8
+ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Model, Variational
11
11
  from bayinx.dists import normal
@@ -18,12 +18,14 @@ class MeanField(Variational, Generic[M]):
18
18
  A fully factorized Gaussian approximation to a posterior distribution.
19
19
 
20
20
  # Attributes
21
- - `var_params`: The variational parameters for the approximation.
21
+ - `mean`: The mean of the unconstrained approximation.
22
+ - `log_std` The log-transformed standard deviation of the unconstrained approximation.
22
23
  """
23
24
 
24
- var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
25
+ mean: Array
26
+ log_std: Array
25
27
 
26
- def __init__(self, model: M, init_log_std: float = 0.0):
28
+ def __init__(self, model: M, init_log_std: float = -5.0):
27
29
  """
28
30
  Constructs an unoptimized meanfield posterior approximation.
29
31
 
@@ -38,10 +40,8 @@ class MeanField(Variational, Generic[M]):
38
40
  params, self._unflatten = ravel_pytree(params)
39
41
 
40
42
  # Initialize variational parameters
41
- self.var_params = {
42
- "mean": params,
43
- "log_std": jnp.full(params.size, init_log_std, params.dtype),
44
- }
43
+ self.mean = params
44
+ self.log_std = jnp.full(params.size, init_log_std, params.dtype)
45
45
 
46
46
  @property
47
47
  @eqx.filter_jit
@@ -51,7 +51,12 @@ class MeanField(Variational, Generic[M]):
51
51
 
52
52
  # Specify variational parameters
53
53
  filter_spec = eqx.tree_at(
54
- lambda mf: mf.var_params,
54
+ lambda mf: mf.mean,
55
+ filter_spec,
56
+ replace=True,
57
+ )
58
+ filter_spec = eqx.tree_at(
59
+ lambda mf: mf.log_std,
55
60
  filter_spec,
56
61
  replace=True,
57
62
  )
@@ -62,9 +67,9 @@ class MeanField(Variational, Generic[M]):
62
67
  def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
63
68
  # Sample variational draws
64
69
  draws: Array = (
65
- jr.normal(key=key, shape=(n, self.var_params["mean"].size))
66
- * jnp.exp(self.var_params["log_std"])
67
- + self.var_params["mean"]
70
+ jr.normal(key=key, shape=(n, self.mean.size))
71
+ * jnp.exp(self.log_std)
72
+ + self.mean
68
73
  )
69
74
 
70
75
  return draws
@@ -73,8 +78,8 @@ class MeanField(Variational, Generic[M]):
73
78
  def eval(self, draws: Array) -> Array:
74
79
  return normal.logprob(
75
80
  x=draws,
76
- mu=self.var_params["mean"],
77
- sigma=jnp.exp(self.var_params["log_std"]),
81
+ mu=self.mean,
82
+ sigma=jnp.exp(self.log_std),
78
83
  ).sum(axis=1)
79
84
 
80
85
  @eqx.filter_jit
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.16
3
+ Version: 0.3.18
4
4
  Summary: Bayesian Inference with JAX
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -8,7 +8,7 @@ bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
8
  bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
9
9
  bayinx/core/_optimization.py,sha256=mmeVUqfFARz8F7q4LRl-uEwVWzekmzh-9o7PnuvsHZk,2651
10
10
  bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
11
- bayinx/core/_variational.py,sha256=MbOgWs0-mkEwIYsh1Hhn7anWkQ7fAGjkr6VJybYAuR0,6185
11
+ bayinx/core/_variational.py,sha256=NlCvBHsfP-7HeNimhD4x9JLT5VWeiCBUADr3bx6tt84,6381
12
12
  bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
13
13
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
14
14
  bayinx/dists/gamma2.py,sha256=HtB60LUQdPj4yDAHme2jsHNmLfrAKWsSZnDYkxAGaOI,1548
@@ -23,7 +23,7 @@ bayinx/dists/censored/posnormal/r.py,sha256=wMDt2Am1TD376ms8B-o6PFCJZXmUJd2-aBC-
23
23
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
24
24
  bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
26
- bayinx/mhx/vi/meanfield.py,sha256=wM0v6Q2m0pPyEdOaT8DvFqdRLsHijr7AYNQKAcsZtRQ,3881
26
+ bayinx/mhx/vi/meanfield.py,sha256=iX4AeDG9jrLZd6d9NimuJ3O5zaoBXsD03JbgPgxVrfY,3917
27
27
  bayinx/mhx/vi/normalizing_flow.py,sha256=vzLu5H1G1-pBqhgHWmIZkUTyPE1DxC9vBwpiZeIyu1I,4712
28
28
  bayinx/mhx/vi/standard.py,sha256=LYgglaGQMGmXpzFR4eMJnXkl2PhBeggbXMvO5zJpf2c,1578
29
29
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
@@ -31,7 +31,7 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
31
31
  bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
32
32
  bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
33
33
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
34
- bayinx-0.3.16.dist-info/METADATA,sha256=9oLNZAc1YvN9FpWiuFFlztL-OkhclOYPBQFkuiHmUT8,3087
35
- bayinx-0.3.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
- bayinx-0.3.16.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
37
- bayinx-0.3.16.dist-info/RECORD,,
34
+ bayinx-0.3.18.dist-info/METADATA,sha256=osJ-IhnQs0OuZoA6AGk0GKSwibISSNaOeyLg5F0T4w4,3087
35
+ bayinx-0.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ bayinx-0.3.18.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
37
+ bayinx-0.3.18.dist-info/RECORD,,