bayinx 0.3.2__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {bayinx-0.3.2 → bayinx-0.3.4}/PKG-INFO +1 -1
  2. {bayinx-0.3.2 → bayinx-0.3.4}/pyproject.toml +2 -2
  3. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/__init__.py +1 -0
  4. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/core/flow.py +1 -1
  5. bayinx-0.3.4/src/bayinx/core/model.py +102 -0
  6. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/core/parameter.py +3 -0
  7. bayinx-0.3.4/src/bayinx/dists/censored/gamma2/__init__.py +1 -0
  8. {bayinx-0.3.2 → bayinx-0.3.4}/tests/test_variational.py +17 -36
  9. {bayinx-0.3.2 → bayinx-0.3.4}/uv.lock +1 -1
  10. bayinx-0.3.2/src/bayinx/core/model.py +0 -78
  11. {bayinx-0.3.2 → bayinx-0.3.4}/.github/workflows/release_and_publish.yml +0 -0
  12. {bayinx-0.3.2 → bayinx-0.3.4}/.gitignore +0 -0
  13. {bayinx-0.3.2 → bayinx-0.3.4}/README.md +0 -0
  14. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/constraints/__init__.py +0 -0
  15. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/constraints/lower.py +0 -0
  16. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/core/__init__.py +0 -0
  17. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/core/constraint.py +0 -0
  18. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/core/variational.py +0 -0
  19. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/dists/__init__.py +0 -0
  20. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/dists/bernoulli.py +0 -0
  21. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/dists/censored/__init__.py +0 -0
  22. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/dists/censored/gamma2/r.py +0 -0
  23. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/dists/gamma2.py +0 -0
  24. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/dists/normal.py +0 -0
  25. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/dists/uniform.py +0 -0
  26. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/__init__.py +0 -0
  27. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/vi/__init__.py +0 -0
  28. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  29. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
  30. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  31. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  32. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  33. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/vi/meanfield.py +0 -0
  34. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/vi/normalizing_flow.py +0 -0
  35. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/mhx/vi/standard.py +0 -0
  36. {bayinx-0.3.2 → bayinx-0.3.4}/src/bayinx/py.typed +0 -0
  37. {bayinx-0.3.2 → bayinx-0.3.4}/tests/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.3.2"
3
+ version = "0.3.4"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
19
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
20
 
21
21
  [tool.bumpversion]
22
- current_version = "0.3.2"
22
+ current_version = "0.3.4"
23
23
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
24
  serialize = ["{major}.{minor}.{patch}"]
25
25
  search = "{current_version}"
@@ -1,2 +1,3 @@
1
1
  from bayinx.core import Model as Model
2
2
  from bayinx.core import Parameter as Parameter
3
+ from bayinx.core.model import constrain as constrain
@@ -11,7 +11,7 @@ class Flow(eqx.Module):
11
11
  An abstract base class for a flow(of a normalizing flow).
12
12
 
13
13
  # Attributes
14
- - `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
14
+ - `params`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
15
15
  - `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
16
16
  """
17
17
 
@@ -0,0 +1,102 @@
1
+ from abc import abstractmethod
2
+ from dataclasses import field, fields
3
+ from typing import Any, Self, Tuple
4
+
5
+ import equinox as eqx
6
+ import jax.numpy as jnp
7
+ import jax.tree as jt
8
+ from jaxtyping import Scalar
9
+
10
+ from bayinx.core.constraint import Constraint
11
+ from bayinx.core.parameter import Parameter
12
+
13
+
14
+ def constrain(constraint: Constraint):
15
+ """Defines constraint metadata."""
16
+ return field(metadata={'constraint': constraint})
17
+
18
+
19
+ class Model(eqx.Module):
20
+ """
21
+ An abstract base class used to define probabilistic models.
22
+
23
+ Annotate parameter attributes with `Parameter`.
24
+
25
+ Include constraints by setting them equal to `constrain(Constraint)`.
26
+ """
27
+
28
+ @abstractmethod
29
+ def eval(self, data: Any) -> Scalar:
30
+ pass
31
+
32
+ # Default filter specification
33
+ @property
34
+ @eqx.filter_jit
35
+ def filter_spec(self) -> Self:
36
+ """
37
+ Generates a filter specification to subset relevant parameters for the model.
38
+ """
39
+ # Generate empty specification
40
+ filter_spec: Self = jt.map(lambda _: False, self)
41
+
42
+ for f in fields(self):
43
+ # Extract attribute from field
44
+ attr = getattr(self, f.name)
45
+
46
+ # Check if attribute is a parameter
47
+ if isinstance(attr, Parameter):
48
+ # Update filter specification for parameter
49
+ filter_spec = eqx.tree_at(
50
+ lambda model: getattr(model, f.name),
51
+ filter_spec,
52
+ replace=attr.filter_spec
53
+ )
54
+
55
+ return filter_spec
56
+
57
+
58
+ @eqx.filter_jit
59
+ def constrain_params(self) -> Tuple[Self, Scalar]:
60
+ """
61
+ Constrain parameters to the appropriate domain.
62
+
63
+ # Returns
64
+ A constrained `Model` object and the adjustment to the posterior.
65
+ """
66
+ constrained: Self = self
67
+ target: Scalar = jnp.array(0.0)
68
+
69
+ for f in fields(self):
70
+ # Extract attribute
71
+ attr = getattr(self, f.name)
72
+
73
+ # Check if constrained parameter
74
+ if isinstance(attr, Parameter) and 'constraint' in f.metadata:
75
+ param = attr
76
+ constraint = f.metadata['constraint']
77
+
78
+ # Apply constraint
79
+ param, laj = constraint.constrain(param)
80
+
81
+ # Update parameters for constrained model
82
+ constrained = eqx.tree_at(
83
+ lambda model: getattr(model, f.name),
84
+ constrained,
85
+ replace=param
86
+ )
87
+
88
+ # Adjust posterior density
89
+ target += laj
90
+
91
+ return constrained, target
92
+
93
+
94
+ @eqx.filter_jit
95
+ def transform_params(self) -> Tuple[Self, Scalar]:
96
+ """
97
+ Apply a custom transformation to parameters if needed(defaults to constrained parameters).
98
+
99
+ # Returns
100
+ A transformed `Model` object and the adjustment to the posterior.
101
+ """
102
+ return self.constrain_params()
@@ -21,6 +21,9 @@ class Parameter(eqx.Module, Generic[T]):
21
21
  # Insert parameter values
22
22
  self.vals = values
23
23
 
24
+ def __call__(self) -> T:
25
+ return self.vals
26
+
24
27
  # Default filter specification
25
28
  @property
26
29
  @eqx.filter_jit
@@ -0,0 +1 @@
1
+ from . import r as r
@@ -1,6 +1,5 @@
1
1
  from typing import Dict
2
2
 
3
- import equinox as eqx
4
3
  import jax.numpy as jnp
5
4
  import pytest
6
5
  from jaxtyping import Array
@@ -11,25 +10,24 @@ from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
11
10
  from bayinx.mhx.vi.flows import FullAffine, Planar, Radial
12
11
 
13
12
 
14
- class NormalDist(Model[Dict[str, Parameter[Array]]]):
13
+ class NormalDist(Model):
14
+ x: Parameter[Array]
15
+
15
16
  def __init__(self):
16
- self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
17
- self.constraints = {}
17
+ self.x = Parameter(jnp.array([0.0, 0.0]))
18
18
 
19
- @eqx.filter_jit
20
- def eval(self, data = None):
21
- # Get constrained parameters
22
- params, target = self.constrain_params()
19
+ def eval(self, data: Dict[str, Array]):
20
+ # Constrain parameters
21
+ self, target = self.constrain_params()
23
22
 
24
- # Evaluate mu ~ N(10,1)
25
- target += normal.logprob(
26
- x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
27
- ).sum()
23
+ # Evaluate x ~ Normal(10.0, 1.0)
24
+ target += jnp.sum(normal.logprob(self.x(), jnp.array(10.0), jnp.array(1.0)))
28
25
 
29
26
  return target
30
27
 
28
+
31
29
  # Tests ----
32
- @pytest.mark.parametrize("var_draws", [1, 10, 100])
30
+ @pytest.mark.parametrize("var_draws", [1, 100])
33
31
  def test_meanfield(benchmark, var_draws):
34
32
  # Construct model
35
33
  model = NormalDist()
@@ -41,8 +39,8 @@ def test_meanfield(benchmark, var_draws):
41
39
  def benchmark_fit():
42
40
  vari.fit(10000, var_draws=var_draws)
43
41
 
44
- benchmark(benchmark_fit)
45
42
  vari = vari.fit(20000, var_draws=var_draws)
43
+ benchmark(benchmark_fit)
46
44
 
47
45
  # Assert parameters are roughly correct
48
46
  assert all(abs(10.0 - vari.var_params["mean"]) < 0.1) and all(
@@ -50,7 +48,7 @@ def test_meanfield(benchmark, var_draws):
50
48
  )
51
49
 
52
50
 
53
- @pytest.mark.parametrize("var_draws", [1, 10, 100])
51
+ @pytest.mark.parametrize("var_draws", [1, 100])
54
52
  def test_affine(benchmark, var_draws):
55
53
  # Construct model
56
54
  model = NormalDist()
@@ -62,8 +60,8 @@ def test_affine(benchmark, var_draws):
62
60
  def benchmark_fit():
63
61
  vari.fit(10000, var_draws=var_draws)
64
62
 
65
- benchmark(benchmark_fit)
66
63
  vari = vari.fit(20000, var_draws=var_draws)
64
+ benchmark(benchmark_fit)
67
65
 
68
66
  params = vari.flows[0].transform_params()
69
67
  assert (abs(10.0 - vari.flows[0].params["shift"]) < 0.1).all() and (
@@ -71,26 +69,8 @@ def test_affine(benchmark, var_draws):
71
69
  ).all()
72
70
 
73
71
 
74
- @pytest.mark.parametrize("var_draws", [1, 10, 100])
72
+ @pytest.mark.parametrize("var_draws", [1, 100])
75
73
  def test_flows(benchmark, var_draws):
76
- # Construct model definition
77
- class NormalDist(Model):
78
- def __init__(self):
79
- self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
80
- self.constraints = {}
81
-
82
- @eqx.filter_jit
83
- def eval(self, data: dict):
84
- # Get constrained parameters
85
- params, target = self.constrain_params()
86
-
87
- # Evaluate mu ~ N(10,1)
88
- target += normal.logprob(
89
- x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
90
- ).sum()
91
-
92
- return target
93
-
94
74
  # Construct model
95
75
  model = NormalDist()
96
76
 
@@ -103,8 +83,9 @@ def test_flows(benchmark, var_draws):
103
83
  def benchmark_fit():
104
84
  vari.fit(10000, var_draws=var_draws)
105
85
 
106
- benchmark(benchmark_fit)
107
86
  vari = vari.fit(20000, var_draws=var_draws)
87
+ benchmark(benchmark_fit)
88
+
108
89
 
109
90
  mean = vari.sample(1000).mean(0)
110
91
  var = vari.sample(1000).var(0)
@@ -17,7 +17,7 @@ wheels = [
17
17
 
18
18
  [[package]]
19
19
  name = "bayinx"
20
- version = "0.2.11"
20
+ version = "0.3.3"
21
21
  source = { editable = "." }
22
22
  dependencies = [
23
23
  { name = "equinox" },
@@ -1,78 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Dict, Generic, Tuple, TypeVar
3
-
4
- import equinox as eqx
5
- import jax.numpy as jnp
6
- import jax.tree as jt
7
- from jaxtyping import PyTree, Scalar
8
-
9
- from bayinx.core.constraint import Constraint
10
- from bayinx.core.parameter import Parameter
11
-
12
- P = TypeVar('P', bound=Dict[str, Parameter[PyTree]])
13
- class Model(eqx.Module, Generic[P]):
14
- """
15
- An abstract base class used to define probabilistic models.
16
-
17
- # Attributes
18
- - `params`: A dictionary of parameters.
19
- - `constraints`: A dictionary of constraints.
20
- """
21
-
22
- params: P
23
- constraints: Dict[str, Constraint]
24
-
25
- @abstractmethod
26
- def eval(self, data: Any) -> Scalar:
27
- pass
28
-
29
- # Default filter specification
30
- @property
31
- @eqx.filter_jit
32
- def filter_spec(self):
33
- """
34
- Generates a filter specification to subset relevant parameters for the model.
35
- """
36
- # Generate empty specification
37
- filter_spec = jt.map(lambda _: False, self)
38
-
39
- # Specify relevant parameters
40
- filter_spec = eqx.tree_at(
41
- lambda model: model.params,
42
- filter_spec,
43
- replace={key: param.filter_spec for key, param in self.params.items()}
44
- )
45
-
46
- return filter_spec
47
-
48
- # Add constrain method
49
- @eqx.filter_jit
50
- def constrain_params(self) -> Tuple[P, Scalar]:
51
- """
52
- Constrain `params` to the appropriate domain.
53
-
54
- # Returns
55
- A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
56
- """
57
- t_params: P = self.params
58
- target: Scalar = jnp.array(0.0)
59
-
60
- for par, map in self.constraints.items():
61
- # Apply transformation
62
- t_params[par], ladj = map.constrain(t_params[par])
63
-
64
- # Adjust posterior density
65
- target -= ladj
66
-
67
- return t_params, target
68
-
69
- # Add default transform method
70
- @eqx.filter_jit
71
- def transform_params(self) -> Tuple[P, Scalar]:
72
- """
73
- Apply a custom transformation to `params` if needed.
74
-
75
- # Returns
76
- A dictionary of transformed JAX Arrays representing the transformed parameters.
77
- """
78
- return self.constrain_params()
File without changes
File without changes
File without changes
File without changes