bayinx 0.2.10__tar.gz → 0.2.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {bayinx-0.2.10 → bayinx-0.2.12}/PKG-INFO +1 -1
  2. {bayinx-0.2.10 → bayinx-0.2.12}/pyproject.toml +2 -2
  3. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/flows/__init__.py +1 -1
  4. bayinx-0.2.10/src/bayinx/mhx/vi/flows/affine.py → bayinx-0.2.12/src/bayinx/mhx/vi/flows/fullaffine.py +1 -1
  5. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/flows/planar.py +5 -5
  6. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/normalizing_flow.py +1 -1
  7. {bayinx-0.2.10 → bayinx-0.2.12}/tests/test_variational.py +6 -6
  8. bayinx-0.2.12/uv.lock +596 -0
  9. bayinx-0.2.10/uv.lock +0 -360
  10. {bayinx-0.2.10 → bayinx-0.2.12}/.github/workflows/release_and_publish.yml +0 -0
  11. {bayinx-0.2.10 → bayinx-0.2.12}/.gitignore +0 -0
  12. {bayinx-0.2.10 → bayinx-0.2.12}/README.md +0 -0
  13. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/__init__.py +0 -0
  14. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/__init__.py +0 -0
  15. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/flow.py +0 -0
  16. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/model.py +0 -0
  17. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/utils.py +0 -0
  18. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/variational.py +0 -0
  19. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/__init__.py +0 -0
  20. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/bernoulli.py +0 -0
  21. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/binomial.py +0 -0
  22. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/gamma.py +0 -0
  23. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/gamma2.py +0 -0
  24. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/normal.py +0 -0
  25. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/__init__.py +0 -0
  26. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/__init__.py +0 -0
  27. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  28. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  29. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/meanfield.py +0 -0
  30. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/standard.py +0 -0
  31. {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/py.typed +0 -0
  32. {bayinx-0.2.10 → bayinx-0.2.12}/tests/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.10
3
+ Version: 0.2.12
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.2.10"
3
+ version = "0.2.12"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -19,4 +19,4 @@ build-backend = "hatchling.build"
19
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
20
 
21
21
  [dependency-groups]
22
- dev = ["pytest>=8.3.5", "pytest-benchmark>=5.1.0"]
22
+ dev = ["matplotlib>=3.10.1", "pytest>=8.3.5", "pytest-benchmark>=5.1.0"]
@@ -1,3 +1,3 @@
1
- from bayinx.mhx.vi.flows.affine import Affine as Affine
1
+ from bayinx.mhx.vi.flows.fullaffine import FullAffine as FullAffine
2
2
  from bayinx.mhx.vi.flows.planar import Planar as Planar
3
3
  from bayinx.mhx.vi.flows.radial import Radial as Radial
@@ -9,7 +9,7 @@ from jaxtyping import Array, Float, Scalar
9
9
  from bayinx.core import Flow
10
10
 
11
11
 
12
- class Affine(Flow):
12
+ class FullAffine(Flow):
13
13
  """
14
14
  An affine flow.
15
15
 
@@ -30,16 +30,16 @@ class Planar(Flow):
30
30
  - `dim`: The dimension of the parameter space.
31
31
  """
32
32
  self.params = {
33
- "u": jr.normal(key, (dim,)),
34
- "w": jr.normal(key, (dim,)),
35
- "b": jr.normal(key, (1,)),
33
+ "u": jnp.ones(dim),
34
+ "w": jnp.ones(dim),
35
+ "b": jnp.zeros(1),
36
36
  }
37
37
  self.constraints = {}
38
38
 
39
39
  @eqx.filter_jit
40
40
  @partial(jax.vmap, in_axes=(None, 0))
41
41
  def forward(self, draws: Array) -> Array:
42
- params = self.constrain_pars()
42
+ params = self.transform_pars()
43
43
 
44
44
  # Extract parameters
45
45
  w: Array = params["w"]
@@ -54,7 +54,7 @@ class Planar(Flow):
54
54
  @eqx.filter_jit
55
55
  @partial(jax.vmap, in_axes=(None, 0))
56
56
  def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
57
- params = self.constrain_pars()
57
+ params = self.transform_pars()
58
58
 
59
59
  # Extract parameters
60
60
  w: Array = params["w"]
@@ -115,7 +115,7 @@ class NormalizingFlow(Variational):
115
115
  return filter_spec
116
116
 
117
117
  @eqx.filter_jit
118
- def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
118
+ def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
119
119
  dyn, static = eqx.partition(self, self.filter_spec())
120
120
 
121
121
  @eqx.filter_jit
@@ -8,7 +8,7 @@ from jaxtyping import Array
8
8
  from bayinx import Model
9
9
  from bayinx.dists import normal
10
10
  from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
11
- from bayinx.mhx.vi.flows import Affine, Planar, Radial
11
+ from bayinx.mhx.vi.flows import FullAffine, Planar, Radial
12
12
 
13
13
 
14
14
  # Tests ----
@@ -44,7 +44,7 @@ def test_meanfield(benchmark, var_draws):
44
44
  vari.fit(10000, var_draws=var_draws)
45
45
 
46
46
  benchmark(benchmark_fit)
47
- vari = vari.fit(10000)
47
+ vari = vari.fit(20000)
48
48
 
49
49
  # Assert parameters are roughly correct
50
50
  assert all(abs(10.0 - vari.var_params["mean"]) < 0.1) and all(
@@ -77,14 +77,14 @@ def test_affine(benchmark, var_draws):
77
77
  model = NormalDist()
78
78
 
79
79
  # Construct normalizing flow variational
80
- vari = NormalizingFlow(Standard(model), [Affine(2)], model)
80
+ vari = NormalizingFlow(Standard(model), [FullAffine(2)], model)
81
81
 
82
82
  # Optimize variational distribution
83
83
  def benchmark_fit():
84
84
  vari.fit(10000, var_draws=var_draws)
85
85
 
86
86
  benchmark(benchmark_fit)
87
- vari = vari.fit(10000)
87
+ vari = vari.fit(20000)
88
88
 
89
89
  params = vari.flows[0].constrain_pars()
90
90
  assert (abs(10.0 - vari.flows[0].params["shift"]) < 0.1).all() and (
@@ -118,7 +118,7 @@ def test_flows(benchmark, var_draws):
118
118
 
119
119
  # Construct normalizing flow variational
120
120
  vari = NormalizingFlow(
121
- Standard(model), [Planar(2), Radial(2), Planar(2), Radial(2), Planar(2)], model
121
+ Standard(model), [FullAffine(2), Planar(2), Radial(2)], model
122
122
  )
123
123
 
124
124
  # Optimize variational distribution
@@ -126,7 +126,7 @@ def test_flows(benchmark, var_draws):
126
126
  vari.fit(10000, var_draws=var_draws)
127
127
 
128
128
  benchmark(benchmark_fit)
129
- vari = vari.fit(10000)
129
+ vari = vari.fit(20000)
130
130
 
131
131
  mean = vari.sample(1000).mean(0)
132
132
  var = vari.sample(1000).var(0)