bayinx 0.3.1__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bayinx-0.3.1 → bayinx-0.3.2}/PKG-INFO +1 -1
- {bayinx-0.3.1 → bayinx-0.3.2}/pyproject.toml +2 -2
- bayinx-0.3.2/src/bayinx/constraints/__init__.py +1 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/model.py +6 -6
- {bayinx-0.3.1 → bayinx-0.3.2}/tests/test_variational.py +16 -41
- bayinx-0.3.1/tests/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/.github/workflows/release_and_publish.yml +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/.gitignore +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/README.md +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/constraints/lower.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/constraint.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/flow.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/parameter.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/variational.py +0 -0
- {bayinx-0.3.1/src/bayinx/constraints → bayinx-0.3.2/src/bayinx/dists}/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/bernoulli.py +0 -0
- {bayinx-0.3.1/src/bayinx/dists → bayinx-0.3.2/src/bayinx/dists/censored}/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/censored/gamma2/r.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/gamma2.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/normal.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/uniform.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/planar.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/radial.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/meanfield.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/normalizing_flow.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/standard.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/py.typed +0 -0
- {bayinx-0.3.1/src/bayinx/dists/censored → bayinx-0.3.2/tests}/__init__.py +0 -0
- {bayinx-0.3.1 → bayinx-0.3.2}/uv.lock +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "bayinx"
|
3
|
-
version = "0.3.
|
3
|
+
version = "0.3.2"
|
4
4
|
description = "Bayesian Inference with JAX"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.12"
|
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
|
|
19
19
|
addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
|
20
20
|
|
21
21
|
[tool.bumpversion]
|
22
|
-
current_version = "0.3.
|
22
|
+
current_version = "0.3.2"
|
23
23
|
parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
|
24
24
|
serialize = ["{major}.{minor}.{patch}"]
|
25
25
|
search = "{current_version}"
|
@@ -0,0 +1 @@
|
|
1
|
+
from bayinx.constraints.lower import Lower as Lower
|
@@ -9,8 +9,8 @@ from jaxtyping import PyTree, Scalar
|
|
9
9
|
from bayinx.core.constraint import Constraint
|
10
10
|
from bayinx.core.parameter import Parameter
|
11
11
|
|
12
|
-
|
13
|
-
class Model(eqx.Module, Generic[
|
12
|
+
P = TypeVar('P', bound=Dict[str, Parameter[PyTree]])
|
13
|
+
class Model(eqx.Module, Generic[P]):
|
14
14
|
"""
|
15
15
|
An abstract base class used to define probabilistic models.
|
16
16
|
|
@@ -19,7 +19,7 @@ class Model(eqx.Module, Generic[T]):
|
|
19
19
|
- `constraints`: A dictionary of constraints.
|
20
20
|
"""
|
21
21
|
|
22
|
-
params:
|
22
|
+
params: P
|
23
23
|
constraints: Dict[str, Constraint]
|
24
24
|
|
25
25
|
@abstractmethod
|
@@ -47,14 +47,14 @@ class Model(eqx.Module, Generic[T]):
|
|
47
47
|
|
48
48
|
# Add constrain method
|
49
49
|
@eqx.filter_jit
|
50
|
-
def constrain_params(self) -> Tuple[
|
50
|
+
def constrain_params(self) -> Tuple[P, Scalar]:
|
51
51
|
"""
|
52
52
|
Constrain `params` to the appropriate domain.
|
53
53
|
|
54
54
|
# Returns
|
55
55
|
A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
|
56
56
|
"""
|
57
|
-
t_params:
|
57
|
+
t_params: P = self.params
|
58
58
|
target: Scalar = jnp.array(0.0)
|
59
59
|
|
60
60
|
for par, map in self.constraints.items():
|
@@ -68,7 +68,7 @@ class Model(eqx.Module, Generic[T]):
|
|
68
68
|
|
69
69
|
# Add default transform method
|
70
70
|
@eqx.filter_jit
|
71
|
-
def transform_params(self) -> Tuple[
|
71
|
+
def transform_params(self) -> Tuple[P, Scalar]:
|
72
72
|
"""
|
73
73
|
Apply a custom transformation to `params` if needed.
|
74
74
|
|
@@ -1,4 +1,3 @@
|
|
1
|
-
|
2
1
|
from typing import Dict
|
3
2
|
|
4
3
|
import equinox as eqx
|
@@ -12,30 +11,26 @@ from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
|
|
12
11
|
from bayinx.mhx.vi.flows import FullAffine, Planar, Radial
|
13
12
|
|
14
13
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
class NormalDist(Model[Array]):
|
20
|
-
params: Dict[str, Parameter[Array]]
|
14
|
+
class NormalDist(Model[Dict[str, Parameter[Array]]]):
|
15
|
+
def __init__(self):
|
16
|
+
self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
|
17
|
+
self.constraints = {}
|
21
18
|
|
22
|
-
|
23
|
-
|
24
|
-
|
19
|
+
@eqx.filter_jit
|
20
|
+
def eval(self, data = None):
|
21
|
+
# Get constrained parameters
|
22
|
+
params, target = self.constrain_params()
|
25
23
|
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
24
|
+
# Evaluate mu ~ N(10,1)
|
25
|
+
target += normal.logprob(
|
26
|
+
x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
27
|
+
).sum()
|
30
28
|
|
31
|
-
|
32
|
-
target += normal.logprob(
|
33
|
-
x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
34
|
-
).sum()
|
35
|
-
|
36
|
-
# Evaluate mu ~ N(10,1)
|
37
|
-
return target
|
29
|
+
return target
|
38
30
|
|
31
|
+
# Tests ----
|
32
|
+
@pytest.mark.parametrize("var_draws", [1, 10, 100])
|
33
|
+
def test_meanfield(benchmark, var_draws):
|
39
34
|
# Construct model
|
40
35
|
model = NormalDist()
|
41
36
|
|
@@ -57,26 +52,6 @@ def test_meanfield(benchmark, var_draws):
|
|
57
52
|
|
58
53
|
@pytest.mark.parametrize("var_draws", [1, 10, 100])
|
59
54
|
def test_affine(benchmark, var_draws):
|
60
|
-
# Construct model definition
|
61
|
-
class NormalDist(Model):
|
62
|
-
|
63
|
-
def __init__(self):
|
64
|
-
self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
|
65
|
-
self.constraints = {}
|
66
|
-
|
67
|
-
@eqx.filter_jit
|
68
|
-
def eval(self, data: dict):
|
69
|
-
# Get constrained parameters
|
70
|
-
params, target = self.constrain_params()
|
71
|
-
|
72
|
-
# Evaluate mu ~ N(10,1)
|
73
|
-
target += normal.logprob(
|
74
|
-
x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
|
75
|
-
).sum()
|
76
|
-
|
77
|
-
# Evaluate mu ~ N(10,1)
|
78
|
-
return target
|
79
|
-
|
80
55
|
# Construct model
|
81
56
|
model = NormalDist()
|
82
57
|
|
bayinx-0.3.1/tests/__init__.py
DELETED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|