bayinx 0.2.24__tar.gz → 0.2.26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {bayinx-0.2.24 → bayinx-0.2.26}/PKG-INFO +1 -1
  2. {bayinx-0.2.24 → bayinx-0.2.26}/pyproject.toml +2 -2
  3. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/core/flow.py +1 -3
  4. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/core/model.py +1 -2
  5. bayinx-0.2.26/src/bayinx/mhx/__init__.py +1 -0
  6. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/mhx/vi/flows/fullaffine.py +16 -1
  7. bayinx-0.2.24/src/bayinx/core/utils.py +0 -54
  8. {bayinx-0.2.24 → bayinx-0.2.26}/.github/workflows/release_and_publish.yml +0 -0
  9. {bayinx-0.2.24 → bayinx-0.2.26}/.gitignore +0 -0
  10. {bayinx-0.2.24 → bayinx-0.2.26}/README.md +0 -0
  11. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/__init__.py +0 -0
  12. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/core/__init__.py +0 -0
  13. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/core/constraints.py +0 -0
  14. /bayinx-0.2.24/src/bayinx/mhx/__init__.py → /bayinx-0.2.26/src/bayinx/core/utils.py +0 -0
  15. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/core/variational.py +0 -0
  16. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/dists/__init__.py +0 -0
  17. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/dists/bernoulli.py +0 -0
  18. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/dists/binomial.py +0 -0
  19. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/dists/gamma.py +0 -0
  20. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/dists/gamma2.py +0 -0
  21. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/dists/normal.py +0 -0
  22. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/dists/uniform.py +0 -0
  23. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/mhx/vi/__init__.py +0 -0
  24. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  25. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  26. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  27. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  28. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/mhx/vi/meanfield.py +0 -0
  29. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/mhx/vi/normalizing_flow.py +0 -0
  30. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/mhx/vi/standard.py +0 -0
  31. {bayinx-0.2.24 → bayinx-0.2.26}/src/bayinx/py.typed +0 -0
  32. {bayinx-0.2.24 → bayinx-0.2.26}/tests/__init__.py +0 -0
  33. {bayinx-0.2.24 → bayinx-0.2.26}/tests/test_variational.py +0 -0
  34. {bayinx-0.2.24 → bayinx-0.2.26}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.24
3
+ Version: 0.2.26
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.2.24"
3
+ version = "0.2.26"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
19
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
20
 
21
21
  [tool.bumpversion]
22
- current_version = "0.2.24"
22
+ current_version = "0.2.26"
23
23
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
24
  serialize = ["{major}.{minor}.{patch}"]
25
25
  search = "{current_version}"
@@ -5,10 +5,8 @@ import equinox as eqx
5
5
  import jax.tree_util as jtu
6
6
  from jaxtyping import Array, Float
7
7
 
8
- from bayinx.core.utils import __MyMeta
9
8
 
10
-
11
- class Flow(eqx.Module, metaclass=__MyMeta):
9
+ class Flow(eqx.Module):
12
10
  """
13
11
  A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
14
12
 
@@ -7,10 +7,9 @@ import jax.tree_util as jtu
7
7
  from jaxtyping import Array, Scalar
8
8
 
9
9
  from bayinx.core.constraints import Constraint
10
- from bayinx.core.utils import __MyMeta
11
10
 
12
11
 
13
- class Model(eqx.Module, metaclass=__MyMeta):
12
+ class Model(eqx.Module):
14
13
  """
15
14
  A superclass used to define probabilistic models.
16
15
 
@@ -0,0 +1 @@
1
+
@@ -33,7 +33,22 @@ class FullAffine(Flow):
33
33
  "scale": jnp.zeros((dim, dim)),
34
34
  }
35
35
 
36
- self.constraints = {"scale": lambda m: jnp.tril(jnp.exp(m))}
36
+ self.constraints = {"scale": lambda m: jnp.tril(m)}
37
+
38
+ def transform_pars(self):
39
+ # Get constrained parameters
40
+ params = self.constrain_pars()
41
+
42
+ # Extract diagonal and apply exponential
43
+ diag: Array = jnp.exp(jnp.diag(params['scale']))
44
+
45
+ # Fill diagonal
46
+ params['scale'] = jnp.fill_diagonal(params['scale'], diag)
47
+
48
+
49
+ return params
50
+
51
+
37
52
 
38
53
  @eqx.filter_jit
39
54
  def forward(self, draws: Array) -> Array:
@@ -1,54 +0,0 @@
1
- from typing import Callable, Dict
2
-
3
- import equinox as eqx
4
- from jaxtyping import Array
5
-
6
-
7
- class __MyMeta(type(eqx.Module)):
8
- """
9
- Metaclass to ensure attribute types are respected.
10
- """
11
-
12
- def __call__(cls, *args, **kwargs):
13
- obj = super().__call__(*args, **kwargs)
14
-
15
- # Check parameters are a Dict of JAX Arrays
16
- if not isinstance(obj.params, Dict):
17
- raise ValueError(
18
- f"Model {cls.__name__} must initialize 'params' as a dictionary."
19
- )
20
-
21
- for key, value in obj.params.items():
22
- if not isinstance(value, Array):
23
- raise TypeError(f"Parameter '{key}' must be a JAX Array.")
24
-
25
- # Check constraints are a Dict of functions
26
- if not isinstance(obj.constraints, Dict):
27
- raise ValueError(
28
- f"Model {cls.__name__} must initialize 'constraints' as a dictionary."
29
- )
30
-
31
- for key, value in obj.constraints.items():
32
- if not isinstance(value, Callable):
33
- raise TypeError(f"Constraint for parameter '{key}' must be a function.")
34
-
35
- # Check that the constrain method returns a dict equivalent to `params`
36
- t_params: Dict[str, Array] = obj.constrain_pars()
37
-
38
- if not isinstance(t_params, Dict):
39
- raise ValueError(
40
- f"The 'constrain' method of {cls.__name__} must return a Dict of JAX Arrays."
41
- )
42
-
43
- for key, value in t_params.items():
44
- if not isinstance(value, Array):
45
- raise TypeError(f"Constrained parameter '{key}' must be a JAX Array.")
46
-
47
- if not value.shape == obj.params[key].shape:
48
- raise ValueError(
49
- f"Constrained parameter '{key}' must have same shape as unconstrained counterpart."
50
- )
51
-
52
- # Check transform_pars
53
-
54
- return obj
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes