bayinx 0.2.32__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bayinx-0.2.32 → bayinx-0.3.0}/PKG-INFO +1 -1
- {bayinx-0.2.32 → bayinx-0.3.0}/pyproject.toml +2 -2
- bayinx-0.3.0/src/bayinx/__init__.py +2 -0
- bayinx-0.3.0/src/bayinx/constraints/lower.py +51 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/core/__init__.py +1 -0
- bayinx-0.3.0/src/bayinx/core/constraint.py +28 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/core/flow.py +6 -4
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/core/model.py +13 -10
- bayinx-0.3.0/src/bayinx/core/parameter.py +41 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/core/variational.py +2 -2
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/vi/flows/fullaffine.py +2 -2
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/vi/flows/planar.py +2 -2
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/vi/flows/radial.py +2 -2
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/vi/meanfield.py +19 -17
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/vi/normalizing_flow.py +6 -4
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/vi/standard.py +3 -3
- {bayinx-0.2.32 → bayinx-0.3.0}/tests/test_variational.py +12 -22
- bayinx-0.2.32/src/bayinx/__init__.py +0 -1
- bayinx-0.2.32/src/bayinx/constraints/lower.py +0 -37
- bayinx-0.2.32/src/bayinx/core/constraint.py +0 -26
- {bayinx-0.2.32 → bayinx-0.3.0}/.github/workflows/release_and_publish.yml +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/.gitignore +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/README.md +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/constraints/__init__.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/dists/__init__.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/dists/bernoulli.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/dists/censored/__init__.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/dists/censored/gamma2/r.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/dists/gamma2.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/dists/normal.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/dists/uniform.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/__init__.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/vi/__init__.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/src/bayinx/py.typed +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/tests/__init__.py +0 -0
- {bayinx-0.2.32 → bayinx-0.3.0}/uv.lock +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "bayinx"
|
3
|
-
version = "0.
|
3
|
+
version = "0.3.0"
|
4
4
|
description = "Bayesian Inference with JAX"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.12"
|
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
|
|
19
19
|
addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
|
20
20
|
|
21
21
|
[tool.bumpversion]
|
22
|
-
current_version = "0.
|
22
|
+
current_version = "0.3.0"
|
23
23
|
parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
|
24
24
|
serialize = ["{major}.{minor}.{patch}"]
|
25
25
|
search = "{current_version}"
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jax.tree as jt
|
6
|
+
from jaxtyping import PyTree, Scalar, ScalarLike
|
7
|
+
|
8
|
+
from bayinx.core.constraint import Constraint
|
9
|
+
from bayinx.core.parameter import Parameter
|
10
|
+
|
11
|
+
|
12
|
+
class Lower(Constraint):
|
13
|
+
"""
|
14
|
+
Enforces a lower bound on the parameter.
|
15
|
+
"""
|
16
|
+
|
17
|
+
lb: Scalar
|
18
|
+
|
19
|
+
def __init__(self, lb: ScalarLike):
|
20
|
+
self.lb = jnp.array(lb)
|
21
|
+
|
22
|
+
@eqx.filter_jit
|
23
|
+
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
24
|
+
"""
|
25
|
+
Enforces a lower bound on the parameter and adjusts the posterior density.
|
26
|
+
|
27
|
+
# Parameters
|
28
|
+
- `x`: The unconstrained `Parameter`.
|
29
|
+
|
30
|
+
# Parameters
|
31
|
+
A tuple containing:
|
32
|
+
- A modified `Parameter` with relevant leaves satisfying the constraint.
|
33
|
+
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
34
|
+
"""
|
35
|
+
# Extract relevant filter specification
|
36
|
+
filter_spec = x.filter_spec
|
37
|
+
|
38
|
+
# Extract relevant parameters(all Array)
|
39
|
+
dyn_params, static_params = eqx.partition(x, filter_spec)
|
40
|
+
|
41
|
+
# Compute density adjustment
|
42
|
+
laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
|
43
|
+
laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
|
44
|
+
|
45
|
+
# Compute transformation
|
46
|
+
dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
|
47
|
+
|
48
|
+
# Combine into full parameter object
|
49
|
+
x = eqx.combine(dyn_params, static_params)
|
50
|
+
|
51
|
+
return x, laj
|
@@ -0,0 +1,28 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
from jaxtyping import Scalar
|
6
|
+
|
7
|
+
from bayinx.core.parameter import Parameter
|
8
|
+
|
9
|
+
|
10
|
+
class Constraint(eqx.Module):
|
11
|
+
"""
|
12
|
+
Abstract base class for defining parameter constraints.
|
13
|
+
"""
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
17
|
+
"""
|
18
|
+
Applies the constraining transformation to a parameter and computes the log-absolute-Jacobian of the transformation.
|
19
|
+
|
20
|
+
# Parameters
|
21
|
+
- `x`: The unconstrained `Parameter`.
|
22
|
+
|
23
|
+
# Returns
|
24
|
+
A tuple containing:
|
25
|
+
- The constrained `Parameter`.
|
26
|
+
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
27
|
+
"""
|
28
|
+
pass
|
@@ -31,11 +31,13 @@ class Flow(eqx.Module):
|
|
31
31
|
Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
|
32
32
|
|
33
33
|
# Returns
|
34
|
-
|
34
|
+
A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
|
35
35
|
"""
|
36
36
|
pass
|
37
37
|
|
38
38
|
# Default filter specification
|
39
|
+
@property
|
40
|
+
@eqx.filter_jit
|
39
41
|
def filter_spec(self):
|
40
42
|
"""
|
41
43
|
Generates a filter specification to subset relevant parameters for the flow.
|
@@ -53,7 +55,7 @@ class Flow(eqx.Module):
|
|
53
55
|
return filter_spec
|
54
56
|
|
55
57
|
@eqx.filter_jit
|
56
|
-
def
|
58
|
+
def constrain_params(self: Self):
|
57
59
|
"""
|
58
60
|
Constrain `params` to the appropriate domain.
|
59
61
|
|
@@ -68,11 +70,11 @@ class Flow(eqx.Module):
|
|
68
70
|
return t_params
|
69
71
|
|
70
72
|
@eqx.filter_jit
|
71
|
-
def
|
73
|
+
def transform_params(self: Self) -> Dict[str, Array]:
|
72
74
|
"""
|
73
75
|
Apply a custom transformation to `params` if needed.
|
74
76
|
|
75
77
|
# Returns
|
76
78
|
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
77
79
|
"""
|
78
|
-
return self.
|
80
|
+
return self.constrain_params()
|
@@ -4,9 +4,10 @@ from typing import Any, Dict, Tuple
|
|
4
4
|
import equinox as eqx
|
5
5
|
import jax.numpy as jnp
|
6
6
|
import jax.tree as jt
|
7
|
-
from jaxtyping import Array, Scalar
|
7
|
+
from jaxtyping import Array, PyTree, Scalar
|
8
8
|
|
9
9
|
from bayinx.core.constraint import Constraint
|
10
|
+
from bayinx.core.parameter import Parameter
|
10
11
|
|
11
12
|
|
12
13
|
class Model(eqx.Module):
|
@@ -14,11 +15,11 @@ class Model(eqx.Module):
|
|
14
15
|
An abstract base class used to define probabilistic models.
|
15
16
|
|
16
17
|
# Attributes
|
17
|
-
- `params`: A dictionary of
|
18
|
+
- `params`: A dictionary of Arrays or PyTrees containing Arrays representing parameters of the model.
|
18
19
|
- `constraints`: A dictionary of constraints.
|
19
20
|
"""
|
20
21
|
|
21
|
-
params: Dict[str,
|
22
|
+
params: Dict[str, Parameter]
|
22
23
|
constraints: Dict[str, Constraint]
|
23
24
|
|
24
25
|
@abstractmethod
|
@@ -26,6 +27,8 @@ class Model(eqx.Module):
|
|
26
27
|
pass
|
27
28
|
|
28
29
|
# Default filter specification
|
30
|
+
@property
|
31
|
+
@eqx.filter_jit
|
29
32
|
def filter_spec(self):
|
30
33
|
"""
|
31
34
|
Generates a filter specification to subset relevant parameters for the model.
|
@@ -33,25 +36,25 @@ class Model(eqx.Module):
|
|
33
36
|
# Generate empty specification
|
34
37
|
filter_spec = jt.map(lambda _: False, self)
|
35
38
|
|
36
|
-
# Specify
|
39
|
+
# Specify relevant parameters
|
37
40
|
filter_spec = eqx.tree_at(
|
38
41
|
lambda model: model.params,
|
39
42
|
filter_spec,
|
40
|
-
replace=
|
43
|
+
replace={key: param.filter_spec for key, param in self.params.items()}
|
41
44
|
)
|
42
45
|
|
43
46
|
return filter_spec
|
44
47
|
|
45
48
|
# Add constrain method
|
46
49
|
@eqx.filter_jit
|
47
|
-
def
|
50
|
+
def constrain_params(self) -> Tuple[Dict[str, Parameter], Scalar]:
|
48
51
|
"""
|
49
52
|
Constrain `params` to the appropriate domain.
|
50
53
|
|
51
54
|
# Returns
|
52
|
-
A dictionary of
|
55
|
+
A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
|
53
56
|
"""
|
54
|
-
t_params: Dict[str, Array] = self.params
|
57
|
+
t_params: Dict[str, Array | PyTree] = self.params
|
55
58
|
target: Scalar = jnp.array(0.0)
|
56
59
|
|
57
60
|
for par, map in self.constraints.items():
|
@@ -64,11 +67,11 @@ class Model(eqx.Module):
|
|
64
67
|
return t_params, target
|
65
68
|
|
66
69
|
# Add default transform method
|
67
|
-
def
|
70
|
+
def transform_params(self) -> Tuple[Dict[str, Parameter], Scalar]:
|
68
71
|
"""
|
69
72
|
Apply a custom transformation to `params` if needed.
|
70
73
|
|
71
74
|
# Returns
|
72
75
|
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
73
76
|
"""
|
74
|
-
return self.
|
77
|
+
return self.constrain_params()
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Self
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.tree as jt
|
5
|
+
from jaxtyping import Array, PyTree
|
6
|
+
|
7
|
+
|
8
|
+
class Parameter(eqx.Module):
|
9
|
+
"""
|
10
|
+
A container for a parameter of a `Model`.
|
11
|
+
|
12
|
+
Subclasses can be constructed for custom filter specifications(`filter_spec`).
|
13
|
+
|
14
|
+
# Attributes
|
15
|
+
- `vals`: The parameter's value(s).
|
16
|
+
"""
|
17
|
+
vals: Array | PyTree
|
18
|
+
|
19
|
+
|
20
|
+
def __init__(self, values: Array | PyTree):
|
21
|
+
# Insert parameter values
|
22
|
+
self.vals = values
|
23
|
+
|
24
|
+
# Default filter specification
|
25
|
+
@property
|
26
|
+
@eqx.filter_jit
|
27
|
+
def filter_spec(self) -> Self:
|
28
|
+
"""
|
29
|
+
Generates a filter specification to filter out static parameters.
|
30
|
+
"""
|
31
|
+
# Generate empty specification
|
32
|
+
filter_spec = jt.map(lambda _: False, self)
|
33
|
+
|
34
|
+
# Specify Array leaves
|
35
|
+
filter_spec = eqx.tree_at(
|
36
|
+
lambda params: params.vals,
|
37
|
+
filter_spec,
|
38
|
+
replace=jt.map(eqx.is_array_like, self.vals),
|
39
|
+
)
|
40
|
+
|
41
|
+
return filter_spec
|
@@ -103,7 +103,7 @@ class Variational(eqx.Module):
|
|
103
103
|
- `key`: A PRNG key.
|
104
104
|
"""
|
105
105
|
# Partition variational
|
106
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
106
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
107
107
|
|
108
108
|
# Construct scheduler
|
109
109
|
schedule: Schedule = opx.cosine_decay_schedule(
|
@@ -143,7 +143,7 @@ class Variational(eqx.Module):
|
|
143
143
|
|
144
144
|
# Compute updates
|
145
145
|
updates, opt_state = optim.update(
|
146
|
-
updates, opt_state, eqx.filter(dyn, dyn.filter_spec
|
146
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
|
147
147
|
)
|
148
148
|
|
149
149
|
# Update variational distribution
|
@@ -46,7 +46,7 @@ class FullAffine(Flow):
|
|
46
46
|
|
47
47
|
@eqx.filter_jit
|
48
48
|
def forward(self, draws: Array) -> Array:
|
49
|
-
params = self.
|
49
|
+
params = self.transform_params()
|
50
50
|
|
51
51
|
# Extract parameters
|
52
52
|
shift: Array = params["shift"]
|
@@ -60,7 +60,7 @@ class FullAffine(Flow):
|
|
60
60
|
@eqx.filter_jit
|
61
61
|
@partial(jax.vmap, in_axes=(None, 0))
|
62
62
|
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
63
|
-
params = self.
|
63
|
+
params = self.transform_params()
|
64
64
|
|
65
65
|
# Extract parameters
|
66
66
|
shift: Array = params["shift"]
|
@@ -39,7 +39,7 @@ class Planar(Flow):
|
|
39
39
|
@eqx.filter_jit
|
40
40
|
@partial(jax.vmap, in_axes=(None, 0))
|
41
41
|
def forward(self, draws: Array) -> Array:
|
42
|
-
params = self.
|
42
|
+
params = self.transform_params()
|
43
43
|
|
44
44
|
# Extract parameters
|
45
45
|
w: Array = params["w"]
|
@@ -54,7 +54,7 @@ class Planar(Flow):
|
|
54
54
|
@eqx.filter_jit
|
55
55
|
@partial(jax.vmap, in_axes=(None, 0))
|
56
56
|
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
57
|
-
params = self.
|
57
|
+
params = self.transform_params()
|
58
58
|
|
59
59
|
# Extract parameters
|
60
60
|
w: Array = params["w"]
|
@@ -49,7 +49,7 @@ class Radial(Flow):
|
|
49
49
|
# Returns
|
50
50
|
The transformed samples.
|
51
51
|
"""
|
52
|
-
params = self.
|
52
|
+
params = self.transform_params()
|
53
53
|
|
54
54
|
# Extract parameters
|
55
55
|
alpha = params["alpha"]
|
@@ -67,7 +67,7 @@ class Radial(Flow):
|
|
67
67
|
@partial(jax.vmap, in_axes=(None, 0))
|
68
68
|
@eqx.filter_jit
|
69
69
|
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
70
|
-
params = self.
|
70
|
+
params = self.transform_params()
|
71
71
|
|
72
72
|
# Extract parameters
|
73
73
|
alpha = params["alpha"]
|
@@ -29,7 +29,7 @@ class MeanField(Variational):
|
|
29
29
|
- `model`: A probabilistic `Model` object.
|
30
30
|
"""
|
31
31
|
# Partition model
|
32
|
-
params, self._constraints = eqx.partition(model, model.filter_spec
|
32
|
+
params, self._constraints = eqx.partition(model, model.filter_spec)
|
33
33
|
|
34
34
|
# Flatten params component
|
35
35
|
params, self._unflatten = ravel_pytree(params)
|
@@ -40,6 +40,22 @@ class MeanField(Variational):
|
|
40
40
|
"log_std": jnp.zeros(params.size, dtype=params.dtype),
|
41
41
|
}
|
42
42
|
|
43
|
+
@property
|
44
|
+
@eqx.filter_jit
|
45
|
+
def filter_spec(self):
|
46
|
+
# Generate empty specification
|
47
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
48
|
+
|
49
|
+
# Specify variational parameters
|
50
|
+
filter_spec = eqx.tree_at(
|
51
|
+
lambda mf: mf.var_params,
|
52
|
+
filter_spec,
|
53
|
+
replace=True,
|
54
|
+
)
|
55
|
+
|
56
|
+
return filter_spec
|
57
|
+
|
58
|
+
|
43
59
|
@eqx.filter_jit
|
44
60
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
45
61
|
# Sample variational draws
|
@@ -59,23 +75,9 @@ class MeanField(Variational):
|
|
59
75
|
sigma=jnp.exp(self.var_params["log_std"]),
|
60
76
|
).sum(axis=1)
|
61
77
|
|
62
|
-
@eqx.filter_jit
|
63
|
-
def filter_spec(self):
|
64
|
-
# Generate empty specification
|
65
|
-
filter_spec = jtu.tree_map(lambda _: False, self)
|
66
|
-
|
67
|
-
# Specify variational parameters
|
68
|
-
filter_spec = eqx.tree_at(
|
69
|
-
lambda mf: mf.var_params,
|
70
|
-
filter_spec,
|
71
|
-
replace=True,
|
72
|
-
)
|
73
|
-
|
74
|
-
return filter_spec
|
75
|
-
|
76
78
|
@eqx.filter_jit
|
77
79
|
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
78
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
80
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
79
81
|
|
80
82
|
@eqx.filter_jit
|
81
83
|
def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
|
@@ -97,7 +99,7 @@ class MeanField(Variational):
|
|
97
99
|
|
98
100
|
@eqx.filter_jit
|
99
101
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
100
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
102
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
101
103
|
|
102
104
|
@eqx.filter_grad
|
103
105
|
@eqx.filter_jit
|
@@ -33,7 +33,7 @@ class NormalizingFlow(Variational):
|
|
33
33
|
- `model`: A probabilistic `Model` object.
|
34
34
|
"""
|
35
35
|
# Partition model
|
36
|
-
params, self._constraints = eqx.partition(model,
|
36
|
+
params, self._constraints = eqx.partition(model, model.filter_spec)
|
37
37
|
|
38
38
|
# Flatten params component
|
39
39
|
_, self._unflatten = jfu.ravel_pytree(params)
|
@@ -41,6 +41,8 @@ class NormalizingFlow(Variational):
|
|
41
41
|
self.base = base
|
42
42
|
self.flows = flows
|
43
43
|
|
44
|
+
@property
|
45
|
+
@eqx.filter_jit
|
44
46
|
def filter_spec(self):
|
45
47
|
# Generate empty specification
|
46
48
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
@@ -49,7 +51,7 @@ class NormalizingFlow(Variational):
|
|
49
51
|
filter_spec = eqx.tree_at(
|
50
52
|
lambda vari: vari.flows,
|
51
53
|
filter_spec,
|
52
|
-
replace=[flow.filter_spec
|
54
|
+
replace=[flow.filter_spec for flow in self.flows],
|
53
55
|
)
|
54
56
|
|
55
57
|
return filter_spec
|
@@ -112,7 +114,7 @@ class NormalizingFlow(Variational):
|
|
112
114
|
|
113
115
|
@eqx.filter_jit
|
114
116
|
def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
|
115
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
117
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
116
118
|
|
117
119
|
@eqx.filter_jit
|
118
120
|
def elbo(dyn: Self, n: int, key: Key, data: Any = None):
|
@@ -129,7 +131,7 @@ class NormalizingFlow(Variational):
|
|
129
131
|
|
130
132
|
@eqx.filter_jit
|
131
133
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
132
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
134
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
133
135
|
|
134
136
|
@eqx.filter_grad
|
135
137
|
@eqx.filter_jit
|
@@ -19,7 +19,7 @@ class Standard(Variational):
|
|
19
19
|
- `dim`: Dimension of the parameter space.
|
20
20
|
"""
|
21
21
|
|
22
|
-
dim: int
|
22
|
+
dim: int
|
23
23
|
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
24
|
_constraints: Model
|
25
25
|
|
@@ -31,7 +31,7 @@ class Standard(Variational):
|
|
31
31
|
- `model`: A probabilistic `Model` object.
|
32
32
|
"""
|
33
33
|
# Partition model
|
34
|
-
params, self._constraints = eqx.partition(model, model.filter_spec
|
34
|
+
params, self._constraints = eqx.partition(model, model.filter_spec)
|
35
35
|
|
36
36
|
# Flatten params component
|
37
37
|
params, self._unflatten = ravel_pytree(params)
|
@@ -54,7 +54,7 @@ class Standard(Variational):
|
|
54
54
|
sigma=jnp.array(1.0),
|
55
55
|
).sum(axis=1, keepdims=True)
|
56
56
|
|
57
|
-
@
|
57
|
+
@property
|
58
58
|
def filter_spec(self):
|
59
59
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
60
60
|
|
@@ -1,11 +1,9 @@
|
|
1
|
-
from typing import Callable, Dict
|
2
1
|
|
3
2
|
import equinox as eqx
|
4
3
|
import jax.numpy as jnp
|
5
4
|
import pytest
|
6
|
-
from jaxtyping import Array
|
7
5
|
|
8
|
-
from bayinx import Model
|
6
|
+
from bayinx import Model, Parameter
|
9
7
|
from bayinx.dists import normal
|
10
8
|
from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
|
11
9
|
from bayinx.mhx.vi.flows import FullAffine, Planar, Radial
|
@@ -16,21 +14,18 @@ from bayinx.mhx.vi.flows import FullAffine, Planar, Radial
|
|
16
14
|
def test_meanfield(benchmark, var_draws):
|
17
15
|
# Construct model definition
|
18
16
|
class NormalDist(Model):
|
19
|
-
params: Dict[str, Array]
|
20
|
-
constraints: Dict[str, Callable[[Array], Array]]
|
21
|
-
|
22
17
|
def __init__(self):
|
23
|
-
self.params = {"mu": jnp.array([0.0, 0.0])}
|
18
|
+
self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
|
24
19
|
self.constraints = {}
|
25
20
|
|
26
21
|
@eqx.filter_jit
|
27
|
-
def eval(self, data
|
22
|
+
def eval(self, data = None):
|
28
23
|
# Get constrained parameters
|
29
|
-
params, target = self.
|
24
|
+
params, target = self.constrain_params()
|
30
25
|
|
31
26
|
# Evaluate mu ~ N(10,1)
|
32
27
|
target += normal.logprob(
|
33
|
-
x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
28
|
+
x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
34
29
|
).sum()
|
35
30
|
|
36
31
|
# Evaluate mu ~ N(10,1)
|
@@ -59,21 +54,19 @@ def test_meanfield(benchmark, var_draws):
|
|
59
54
|
def test_affine(benchmark, var_draws):
|
60
55
|
# Construct model definition
|
61
56
|
class NormalDist(Model):
|
62
|
-
params: Dict[str, Array]
|
63
|
-
constraints: Dict[str, Callable[[Array], Array]]
|
64
57
|
|
65
58
|
def __init__(self):
|
66
|
-
self.params = {"mu": jnp.array([0.0, 0.0])}
|
59
|
+
self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
|
67
60
|
self.constraints = {}
|
68
61
|
|
69
62
|
@eqx.filter_jit
|
70
63
|
def eval(self, data: dict):
|
71
64
|
# Get constrained parameters
|
72
|
-
params, target = self.
|
65
|
+
params, target = self.constrain_params()
|
73
66
|
|
74
67
|
# Evaluate mu ~ N(10,1)
|
75
68
|
target += normal.logprob(
|
76
|
-
x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
69
|
+
x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
77
70
|
).sum()
|
78
71
|
|
79
72
|
# Evaluate mu ~ N(10,1)
|
@@ -92,7 +85,7 @@ def test_affine(benchmark, var_draws):
|
|
92
85
|
benchmark(benchmark_fit)
|
93
86
|
vari = vari.fit(20000, var_draws=var_draws)
|
94
87
|
|
95
|
-
params = vari.flows[0].
|
88
|
+
params = vari.flows[0].transform_params()
|
96
89
|
assert (abs(10.0 - vari.flows[0].params["shift"]) < 0.1).all() and (
|
97
90
|
abs(jnp.eye(2) - params["scale"]) < 0.1
|
98
91
|
).all()
|
@@ -102,21 +95,18 @@ def test_affine(benchmark, var_draws):
|
|
102
95
|
def test_flows(benchmark, var_draws):
|
103
96
|
# Construct model definition
|
104
97
|
class NormalDist(Model):
|
105
|
-
params: Dict[str, Array]
|
106
|
-
constraints: Dict[str, Callable[[Array], Array]]
|
107
|
-
|
108
98
|
def __init__(self):
|
109
|
-
self.params = {"mu": jnp.array([0.0, 0.0])}
|
99
|
+
self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
|
110
100
|
self.constraints = {}
|
111
101
|
|
112
102
|
@eqx.filter_jit
|
113
103
|
def eval(self, data: dict):
|
114
104
|
# Get constrained parameters
|
115
|
-
params, target = self.
|
105
|
+
params, target = self.constrain_params()
|
116
106
|
|
117
107
|
# Evaluate mu ~ N(10,1)
|
118
108
|
target += normal.logprob(
|
119
|
-
x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
109
|
+
x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
120
110
|
).sum()
|
121
111
|
|
122
112
|
return target
|
@@ -1 +0,0 @@
|
|
1
|
-
from bayinx.core.model import Model as Model
|
@@ -1,37 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
|
5
|
-
|
6
|
-
from bayinx.core.constraint import Constraint
|
7
|
-
|
8
|
-
|
9
|
-
class Lower(Constraint):
|
10
|
-
"""
|
11
|
-
Enforces a lower bound on the parameter.
|
12
|
-
"""
|
13
|
-
|
14
|
-
lb: ScalarLike
|
15
|
-
|
16
|
-
def __init__(self, lb: ScalarLike):
|
17
|
-
self.lb = lb
|
18
|
-
|
19
|
-
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
20
|
-
"""
|
21
|
-
Applies the lower bound constraint and adjusts the posterior density.
|
22
|
-
|
23
|
-
# Parameters
|
24
|
-
- `x`: The unconstrained JAX Array-like input.
|
25
|
-
|
26
|
-
# Parameters
|
27
|
-
A tuple containing:
|
28
|
-
- The constrained JAX Array (x > self.lb).
|
29
|
-
- A scalar JAX Array representing the log-absolute-Jacobian of the transformation.
|
30
|
-
"""
|
31
|
-
# Compute transformation adjustment
|
32
|
-
laj: Scalar = jnp.sum(x)
|
33
|
-
|
34
|
-
# Compute transformation
|
35
|
-
x = jnp.exp(x) + self.lb
|
36
|
-
|
37
|
-
return x, laj
|
@@ -1,26 +0,0 @@
|
|
1
|
-
from abc import abstractmethod
|
2
|
-
from typing import Tuple
|
3
|
-
|
4
|
-
import equinox as eqx
|
5
|
-
from jaxtyping import Array, ArrayLike, Scalar
|
6
|
-
|
7
|
-
|
8
|
-
class Constraint(eqx.Module):
|
9
|
-
"""
|
10
|
-
Abstract base class for defining parameter constraints.
|
11
|
-
"""
|
12
|
-
|
13
|
-
@abstractmethod
|
14
|
-
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
15
|
-
"""
|
16
|
-
Applies the constraining transformation to an unconstrained input and computes the log-absolute-Jacobian of the transformation.
|
17
|
-
|
18
|
-
# Parameters
|
19
|
-
- `x`: The unconstrained JAX Array-like input.
|
20
|
-
|
21
|
-
# Returns
|
22
|
-
A tuple containing:
|
23
|
-
- The constrained JAX Array.
|
24
|
-
- A scalar JAX Array representing the log-absolute-Jacobian of the transformation.
|
25
|
-
"""
|
26
|
-
pass
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|