bayinx 0.3.1__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bayinx-0.3.1 → bayinx-0.3.3}/PKG-INFO +1 -1
- {bayinx-0.3.1 → bayinx-0.3.3}/pyproject.toml +2 -2
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/__init__.py +1 -0
- bayinx-0.3.3/src/bayinx/constraints/__init__.py +1 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/core/flow.py +1 -1
- bayinx-0.3.3/src/bayinx/core/model.py +102 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/core/parameter.py +3 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/tests/test_variational.py +20 -64
- bayinx-0.3.1/src/bayinx/core/model.py +0 -78
- bayinx-0.3.1/tests/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/.github/workflows/release_and_publish.yml +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/.gitignore +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/README.md +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/constraints/lower.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/core/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/core/constraint.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/core/variational.py +0 -0
- {bayinx-0.3.1/src/bayinx/constraints → bayinx-0.3.3/src/bayinx/dists}/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/dists/bernoulli.py +0 -0
- {bayinx-0.3.1/src/bayinx/dists → bayinx-0.3.3/src/bayinx/dists/censored}/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/dists/censored/gamma2/r.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/dists/gamma2.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/dists/normal.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/dists/uniform.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/vi/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/vi/flows/planar.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/vi/flows/radial.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/vi/meanfield.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/vi/normalizing_flow.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/mhx/vi/standard.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/src/bayinx/py.typed +0 -0
- {bayinx-0.3.1/src/bayinx/dists/censored → bayinx-0.3.3/tests}/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.3}/uv.lock +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "bayinx"
|
3
|
-
version = "0.3.
|
3
|
+
version = "0.3.3"
|
4
4
|
description = "Bayesian Inference with JAX"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.12"
|
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
|
|
19
19
|
addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
|
20
20
|
|
21
21
|
[tool.bumpversion]
|
22
|
-
current_version = "0.3.
|
22
|
+
current_version = "0.3.3"
|
23
23
|
parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
|
24
24
|
serialize = ["{major}.{minor}.{patch}"]
|
25
25
|
search = "{current_version}"
|
@@ -0,0 +1 @@
|
|
1
|
+
from bayinx.constraints.lower import Lower as Lower
|
@@ -11,7 +11,7 @@ class Flow(eqx.Module):
|
|
11
11
|
An abstract base class for a flow(of a normalizing flow).
|
12
12
|
|
13
13
|
# Attributes
|
14
|
-
- `
|
14
|
+
- `params`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
|
15
15
|
- `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
|
16
16
|
"""
|
17
17
|
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from dataclasses import field, fields
|
3
|
+
from typing import Any, Self, Tuple
|
4
|
+
|
5
|
+
import equinox as eqx
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax.tree as jt
|
8
|
+
from jaxtyping import Scalar
|
9
|
+
|
10
|
+
from bayinx.core.constraint import Constraint
|
11
|
+
from bayinx.core.parameter import Parameter
|
12
|
+
|
13
|
+
|
14
|
+
def constrain(constraint: Constraint):
|
15
|
+
"""Defines constraint metadata."""
|
16
|
+
return field(metadata={'constraint': constraint})
|
17
|
+
|
18
|
+
|
19
|
+
class Model(eqx.Module):
|
20
|
+
"""
|
21
|
+
An abstract base class used to define probabilistic models.
|
22
|
+
|
23
|
+
Annotate parameter attributes with `Parameter`.
|
24
|
+
|
25
|
+
Include constraints by setting them equal to `constrain(Constraint)`.
|
26
|
+
"""
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def eval(self, data: Any) -> Scalar:
|
30
|
+
pass
|
31
|
+
|
32
|
+
# Default filter specification
|
33
|
+
@property
|
34
|
+
@eqx.filter_jit
|
35
|
+
def filter_spec(self) -> Self:
|
36
|
+
"""
|
37
|
+
Generates a filter specification to subset relevant parameters for the model.
|
38
|
+
"""
|
39
|
+
# Generate empty specification
|
40
|
+
filter_spec: Self = jt.map(lambda _: False, self)
|
41
|
+
|
42
|
+
for f in fields(self):
|
43
|
+
# Extract attribute from field
|
44
|
+
attr = getattr(self, f.name)
|
45
|
+
|
46
|
+
# Check if attribute is a parameter
|
47
|
+
if isinstance(attr, Parameter):
|
48
|
+
# Update filter specification for parameter
|
49
|
+
filter_spec = eqx.tree_at(
|
50
|
+
lambda model: getattr(model, f.name),
|
51
|
+
filter_spec,
|
52
|
+
replace=attr.filter_spec
|
53
|
+
)
|
54
|
+
|
55
|
+
return filter_spec
|
56
|
+
|
57
|
+
|
58
|
+
@eqx.filter_jit
|
59
|
+
def constrain_params(self) -> Tuple[Self, Scalar]:
|
60
|
+
"""
|
61
|
+
Constrain parameters to the appropriate domain.
|
62
|
+
|
63
|
+
# Returns
|
64
|
+
A constrained `Model` object and the adjustment to the posterior.
|
65
|
+
"""
|
66
|
+
constrained: Self = self
|
67
|
+
target: Scalar = jnp.array(0.0)
|
68
|
+
|
69
|
+
for f in fields(self):
|
70
|
+
# Extract attribute
|
71
|
+
attr = getattr(self, f.name)
|
72
|
+
|
73
|
+
# Check if constrained parameter
|
74
|
+
if isinstance(attr, Parameter) and 'constraint' in f.metadata:
|
75
|
+
param = attr
|
76
|
+
constraint = f.metadata['constraint']
|
77
|
+
|
78
|
+
# Apply constraint
|
79
|
+
param, laj = constraint.constrain(param)
|
80
|
+
|
81
|
+
# Update parameters for constrained model
|
82
|
+
constrained = eqx.tree_at(
|
83
|
+
lambda model: getattr(model, f.name),
|
84
|
+
constrained,
|
85
|
+
replace=param
|
86
|
+
)
|
87
|
+
|
88
|
+
# Adjust posterior density
|
89
|
+
target += laj
|
90
|
+
|
91
|
+
return constrained, target
|
92
|
+
|
93
|
+
|
94
|
+
@eqx.filter_jit
|
95
|
+
def transform_params(self) -> Tuple[Self, Scalar]:
|
96
|
+
"""
|
97
|
+
Apply a custom transformation to parameters if needed(defaults to constrained parameters).
|
98
|
+
|
99
|
+
# Returns
|
100
|
+
A transformed `Model` object and the adjustment to the posterior.
|
101
|
+
"""
|
102
|
+
return self.constrain_params()
|
@@ -1,7 +1,5 @@
|
|
1
|
-
|
2
1
|
from typing import Dict
|
3
2
|
|
4
|
-
import equinox as eqx
|
5
3
|
import jax.numpy as jnp
|
6
4
|
import pytest
|
7
5
|
from jaxtyping import Array
|
@@ -12,30 +10,25 @@ from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
|
|
12
10
|
from bayinx.mhx.vi.flows import FullAffine, Planar, Radial
|
13
11
|
|
14
12
|
|
15
|
-
|
16
|
-
|
17
|
-
def test_meanfield(benchmark, var_draws):
|
18
|
-
# Construct model definition
|
19
|
-
class NormalDist(Model[Array]):
|
20
|
-
params: Dict[str, Parameter[Array]]
|
13
|
+
class NormalDist(Model):
|
14
|
+
x: Parameter[Array]
|
21
15
|
|
22
|
-
|
23
|
-
|
24
|
-
self.constraints = {}
|
16
|
+
def __init__(self):
|
17
|
+
self.x = Parameter(jnp.array([0.0, 0.0]))
|
25
18
|
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
params, target = self.constrain_params()
|
19
|
+
def eval(self, data: Dict[str, Array]):
|
20
|
+
# Constrain parameters
|
21
|
+
self, target = self.constrain_params()
|
30
22
|
|
31
|
-
|
32
|
-
|
33
|
-
x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
34
|
-
).sum()
|
23
|
+
# Evaluate x ~ Normal(10.0, 1.0)
|
24
|
+
target += jnp.sum(normal.logprob(self.x(), jnp.array(10.0), jnp.array(1.0)))
|
35
25
|
|
36
|
-
|
37
|
-
return target
|
26
|
+
return target
|
38
27
|
|
28
|
+
|
29
|
+
# Tests ----
|
30
|
+
@pytest.mark.parametrize("var_draws", [1, 100])
|
31
|
+
def test_meanfield(benchmark, var_draws):
|
39
32
|
# Construct model
|
40
33
|
model = NormalDist()
|
41
34
|
|
@@ -46,8 +39,8 @@ def test_meanfield(benchmark, var_draws):
|
|
46
39
|
def benchmark_fit():
|
47
40
|
vari.fit(10000, var_draws=var_draws)
|
48
41
|
|
49
|
-
benchmark(benchmark_fit)
|
50
42
|
vari = vari.fit(20000, var_draws=var_draws)
|
43
|
+
benchmark(benchmark_fit)
|
51
44
|
|
52
45
|
# Assert parameters are roughly correct
|
53
46
|
assert all(abs(10.0 - vari.var_params["mean"]) < 0.1) and all(
|
@@ -55,28 +48,8 @@ def test_meanfield(benchmark, var_draws):
|
|
55
48
|
)
|
56
49
|
|
57
50
|
|
58
|
-
@pytest.mark.parametrize("var_draws", [1,
|
51
|
+
@pytest.mark.parametrize("var_draws", [1, 100])
|
59
52
|
def test_affine(benchmark, var_draws):
|
60
|
-
# Construct model definition
|
61
|
-
class NormalDist(Model):
|
62
|
-
|
63
|
-
def __init__(self):
|
64
|
-
self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
|
65
|
-
self.constraints = {}
|
66
|
-
|
67
|
-
@eqx.filter_jit
|
68
|
-
def eval(self, data: dict):
|
69
|
-
# Get constrained parameters
|
70
|
-
params, target = self.constrain_params()
|
71
|
-
|
72
|
-
# Evaluate mu ~ N(10,1)
|
73
|
-
target += normal.logprob(
|
74
|
-
x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
75
|
-
).sum()
|
76
|
-
|
77
|
-
# Evaluate mu ~ N(10,1)
|
78
|
-
return target
|
79
|
-
|
80
53
|
# Construct model
|
81
54
|
model = NormalDist()
|
82
55
|
|
@@ -87,8 +60,8 @@ def test_affine(benchmark, var_draws):
|
|
87
60
|
def benchmark_fit():
|
88
61
|
vari.fit(10000, var_draws=var_draws)
|
89
62
|
|
90
|
-
benchmark(benchmark_fit)
|
91
63
|
vari = vari.fit(20000, var_draws=var_draws)
|
64
|
+
benchmark(benchmark_fit)
|
92
65
|
|
93
66
|
params = vari.flows[0].transform_params()
|
94
67
|
assert (abs(10.0 - vari.flows[0].params["shift"]) < 0.1).all() and (
|
@@ -96,26 +69,8 @@ def test_affine(benchmark, var_draws):
|
|
96
69
|
).all()
|
97
70
|
|
98
71
|
|
99
|
-
@pytest.mark.parametrize("var_draws", [1,
|
72
|
+
@pytest.mark.parametrize("var_draws", [1, 100])
|
100
73
|
def test_flows(benchmark, var_draws):
|
101
|
-
# Construct model definition
|
102
|
-
class NormalDist(Model):
|
103
|
-
def __init__(self):
|
104
|
-
self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
|
105
|
-
self.constraints = {}
|
106
|
-
|
107
|
-
@eqx.filter_jit
|
108
|
-
def eval(self, data: dict):
|
109
|
-
# Get constrained parameters
|
110
|
-
params, target = self.constrain_params()
|
111
|
-
|
112
|
-
# Evaluate mu ~ N(10,1)
|
113
|
-
target += normal.logprob(
|
114
|
-
x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
115
|
-
).sum()
|
116
|
-
|
117
|
-
return target
|
118
|
-
|
119
74
|
# Construct model
|
120
75
|
model = NormalDist()
|
121
76
|
|
@@ -128,8 +83,9 @@ def test_flows(benchmark, var_draws):
|
|
128
83
|
def benchmark_fit():
|
129
84
|
vari.fit(10000, var_draws=var_draws)
|
130
85
|
|
131
|
-
benchmark(benchmark_fit)
|
132
86
|
vari = vari.fit(20000, var_draws=var_draws)
|
87
|
+
benchmark(benchmark_fit)
|
88
|
+
|
133
89
|
|
134
90
|
mean = vari.sample(1000).mean(0)
|
135
91
|
var = vari.sample(1000).var(0)
|
@@ -1,78 +0,0 @@
|
|
1
|
-
from abc import abstractmethod
|
2
|
-
from typing import Any, Dict, Generic, Tuple, TypeVar
|
3
|
-
|
4
|
-
import equinox as eqx
|
5
|
-
import jax.numpy as jnp
|
6
|
-
import jax.tree as jt
|
7
|
-
from jaxtyping import PyTree, Scalar
|
8
|
-
|
9
|
-
from bayinx.core.constraint import Constraint
|
10
|
-
from bayinx.core.parameter import Parameter
|
11
|
-
|
12
|
-
T = TypeVar('T', bound=PyTree)
|
13
|
-
class Model(eqx.Module, Generic[T]):
|
14
|
-
"""
|
15
|
-
An abstract base class used to define probabilistic models.
|
16
|
-
|
17
|
-
# Attributes
|
18
|
-
- `params`: A dictionary of parameters.
|
19
|
-
- `constraints`: A dictionary of constraints.
|
20
|
-
"""
|
21
|
-
|
22
|
-
params: Dict[str, Parameter[T]]
|
23
|
-
constraints: Dict[str, Constraint]
|
24
|
-
|
25
|
-
@abstractmethod
|
26
|
-
def eval(self, data: Any) -> Scalar:
|
27
|
-
pass
|
28
|
-
|
29
|
-
# Default filter specification
|
30
|
-
@property
|
31
|
-
@eqx.filter_jit
|
32
|
-
def filter_spec(self):
|
33
|
-
"""
|
34
|
-
Generates a filter specification to subset relevant parameters for the model.
|
35
|
-
"""
|
36
|
-
# Generate empty specification
|
37
|
-
filter_spec = jt.map(lambda _: False, self)
|
38
|
-
|
39
|
-
# Specify relevant parameters
|
40
|
-
filter_spec = eqx.tree_at(
|
41
|
-
lambda model: model.params,
|
42
|
-
filter_spec,
|
43
|
-
replace={key: param.filter_spec for key, param in self.params.items()}
|
44
|
-
)
|
45
|
-
|
46
|
-
return filter_spec
|
47
|
-
|
48
|
-
# Add constrain method
|
49
|
-
@eqx.filter_jit
|
50
|
-
def constrain_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
|
51
|
-
"""
|
52
|
-
Constrain `params` to the appropriate domain.
|
53
|
-
|
54
|
-
# Returns
|
55
|
-
A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
|
56
|
-
"""
|
57
|
-
t_params: Dict[str, Parameter[T]] = self.params
|
58
|
-
target: Scalar = jnp.array(0.0)
|
59
|
-
|
60
|
-
for par, map in self.constraints.items():
|
61
|
-
# Apply transformation
|
62
|
-
t_params[par], ladj = map.constrain(t_params[par])
|
63
|
-
|
64
|
-
# Adjust posterior density
|
65
|
-
target -= ladj
|
66
|
-
|
67
|
-
return t_params, target
|
68
|
-
|
69
|
-
# Add default transform method
|
70
|
-
@eqx.filter_jit
|
71
|
-
def transform_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
|
72
|
-
"""
|
73
|
-
Apply a custom transformation to `params` if needed.
|
74
|
-
|
75
|
-
# Returns
|
76
|
-
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
77
|
-
"""
|
78
|
-
return self.constrain_params()
|
bayinx-0.3.1/tests/__init__.py
DELETED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|