bayinx 0.2.10__tar.gz → 0.2.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {bayinx-0.2.10 → bayinx-0.2.11}/PKG-INFO +1 -1
  2. {bayinx-0.2.10 → bayinx-0.2.11}/pyproject.toml +1 -1
  3. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/flows/__init__.py +1 -1
  4. bayinx-0.2.10/src/bayinx/mhx/vi/flows/affine.py → bayinx-0.2.11/src/bayinx/mhx/vi/flows/fullaffine.py +1 -1
  5. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/flows/planar.py +16 -2
  6. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/normalizing_flow.py +1 -1
  7. {bayinx-0.2.10 → bayinx-0.2.11}/tests/test_variational.py +22 -15
  8. {bayinx-0.2.10 → bayinx-0.2.11}/.github/workflows/release_and_publish.yml +0 -0
  9. {bayinx-0.2.10 → bayinx-0.2.11}/.gitignore +0 -0
  10. {bayinx-0.2.10 → bayinx-0.2.11}/README.md +0 -0
  11. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/__init__.py +0 -0
  12. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/__init__.py +0 -0
  13. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/flow.py +0 -0
  14. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/model.py +0 -0
  15. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/utils.py +0 -0
  16. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/variational.py +0 -0
  17. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/__init__.py +0 -0
  18. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/bernoulli.py +0 -0
  19. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/binomial.py +0 -0
  20. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/gamma.py +0 -0
  21. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/gamma2.py +0 -0
  22. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/normal.py +0 -0
  23. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/__init__.py +0 -0
  24. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/__init__.py +0 -0
  25. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  26. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  27. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/meanfield.py +0 -0
  28. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/standard.py +0 -0
  29. {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/py.typed +0 -0
  30. {bayinx-0.2.10 → bayinx-0.2.11}/tests/__init__.py +0 -0
  31. {bayinx-0.2.10 → bayinx-0.2.11}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.10
3
+ Version: 0.2.11
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.2.10"
3
+ version = "0.2.11"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -1,3 +1,3 @@
1
- from bayinx.mhx.vi.flows.affine import Affine as Affine
1
+ from bayinx.mhx.vi.flows.fullaffine import FullAffine as FullAffine
2
2
  from bayinx.mhx.vi.flows.planar import Planar as Planar
3
3
  from bayinx.mhx.vi.flows.radial import Radial as Radial
@@ -9,7 +9,7 @@ from jaxtyping import Array, Float, Scalar
9
9
  from bayinx.core import Flow
10
10
 
11
11
 
12
- class Affine(Flow):
12
+ class FullAffine(Flow):
13
13
  """
14
14
  An affine flow.
15
15
 
@@ -3,6 +3,7 @@ from typing import Callable, Dict, Tuple
3
3
 
4
4
  import equinox as eqx
5
5
  import jax
6
+ import jax.nn as jnn
6
7
  import jax.numpy as jnp
7
8
  import jax.random as jr
8
9
  from jaxtyping import Array, Float, Scalar
@@ -36,10 +37,23 @@ class Planar(Flow):
36
37
  }
37
38
  self.constraints = {}
38
39
 
40
+ def transform_pars(self):
41
+ params = self.params
42
+
43
+ u = params['u']
44
+ w = params['w']
45
+ b = params['b']
46
+
47
+ m = jnn.softplus(w.dot(u)) - 1.0
48
+
49
+ u = u + (m - w.dot(u)) * w / (w**2).sum()
50
+
51
+ return {'u': u, 'w': w, 'b': b}
52
+
39
53
  @eqx.filter_jit
40
54
  @partial(jax.vmap, in_axes=(None, 0))
41
55
  def forward(self, draws: Array) -> Array:
42
- params = self.constrain_pars()
56
+ params = self.transform_pars()
43
57
 
44
58
  # Extract parameters
45
59
  w: Array = params["w"]
@@ -54,7 +68,7 @@ class Planar(Flow):
54
68
  @eqx.filter_jit
55
69
  @partial(jax.vmap, in_axes=(None, 0))
56
70
  def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
57
- params = self.constrain_pars()
71
+ params = self.transform_pars()
58
72
 
59
73
  # Extract parameters
60
74
  w: Array = params["w"]
@@ -115,7 +115,7 @@ class NormalizingFlow(Variational):
115
115
  return filter_spec
116
116
 
117
117
  @eqx.filter_jit
118
- def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
118
+ def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
119
119
  dyn, static = eqx.partition(self, self.filter_spec())
120
120
 
121
121
  @eqx.filter_jit
@@ -8,7 +8,7 @@ from jaxtyping import Array
8
8
  from bayinx import Model
9
9
  from bayinx.dists import normal
10
10
  from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
11
- from bayinx.mhx.vi.flows import Affine, Planar, Radial
11
+ from bayinx.mhx.vi.flows import FullAffine, Planar
12
12
 
13
13
 
14
14
  # Tests ----
@@ -77,7 +77,7 @@ def test_affine(benchmark, var_draws):
77
77
  model = NormalDist()
78
78
 
79
79
  # Construct normalizing flow variational
80
- vari = NormalizingFlow(Standard(model), [Affine(2)], model)
80
+ vari = NormalizingFlow(Standard(model), [FullAffine(2)], model)
81
81
 
82
82
  # Optimize variational distribution
83
83
  def benchmark_fit():
@@ -95,30 +95,37 @@ def test_affine(benchmark, var_draws):
95
95
  @pytest.mark.parametrize("var_draws", [1, 10, 100])
96
96
  def test_flows(benchmark, var_draws):
97
97
  # Construct model definition
98
- class NormalDist(Model):
98
+ class Banana(Model):
99
99
  params: Dict[str, Array]
100
100
  constraints: Dict[str, Callable[[Array], Array]]
101
101
 
102
102
  def __init__(self):
103
- self.params = {"mu": jnp.array([0.0, 0.0])}
103
+ self.params = {
104
+ 'x': jnp.array(0.0),
105
+ 'y': jnp.array(0.0)
106
+ }
104
107
  self.constraints = {}
105
108
 
106
- @eqx.filter_jit
107
- def eval(self, data: dict):
108
- # Get constrained parameters
109
- params = self.constrain_pars()
109
+ def eval(self, data = None):
110
+ params: Dict[str, Array] = self.params
111
+ # Extract parameters
112
+ x: Array = params['x']
113
+ y: Array = params['y']
110
114
 
111
- # Evaluate mu ~ N(10,1)
112
- return jnp.sum(
113
- normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0))
114
- )
115
+ # Initialize target density
116
+ target = jnp.array(0.0)
117
+
118
+ target += normal.logprob(x, mu = jnp.array(0.0), sigma = jnp.array(1.0))
119
+ target += normal.logprob(y, mu = x**2 + x, sigma = jnp.array(1.0))
120
+
121
+ return target
115
122
 
116
123
  # Construct model
117
- model = NormalDist()
124
+ model = Banana()
118
125
 
119
126
  # Construct normalizing flow variational
120
127
  vari = NormalizingFlow(
121
- Standard(model), [Planar(2), Radial(2), Planar(2), Radial(2), Planar(2)], model
128
+ Standard(model), [FullAffine(2), Planar(2)], model
122
129
  )
123
130
 
124
131
  # Optimize variational distribution
@@ -126,7 +133,7 @@ def test_flows(benchmark, var_draws):
126
133
  vari.fit(10000, var_draws=var_draws)
127
134
 
128
135
  benchmark(benchmark_fit)
129
- vari = vari.fit(10000)
136
+ vari = vari.fit(100)
130
137
 
131
138
  mean = vari.sample(1000).mean(0)
132
139
  var = vari.sample(1000).var(0)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes