bayinx 0.2.10__tar.gz → 0.2.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bayinx-0.2.10 → bayinx-0.2.11}/PKG-INFO +1 -1
- {bayinx-0.2.10 → bayinx-0.2.11}/pyproject.toml +1 -1
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/flows/__init__.py +1 -1
- bayinx-0.2.10/src/bayinx/mhx/vi/flows/affine.py → bayinx-0.2.11/src/bayinx/mhx/vi/flows/fullaffine.py +1 -1
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/flows/planar.py +16 -2
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/normalizing_flow.py +1 -1
- {bayinx-0.2.10 → bayinx-0.2.11}/tests/test_variational.py +22 -15
- {bayinx-0.2.10 → bayinx-0.2.11}/.github/workflows/release_and_publish.yml +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/.gitignore +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/README.md +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/flow.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/model.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/utils.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/core/variational.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/bernoulli.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/binomial.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/gamma.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/gamma2.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/dists/normal.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/flows/radial.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/meanfield.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/mhx/vi/standard.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/src/bayinx/py.typed +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/tests/__init__.py +0 -0
- {bayinx-0.2.10 → bayinx-0.2.11}/uv.lock +0 -0
@@ -3,6 +3,7 @@ from typing import Callable, Dict, Tuple
|
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
5
|
import jax
|
6
|
+
import jax.nn as jnn
|
6
7
|
import jax.numpy as jnp
|
7
8
|
import jax.random as jr
|
8
9
|
from jaxtyping import Array, Float, Scalar
|
@@ -36,10 +37,23 @@ class Planar(Flow):
|
|
36
37
|
}
|
37
38
|
self.constraints = {}
|
38
39
|
|
40
|
+
def transform_pars(self):
|
41
|
+
params = self.params
|
42
|
+
|
43
|
+
u = params['u']
|
44
|
+
w = params['w']
|
45
|
+
b = params['b']
|
46
|
+
|
47
|
+
m = jnn.softplus(w.dot(u)) - 1.0
|
48
|
+
|
49
|
+
u = u + (m - w.dot(u)) * w / (w**2).sum()
|
50
|
+
|
51
|
+
return {'u': u, 'w': w, 'b': b}
|
52
|
+
|
39
53
|
@eqx.filter_jit
|
40
54
|
@partial(jax.vmap, in_axes=(None, 0))
|
41
55
|
def forward(self, draws: Array) -> Array:
|
42
|
-
params = self.
|
56
|
+
params = self.transform_pars()
|
43
57
|
|
44
58
|
# Extract parameters
|
45
59
|
w: Array = params["w"]
|
@@ -54,7 +68,7 @@ class Planar(Flow):
|
|
54
68
|
@eqx.filter_jit
|
55
69
|
@partial(jax.vmap, in_axes=(None, 0))
|
56
70
|
def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
|
57
|
-
params = self.
|
71
|
+
params = self.transform_pars()
|
58
72
|
|
59
73
|
# Extract parameters
|
60
74
|
w: Array = params["w"]
|
@@ -115,7 +115,7 @@ class NormalizingFlow(Variational):
|
|
115
115
|
return filter_spec
|
116
116
|
|
117
117
|
@eqx.filter_jit
|
118
|
-
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
118
|
+
def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
|
119
119
|
dyn, static = eqx.partition(self, self.filter_spec())
|
120
120
|
|
121
121
|
@eqx.filter_jit
|
@@ -8,7 +8,7 @@ from jaxtyping import Array
|
|
8
8
|
from bayinx import Model
|
9
9
|
from bayinx.dists import normal
|
10
10
|
from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
|
11
|
-
from bayinx.mhx.vi.flows import
|
11
|
+
from bayinx.mhx.vi.flows import FullAffine, Planar
|
12
12
|
|
13
13
|
|
14
14
|
# Tests ----
|
@@ -77,7 +77,7 @@ def test_affine(benchmark, var_draws):
|
|
77
77
|
model = NormalDist()
|
78
78
|
|
79
79
|
# Construct normalizing flow variational
|
80
|
-
vari = NormalizingFlow(Standard(model), [
|
80
|
+
vari = NormalizingFlow(Standard(model), [FullAffine(2)], model)
|
81
81
|
|
82
82
|
# Optimize variational distribution
|
83
83
|
def benchmark_fit():
|
@@ -95,30 +95,37 @@ def test_affine(benchmark, var_draws):
|
|
95
95
|
@pytest.mark.parametrize("var_draws", [1, 10, 100])
|
96
96
|
def test_flows(benchmark, var_draws):
|
97
97
|
# Construct model definition
|
98
|
-
class
|
98
|
+
class Banana(Model):
|
99
99
|
params: Dict[str, Array]
|
100
100
|
constraints: Dict[str, Callable[[Array], Array]]
|
101
101
|
|
102
102
|
def __init__(self):
|
103
|
-
self.params = {
|
103
|
+
self.params = {
|
104
|
+
'x': jnp.array(0.0),
|
105
|
+
'y': jnp.array(0.0)
|
106
|
+
}
|
104
107
|
self.constraints = {}
|
105
108
|
|
106
|
-
|
107
|
-
|
108
|
-
#
|
109
|
-
|
109
|
+
def eval(self, data = None):
|
110
|
+
params: Dict[str, Array] = self.params
|
111
|
+
# Extract parameters
|
112
|
+
x: Array = params['x']
|
113
|
+
y: Array = params['y']
|
110
114
|
|
111
|
-
#
|
112
|
-
|
113
|
-
|
114
|
-
)
|
115
|
+
# Initialize target density
|
116
|
+
target = jnp.array(0.0)
|
117
|
+
|
118
|
+
target += normal.logprob(x, mu = jnp.array(0.0), sigma = jnp.array(1.0))
|
119
|
+
target += normal.logprob(y, mu = x**2 + x, sigma = jnp.array(1.0))
|
120
|
+
|
121
|
+
return target
|
115
122
|
|
116
123
|
# Construct model
|
117
|
-
model =
|
124
|
+
model = Banana()
|
118
125
|
|
119
126
|
# Construct normalizing flow variational
|
120
127
|
vari = NormalizingFlow(
|
121
|
-
Standard(model), [
|
128
|
+
Standard(model), [FullAffine(2), Planar(2)], model
|
122
129
|
)
|
123
130
|
|
124
131
|
# Optimize variational distribution
|
@@ -126,7 +133,7 @@ def test_flows(benchmark, var_draws):
|
|
126
133
|
vari.fit(10000, var_draws=var_draws)
|
127
134
|
|
128
135
|
benchmark(benchmark_fit)
|
129
|
-
vari = vari.fit(
|
136
|
+
vari = vari.fit(100)
|
130
137
|
|
131
138
|
mean = vari.sample(1000).mean(0)
|
132
139
|
var = vari.sample(1000).var(0)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|