bayinx 0.2.10__tar.gz → 0.2.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bayinx-0.2.10 → bayinx-0.2.12}/PKG-INFO +1 -1
- {bayinx-0.2.10 → bayinx-0.2.12}/pyproject.toml +2 -2
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/flows/__init__.py +1 -1
- bayinx-0.2.10/src/bayinx/mhx/vi/flows/affine.py → bayinx-0.2.12/src/bayinx/mhx/vi/flows/fullaffine.py +1 -1
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/flows/planar.py +5 -5
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/normalizing_flow.py +1 -1
- {bayinx-0.2.10 → bayinx-0.2.12}/tests/test_variational.py +6 -6
- bayinx-0.2.12/uv.lock +596 -0
- bayinx-0.2.10/uv.lock +0 -360
- {bayinx-0.2.10 → bayinx-0.2.12}/.github/workflows/release_and_publish.yml +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/.gitignore +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/README.md +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/flow.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/model.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/utils.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/core/variational.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/bernoulli.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/binomial.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/gamma.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/gamma2.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/dists/normal.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/flows/radial.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/meanfield.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/mhx/vi/standard.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/src/bayinx/py.typed +0 -0
- {bayinx-0.2.10 → bayinx-0.2.12}/tests/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "bayinx"
|
3
|
-
version = "0.2.
|
3
|
+
version = "0.2.12"
|
4
4
|
description = "Bayesian Inference with JAX"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.12"
|
@@ -19,4 +19,4 @@ build-backend = "hatchling.build"
|
|
19
19
|
addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
|
20
20
|
|
21
21
|
[dependency-groups]
|
22
|
-
dev = ["pytest>=8.3.5", "pytest-benchmark>=5.1.0"]
|
22
|
+
dev = ["matplotlib>=3.10.1", "pytest>=8.3.5", "pytest-benchmark>=5.1.0"]
|
@@ -30,16 +30,16 @@ class Planar(Flow):
|
|
30
30
|
- `dim`: The dimension of the parameter space.
|
31
31
|
"""
|
32
32
|
self.params = {
|
33
|
-
"u":
|
34
|
-
"w":
|
35
|
-
"b":
|
33
|
+
"u": jnp.ones(dim),
|
34
|
+
"w": jnp.ones(dim),
|
35
|
+
"b": jnp.zeros(1),
|
36
36
|
}
|
37
37
|
self.constraints = {}
|
38
38
|
|
39
39
|
@eqx.filter_jit
|
40
40
|
@partial(jax.vmap, in_axes=(None, 0))
|
41
41
|
def forward(self, draws: Array) -> Array:
|
42
|
-
params = self.
|
42
|
+
params = self.transform_pars()
|
43
43
|
|
44
44
|
# Extract parameters
|
45
45
|
w: Array = params["w"]
|
@@ -54,7 +54,7 @@ class Planar(Flow):
|
|
54
54
|
@eqx.filter_jit
|
55
55
|
@partial(jax.vmap, in_axes=(None, 0))
|
56
56
|
def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
|
57
|
-
params = self.
|
57
|
+
params = self.transform_pars()
|
58
58
|
|
59
59
|
# Extract parameters
|
60
60
|
w: Array = params["w"]
|
@@ -115,7 +115,7 @@ class NormalizingFlow(Variational):
|
|
115
115
|
return filter_spec
|
116
116
|
|
117
117
|
@eqx.filter_jit
|
118
|
-
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
118
|
+
def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
|
119
119
|
dyn, static = eqx.partition(self, self.filter_spec())
|
120
120
|
|
121
121
|
@eqx.filter_jit
|
@@ -8,7 +8,7 @@ from jaxtyping import Array
|
|
8
8
|
from bayinx import Model
|
9
9
|
from bayinx.dists import normal
|
10
10
|
from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
|
11
|
-
from bayinx.mhx.vi.flows import
|
11
|
+
from bayinx.mhx.vi.flows import FullAffine, Planar, Radial
|
12
12
|
|
13
13
|
|
14
14
|
# Tests ----
|
@@ -44,7 +44,7 @@ def test_meanfield(benchmark, var_draws):
|
|
44
44
|
vari.fit(10000, var_draws=var_draws)
|
45
45
|
|
46
46
|
benchmark(benchmark_fit)
|
47
|
-
vari = vari.fit(
|
47
|
+
vari = vari.fit(20000)
|
48
48
|
|
49
49
|
# Assert parameters are roughly correct
|
50
50
|
assert all(abs(10.0 - vari.var_params["mean"]) < 0.1) and all(
|
@@ -77,14 +77,14 @@ def test_affine(benchmark, var_draws):
|
|
77
77
|
model = NormalDist()
|
78
78
|
|
79
79
|
# Construct normalizing flow variational
|
80
|
-
vari = NormalizingFlow(Standard(model), [
|
80
|
+
vari = NormalizingFlow(Standard(model), [FullAffine(2)], model)
|
81
81
|
|
82
82
|
# Optimize variational distribution
|
83
83
|
def benchmark_fit():
|
84
84
|
vari.fit(10000, var_draws=var_draws)
|
85
85
|
|
86
86
|
benchmark(benchmark_fit)
|
87
|
-
vari = vari.fit(
|
87
|
+
vari = vari.fit(20000)
|
88
88
|
|
89
89
|
params = vari.flows[0].constrain_pars()
|
90
90
|
assert (abs(10.0 - vari.flows[0].params["shift"]) < 0.1).all() and (
|
@@ -118,7 +118,7 @@ def test_flows(benchmark, var_draws):
|
|
118
118
|
|
119
119
|
# Construct normalizing flow variational
|
120
120
|
vari = NormalizingFlow(
|
121
|
-
Standard(model), [
|
121
|
+
Standard(model), [FullAffine(2), Planar(2), Radial(2)], model
|
122
122
|
)
|
123
123
|
|
124
124
|
# Optimize variational distribution
|
@@ -126,7 +126,7 @@ def test_flows(benchmark, var_draws):
|
|
126
126
|
vari.fit(10000, var_draws=var_draws)
|
127
127
|
|
128
128
|
benchmark(benchmark_fit)
|
129
|
-
vari = vari.fit(
|
129
|
+
vari = vari.fit(20000)
|
130
130
|
|
131
131
|
mean = vari.sample(1000).mean(0)
|
132
132
|
var = vari.sample(1000).var(0)
|