bayinx 0.3.17__tar.gz → 0.3.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {bayinx-0.3.17 → bayinx-0.3.19}/PKG-INFO +1 -1
  2. {bayinx-0.3.17 → bayinx-0.3.19}/pyproject.toml +2 -2
  3. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/core/_variational.py +14 -7
  4. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/vi/meanfield.py +20 -15
  5. {bayinx-0.3.17 → bayinx-0.3.19}/tests/test_variational.py +2 -2
  6. {bayinx-0.3.17 → bayinx-0.3.19}/.github/workflows/release_and_publish.yml +0 -0
  7. {bayinx-0.3.17 → bayinx-0.3.19}/.gitignore +0 -0
  8. {bayinx-0.3.17 → bayinx-0.3.19}/LICENSE +0 -0
  9. {bayinx-0.3.17 → bayinx-0.3.19}/README.md +0 -0
  10. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/__init__.py +0 -0
  11. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/constraints/__init__.py +0 -0
  12. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/constraints/lower.py +0 -0
  13. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/core/__init__.py +0 -0
  14. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/core/_constraint.py +0 -0
  15. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/core/_flow.py +0 -0
  16. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/core/_model.py +0 -0
  17. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/core/_optimization.py +0 -0
  18. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/core/_parameter.py +0 -0
  19. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/__init__.py +0 -0
  20. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/bernoulli.py +0 -0
  21. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/censored/__init__.py +0 -0
  22. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/censored/gamma2/__init__.py +0 -0
  23. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/censored/gamma2/r.py +0 -0
  24. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/censored/posnormal/__init__.py +0 -0
  25. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/censored/posnormal/r.py +0 -0
  26. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/gamma2.py +0 -0
  27. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/normal.py +0 -0
  28. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/posnormal.py +0 -0
  29. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/dists/uniform.py +0 -0
  30. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/__init__.py +0 -0
  31. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/opt/__init__.py +0 -0
  32. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/vi/__init__.py +0 -0
  33. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  34. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
  35. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  36. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  37. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  38. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/vi/normalizing_flow.py +0 -0
  39. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/mhx/vi/standard.py +0 -0
  40. {bayinx-0.3.17 → bayinx-0.3.19}/src/bayinx/py.typed +0 -0
  41. {bayinx-0.3.17 → bayinx-0.3.19}/tests/__init__.py +0 -0
  42. {bayinx-0.3.17 → bayinx-0.3.19}/tests/test_predictive.py +0 -0
  43. {bayinx-0.3.17 → bayinx-0.3.19}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.17
3
+ Version: 0.3.19
4
4
  Summary: Bayesian Inference with JAX
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.3.17"
3
+ version = "0.3.19"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
19
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
20
 
21
21
  [tool.bumpversion]
22
- current_version = "0.3.17"
22
+ current_version = "0.3.19"
23
23
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
24
  serialize = ["{major}.{minor}.{patch}"]
25
25
  search = "{current_version}"
@@ -63,6 +63,16 @@ class Variational(eqx.Module, Generic[M]):
63
63
  """
64
64
  pass
65
65
 
66
+ @eqx.filter_jit
67
+ def reconstruct_model(self, draw: Array) -> M:
68
+ # Unflatten variational draw
69
+ model: M = self._unflatten(draw)
70
+
71
+ # Combine with constraints
72
+ model: M = eqx.combine(model, self._constraints)
73
+
74
+ return model
75
+
66
76
  @eqx.filter_jit
67
77
  @partial(jax.vmap, in_axes=(None, 0, None))
68
78
  def eval_model(self, draws: Array, data: Any = None) -> Array:
@@ -74,10 +84,7 @@ class Variational(eqx.Module, Generic[M]):
74
84
  - `data`: Data used to evaluate the posterior(if needed).
75
85
  """
76
86
  # Unflatten variational draw
77
- model: M = self._unflatten(draws)
78
-
79
- # Combine with constraints
80
- model: M = eqx.combine(model, self._constraints)
87
+ model: M = self.reconstruct_model(draws)
81
88
 
82
89
  # Evaluate posterior density
83
90
  return model.eval(data)
@@ -177,7 +184,7 @@ class Variational(eqx.Module, Generic[M]):
177
184
  ) -> Array:
178
185
  # Sample a single draw to evaluate shape of output
179
186
  draw: Array = self.sample(1, key)[0]
180
- output: Array = func(self._unflatten(draw), data)
187
+ output: Array = func(self.reconstruct_model(draw), data)
181
188
 
182
189
  # Allocate space for results
183
190
  results: Array = jnp.zeros((n,) + output.shape, dtype=output.dtype)
@@ -191,10 +198,10 @@ class Variational(eqx.Module, Generic[M]):
191
198
  next, key = jr.split(key)
192
199
 
193
200
  # Draw from variational
194
- draw: Array = self.sample(1, key)
201
+ draw: Array = self.sample(1, key)[0]
195
202
 
196
203
  # Reconstruct model
197
- model: M = self._unflatten(draw)
204
+ model: M = self.reconstruct_model(draw)
198
205
 
199
206
  # Update results with output
200
207
  results = results.at[i].set(func(model, data))
@@ -1,11 +1,11 @@
1
- from typing import Any, Dict, Generic, Self, TypeVar
1
+ from typing import Any, Generic, Self, TypeVar
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.numpy as jnp
5
5
  import jax.random as jr
6
6
  import jax.tree_util as jtu
7
7
  from jax.flatten_util import ravel_pytree
8
- from jaxtyping import Array, Float, Key, Scalar
8
+ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Model, Variational
11
11
  from bayinx.dists import normal
@@ -18,12 +18,14 @@ class MeanField(Variational, Generic[M]):
18
18
  A fully factorized Gaussian approximation to a posterior distribution.
19
19
 
20
20
  # Attributes
21
- - `var_params`: The variational parameters for the approximation.
21
+ - `mean`: The mean of the unconstrained approximation.
22
+ - `log_std` The log-transformed standard deviation of the unconstrained approximation.
22
23
  """
23
24
 
24
- var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
25
+ mean: Array
26
+ log_std: Array
25
27
 
26
- def __init__(self, model: M, init_log_std: float = 0.0):
28
+ def __init__(self, model: M, init_log_std: float = -5.0):
27
29
  """
28
30
  Constructs an unoptimized meanfield posterior approximation.
29
31
 
@@ -38,10 +40,8 @@ class MeanField(Variational, Generic[M]):
38
40
  params, self._unflatten = ravel_pytree(params)
39
41
 
40
42
  # Initialize variational parameters
41
- self.var_params = {
42
- "mean": params,
43
- "log_std": jnp.full(params.size, init_log_std, params.dtype),
44
- }
43
+ self.mean = params
44
+ self.log_std = jnp.full(params.size, init_log_std, params.dtype)
45
45
 
46
46
  @property
47
47
  @eqx.filter_jit
@@ -51,7 +51,12 @@ class MeanField(Variational, Generic[M]):
51
51
 
52
52
  # Specify variational parameters
53
53
  filter_spec = eqx.tree_at(
54
- lambda mf: mf.var_params,
54
+ lambda mf: mf.mean,
55
+ filter_spec,
56
+ replace=True,
57
+ )
58
+ filter_spec = eqx.tree_at(
59
+ lambda mf: mf.log_std,
55
60
  filter_spec,
56
61
  replace=True,
57
62
  )
@@ -62,9 +67,9 @@ class MeanField(Variational, Generic[M]):
62
67
  def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
63
68
  # Sample variational draws
64
69
  draws: Array = (
65
- jr.normal(key=key, shape=(n, self.var_params["mean"].size))
66
- * jnp.exp(self.var_params["log_std"])
67
- + self.var_params["mean"]
70
+ jr.normal(key=key, shape=(n, self.mean.size))
71
+ * jnp.exp(self.log_std)
72
+ + self.mean
68
73
  )
69
74
 
70
75
  return draws
@@ -73,8 +78,8 @@ class MeanField(Variational, Generic[M]):
73
78
  def eval(self, draws: Array) -> Array:
74
79
  return normal.logprob(
75
80
  x=draws,
76
- mu=self.var_params["mean"],
77
- sigma=jnp.exp(self.var_params["log_std"]),
81
+ mu=self.mean,
82
+ sigma=jnp.exp(self.log_std),
78
83
  ).sum(axis=1)
79
84
 
80
85
  @eqx.filter_jit
@@ -43,8 +43,8 @@ def test_meanfield(benchmark, var_draws):
43
43
  benchmark(benchmark_fit)
44
44
 
45
45
  # Assert parameters are roughly correct
46
- assert all(abs(10.0 - vari.var_params["mean"]) < 0.1) and all(
47
- abs(0.0 - vari.var_params["log_std"]) < 0.1
46
+ assert all(abs(10.0 - vari.mean) < 0.1) and all(
47
+ abs(0.0 - vari.log_std) < 0.1
48
48
  )
49
49
 
50
50
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes