bayinx 0.3.1__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {bayinx-0.3.1 → bayinx-0.3.2}/PKG-INFO +1 -1
  2. {bayinx-0.3.1 → bayinx-0.3.2}/pyproject.toml +2 -2
  3. bayinx-0.3.2/src/bayinx/constraints/__init__.py +1 -0
  4. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/model.py +6 -6
  5. {bayinx-0.3.1 → bayinx-0.3.2}/tests/test_variational.py +16 -41
  6. bayinx-0.3.1/tests/__init__.py +0 -0
  7. {bayinx-0.3.1 → bayinx-0.3.2}/.github/workflows/release_and_publish.yml +0 -0
  8. {bayinx-0.3.1 → bayinx-0.3.2}/.gitignore +0 -0
  9. {bayinx-0.3.1 → bayinx-0.3.2}/README.md +0 -0
  10. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/__init__.py +0 -0
  11. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/constraints/lower.py +0 -0
  12. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/__init__.py +0 -0
  13. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/constraint.py +0 -0
  14. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/flow.py +0 -0
  15. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/parameter.py +0 -0
  16. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/core/variational.py +0 -0
  17. {bayinx-0.3.1/src/bayinx/constraints → bayinx-0.3.2/src/bayinx/dists}/__init__.py +0 -0
  18. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/bernoulli.py +0 -0
  19. {bayinx-0.3.1/src/bayinx/dists → bayinx-0.3.2/src/bayinx/dists/censored}/__init__.py +0 -0
  20. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/censored/gamma2/r.py +0 -0
  21. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/gamma2.py +0 -0
  22. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/normal.py +0 -0
  23. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/dists/uniform.py +0 -0
  24. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/__init__.py +0 -0
  25. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/__init__.py +0 -0
  26. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  27. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
  28. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  29. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  30. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  31. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/meanfield.py +0 -0
  32. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/normalizing_flow.py +0 -0
  33. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/mhx/vi/standard.py +0 -0
  34. {bayinx-0.3.1 → bayinx-0.3.2}/src/bayinx/py.typed +0 -0
  35. {bayinx-0.3.1/src/bayinx/dists/censored → bayinx-0.3.2/tests}/__init__.py +0 -0
  36. {bayinx-0.3.1 → bayinx-0.3.2}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.3.1"
3
+ version = "0.3.2"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
19
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
20
 
21
21
  [tool.bumpversion]
22
- current_version = "0.3.1"
22
+ current_version = "0.3.2"
23
23
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
24
  serialize = ["{major}.{minor}.{patch}"]
25
25
  search = "{current_version}"
@@ -0,0 +1 @@
1
+ from bayinx.constraints.lower import Lower as Lower
@@ -9,8 +9,8 @@ from jaxtyping import PyTree, Scalar
9
9
  from bayinx.core.constraint import Constraint
10
10
  from bayinx.core.parameter import Parameter
11
11
 
12
- T = TypeVar('T', bound=PyTree)
13
- class Model(eqx.Module, Generic[T]):
12
+ P = TypeVar('P', bound=Dict[str, Parameter[PyTree]])
13
+ class Model(eqx.Module, Generic[P]):
14
14
  """
15
15
  An abstract base class used to define probabilistic models.
16
16
 
@@ -19,7 +19,7 @@ class Model(eqx.Module, Generic[T]):
19
19
  - `constraints`: A dictionary of constraints.
20
20
  """
21
21
 
22
- params: Dict[str, Parameter[T]]
22
+ params: P
23
23
  constraints: Dict[str, Constraint]
24
24
 
25
25
  @abstractmethod
@@ -47,14 +47,14 @@ class Model(eqx.Module, Generic[T]):
47
47
 
48
48
  # Add constrain method
49
49
  @eqx.filter_jit
50
- def constrain_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
50
+ def constrain_params(self) -> Tuple[P, Scalar]:
51
51
  """
52
52
  Constrain `params` to the appropriate domain.
53
53
 
54
54
  # Returns
55
55
  A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
56
56
  """
57
- t_params: Dict[str, Parameter[T]] = self.params
57
+ t_params: P = self.params
58
58
  target: Scalar = jnp.array(0.0)
59
59
 
60
60
  for par, map in self.constraints.items():
@@ -68,7 +68,7 @@ class Model(eqx.Module, Generic[T]):
68
68
 
69
69
  # Add default transform method
70
70
  @eqx.filter_jit
71
- def transform_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
71
+ def transform_params(self) -> Tuple[P, Scalar]:
72
72
  """
73
73
  Apply a custom transformation to `params` if needed.
74
74
 
@@ -1,4 +1,3 @@
1
-
2
1
  from typing import Dict
3
2
 
4
3
  import equinox as eqx
@@ -12,30 +11,26 @@ from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
12
11
  from bayinx.mhx.vi.flows import FullAffine, Planar, Radial
13
12
 
14
13
 
15
- # Tests ----
16
- @pytest.mark.parametrize("var_draws", [1, 10, 100])
17
- def test_meanfield(benchmark, var_draws):
18
- # Construct model definition
19
- class NormalDist(Model[Array]):
20
- params: Dict[str, Parameter[Array]]
14
+ class NormalDist(Model[Dict[str, Parameter[Array]]]):
15
+ def __init__(self):
16
+ self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
17
+ self.constraints = {}
21
18
 
22
- def __init__(self):
23
- self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
24
- self.constraints = {}
19
+ @eqx.filter_jit
20
+ def eval(self, data = None):
21
+ # Get constrained parameters
22
+ params, target = self.constrain_params()
25
23
 
26
- @eqx.filter_jit
27
- def eval(self, data = None):
28
- # Get constrained parameters
29
- params, target = self.constrain_params()
24
+ # Evaluate mu ~ N(10,1)
25
+ target += normal.logprob(
26
+ x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
27
+ ).sum()
30
28
 
31
- # Evaluate mu ~ N(10,1)
32
- target += normal.logprob(
33
- x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
34
- ).sum()
35
-
36
- # Evaluate mu ~ N(10,1)
37
- return target
29
+ return target
38
30
 
31
+ # Tests ----
32
+ @pytest.mark.parametrize("var_draws", [1, 10, 100])
33
+ def test_meanfield(benchmark, var_draws):
39
34
  # Construct model
40
35
  model = NormalDist()
41
36
 
@@ -57,26 +52,6 @@ def test_meanfield(benchmark, var_draws):
57
52
 
58
53
  @pytest.mark.parametrize("var_draws", [1, 10, 100])
59
54
  def test_affine(benchmark, var_draws):
60
- # Construct model definition
61
- class NormalDist(Model):
62
-
63
- def __init__(self):
64
- self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
65
- self.constraints = {}
66
-
67
- @eqx.filter_jit
68
- def eval(self, data: dict):
69
- # Get constrained parameters
70
- params, target = self.constrain_params()
71
-
72
- # Evaluate mu ~ N(10,1)
73
- target += normal.logprob(
74
- x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
75
- ).sum()
76
-
77
- # Evaluate mu ~ N(10,1)
78
- return target
79
-
80
55
  # Construct model
81
56
  model = NormalDist()
82
57
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes