bayinx 0.2.33__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {bayinx-0.2.33 → bayinx-0.3.2}/PKG-INFO +1 -1
  2. {bayinx-0.2.33 → bayinx-0.3.2}/pyproject.toml +2 -2
  3. bayinx-0.3.2/src/bayinx/__init__.py +2 -0
  4. bayinx-0.3.2/src/bayinx/constraints/__init__.py +1 -0
  5. bayinx-0.3.2/src/bayinx/constraints/lower.py +51 -0
  6. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/core/__init__.py +1 -0
  7. bayinx-0.3.2/src/bayinx/core/constraint.py +28 -0
  8. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/core/flow.py +6 -4
  9. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/core/model.py +17 -13
  10. bayinx-0.3.2/src/bayinx/core/parameter.py +41 -0
  11. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/core/variational.py +2 -2
  12. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/fullaffine.py +2 -2
  13. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/planar.py +2 -2
  14. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/radial.py +2 -2
  15. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/vi/meanfield.py +19 -17
  16. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/vi/normalizing_flow.py +6 -4
  17. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/vi/standard.py +3 -3
  18. {bayinx-0.2.33 → bayinx-0.3.2}/tests/test_variational.py +22 -52
  19. bayinx-0.2.33/src/bayinx/__init__.py +0 -1
  20. bayinx-0.2.33/src/bayinx/constraints/lower.py +0 -37
  21. bayinx-0.2.33/src/bayinx/core/constraint.py +0 -26
  22. bayinx-0.2.33/tests/__init__.py +0 -0
  23. {bayinx-0.2.33 → bayinx-0.3.2}/.github/workflows/release_and_publish.yml +0 -0
  24. {bayinx-0.2.33 → bayinx-0.3.2}/.gitignore +0 -0
  25. {bayinx-0.2.33 → bayinx-0.3.2}/README.md +0 -0
  26. {bayinx-0.2.33/src/bayinx/constraints → bayinx-0.3.2/src/bayinx/dists}/__init__.py +0 -0
  27. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/dists/bernoulli.py +0 -0
  28. {bayinx-0.2.33/src/bayinx/dists → bayinx-0.3.2/src/bayinx/dists/censored}/__init__.py +0 -0
  29. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/dists/censored/gamma2/r.py +0 -0
  30. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/dists/gamma2.py +0 -0
  31. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/dists/normal.py +0 -0
  32. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/dists/uniform.py +0 -0
  33. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/__init__.py +0 -0
  34. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/vi/__init__.py +0 -0
  35. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  36. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  37. {bayinx-0.2.33 → bayinx-0.3.2}/src/bayinx/py.typed +0 -0
  38. {bayinx-0.2.33/src/bayinx/dists/censored → bayinx-0.3.2/tests}/__init__.py +0 -0
  39. {bayinx-0.2.33 → bayinx-0.3.2}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.33
3
+ Version: 0.3.2
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.2.33"
3
+ version = "0.3.2"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
19
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
20
 
21
21
  [tool.bumpversion]
22
- current_version = "0.2.33"
22
+ current_version = "0.3.2"
23
23
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
24
  serialize = ["{major}.{minor}.{patch}"]
25
25
  search = "{current_version}"
@@ -0,0 +1,2 @@
1
+ from bayinx.core import Model as Model
2
+ from bayinx.core import Parameter as Parameter
@@ -0,0 +1 @@
1
+ from bayinx.constraints.lower import Lower as Lower
@@ -0,0 +1,51 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.numpy as jnp
5
+ import jax.tree as jt
6
+ from jaxtyping import PyTree, Scalar, ScalarLike
7
+
8
+ from bayinx.core.constraint import Constraint
9
+ from bayinx.core.parameter import Parameter
10
+
11
+
12
+ class Lower(Constraint):
13
+ """
14
+ Enforces a lower bound on the parameter.
15
+ """
16
+
17
+ lb: Scalar
18
+
19
+ def __init__(self, lb: ScalarLike):
20
+ self.lb = jnp.array(lb)
21
+
22
+ @eqx.filter_jit
23
+ def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
24
+ """
25
+ Enforces a lower bound on the parameter and adjusts the posterior density.
26
+
27
+ # Parameters
28
+ - `x`: The unconstrained `Parameter`.
29
+
30
+ # Parameters
31
+ A tuple containing:
32
+ - A modified `Parameter` with relevant leaves satisfying the constraint.
33
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
34
+ """
35
+ # Extract relevant filter specification
36
+ filter_spec = x.filter_spec
37
+
38
+ # Extract relevant parameters(all Array)
39
+ dyn_params, static_params = eqx.partition(x, filter_spec)
40
+
41
+ # Compute density adjustment
42
+ laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
43
+ laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
44
+
45
+ # Compute transformation
46
+ dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
47
+
48
+ # Combine into full parameter object
49
+ x = eqx.combine(dyn_params, static_params)
50
+
51
+ return x, laj
@@ -1,3 +1,4 @@
1
1
  from bayinx.core.flow import Flow as Flow
2
2
  from bayinx.core.model import Model as Model
3
+ from bayinx.core.parameter import Parameter as Parameter
3
4
  from bayinx.core.variational import Variational as Variational
@@ -0,0 +1,28 @@
1
+ from abc import abstractmethod
2
+ from typing import Tuple
3
+
4
+ import equinox as eqx
5
+ from jaxtyping import Scalar
6
+
7
+ from bayinx.core.parameter import Parameter
8
+
9
+
10
+ class Constraint(eqx.Module):
11
+ """
12
+ Abstract base class for defining parameter constraints.
13
+ """
14
+
15
+ @abstractmethod
16
+ def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
17
+ """
18
+ Applies the constraining transformation to a parameter and computes the log-absolute-Jacobian of the transformation.
19
+
20
+ # Parameters
21
+ - `x`: The unconstrained `Parameter`.
22
+
23
+ # Returns
24
+ A tuple containing:
25
+ - The constrained `Parameter`.
26
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
27
+ """
28
+ pass
@@ -31,11 +31,13 @@ class Flow(eqx.Module):
31
31
  Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
32
32
 
33
33
  # Returns
34
- A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
34
+ A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
35
35
  """
36
36
  pass
37
37
 
38
38
  # Default filter specification
39
+ @property
40
+ @eqx.filter_jit
39
41
  def filter_spec(self):
40
42
  """
41
43
  Generates a filter specification to subset relevant parameters for the flow.
@@ -53,7 +55,7 @@ class Flow(eqx.Module):
53
55
  return filter_spec
54
56
 
55
57
  @eqx.filter_jit
56
- def constrain_pars(self: Self):
58
+ def constrain_params(self: Self):
57
59
  """
58
60
  Constrain `params` to the appropriate domain.
59
61
 
@@ -68,11 +70,11 @@ class Flow(eqx.Module):
68
70
  return t_params
69
71
 
70
72
  @eqx.filter_jit
71
- def transform_pars(self: Self) -> Dict[str, Array]:
73
+ def transform_params(self: Self) -> Dict[str, Array]:
72
74
  """
73
75
  Apply a custom transformation to `params` if needed.
74
76
 
75
77
  # Returns
76
78
  A dictionary of transformed JAX Arrays representing the transformed parameters.
77
79
  """
78
- return self.constrain_pars()
80
+ return self.constrain_params()
@@ -1,24 +1,25 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, Dict, Tuple
2
+ from typing import Any, Dict, Generic, Tuple, TypeVar
3
3
 
4
4
  import equinox as eqx
5
5
  import jax.numpy as jnp
6
6
  import jax.tree as jt
7
- from jaxtyping import Array, PyTree, Scalar
7
+ from jaxtyping import PyTree, Scalar
8
8
 
9
9
  from bayinx.core.constraint import Constraint
10
+ from bayinx.core.parameter import Parameter
10
11
 
11
-
12
- class Model(eqx.Module):
12
+ P = TypeVar('P', bound=Dict[str, Parameter[PyTree]])
13
+ class Model(eqx.Module, Generic[P]):
13
14
  """
14
15
  An abstract base class used to define probabilistic models.
15
16
 
16
17
  # Attributes
17
- - `params`: A dictionary of JAX Arrays representing parameters of the model.
18
+ - `params`: A dictionary of parameters.
18
19
  - `constraints`: A dictionary of constraints.
19
20
  """
20
21
 
21
- params: Dict[str, PyTree]
22
+ params: P
22
23
  constraints: Dict[str, Constraint]
23
24
 
24
25
  @abstractmethod
@@ -26,6 +27,8 @@ class Model(eqx.Module):
26
27
  pass
27
28
 
28
29
  # Default filter specification
30
+ @property
31
+ @eqx.filter_jit
29
32
  def filter_spec(self):
30
33
  """
31
34
  Generates a filter specification to subset relevant parameters for the model.
@@ -33,25 +36,25 @@ class Model(eqx.Module):
33
36
  # Generate empty specification
34
37
  filter_spec = jt.map(lambda _: False, self)
35
38
 
36
- # Specify JAX Array parameters
39
+ # Specify relevant parameters
37
40
  filter_spec = eqx.tree_at(
38
41
  lambda model: model.params,
39
42
  filter_spec,
40
- replace=jt.map(eqx.is_array, self.params),
43
+ replace={key: param.filter_spec for key, param in self.params.items()}
41
44
  )
42
45
 
43
46
  return filter_spec
44
47
 
45
48
  # Add constrain method
46
49
  @eqx.filter_jit
47
- def constrain_pars(self) -> Tuple[Dict[str, Array], Scalar]:
50
+ def constrain_params(self) -> Tuple[P, Scalar]:
48
51
  """
49
52
  Constrain `params` to the appropriate domain.
50
53
 
51
54
  # Returns
52
- A dictionary of transformed JAX Arrays representing the constrained parameters and the adjustment to the posterior density.
55
+ A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
53
56
  """
54
- t_params: Dict[str, Array] = self.params
57
+ t_params: P = self.params
55
58
  target: Scalar = jnp.array(0.0)
56
59
 
57
60
  for par, map in self.constraints.items():
@@ -64,11 +67,12 @@ class Model(eqx.Module):
64
67
  return t_params, target
65
68
 
66
69
  # Add default transform method
67
- def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
70
+ @eqx.filter_jit
71
+ def transform_params(self) -> Tuple[P, Scalar]:
68
72
  """
69
73
  Apply a custom transformation to `params` if needed.
70
74
 
71
75
  # Returns
72
76
  A dictionary of transformed JAX Arrays representing the transformed parameters.
73
77
  """
74
- return self.constrain_pars()
78
+ return self.constrain_params()
@@ -0,0 +1,41 @@
1
+ from typing import Generic, Self, TypeVar
2
+
3
+ import equinox as eqx
4
+ import jax.tree as jt
5
+ from jaxtyping import PyTree
6
+
7
+ T = TypeVar('T', bound=PyTree)
8
+ class Parameter(eqx.Module, Generic[T]):
9
+ """
10
+ A container for a parameter of a `Model`.
11
+
12
+ Subclasses can be constructed for custom filter specifications(`filter_spec`).
13
+
14
+ # Attributes
15
+ - `vals`: The parameter's value(s).
16
+ """
17
+ vals: T
18
+
19
+
20
+ def __init__(self, values: T):
21
+ # Insert parameter values
22
+ self.vals = values
23
+
24
+ # Default filter specification
25
+ @property
26
+ @eqx.filter_jit
27
+ def filter_spec(self) -> Self:
28
+ """
29
+ Generates a filter specification to filter out static parameters.
30
+ """
31
+ # Generate empty specification
32
+ filter_spec = jt.map(lambda _: False, self)
33
+
34
+ # Specify Array leaves
35
+ filter_spec = eqx.tree_at(
36
+ lambda params: params.vals,
37
+ filter_spec,
38
+ replace=jt.map(eqx.is_array_like, self.vals),
39
+ )
40
+
41
+ return filter_spec
@@ -103,7 +103,7 @@ class Variational(eqx.Module):
103
103
  - `key`: A PRNG key.
104
104
  """
105
105
  # Partition variational
106
- dyn, static = eqx.partition(self, self.filter_spec())
106
+ dyn, static = eqx.partition(self, self.filter_spec)
107
107
 
108
108
  # Construct scheduler
109
109
  schedule: Schedule = opx.cosine_decay_schedule(
@@ -143,7 +143,7 @@ class Variational(eqx.Module):
143
143
 
144
144
  # Compute updates
145
145
  updates, opt_state = optim.update(
146
- updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
146
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
147
147
  )
148
148
 
149
149
  # Update variational distribution
@@ -46,7 +46,7 @@ class FullAffine(Flow):
46
46
 
47
47
  @eqx.filter_jit
48
48
  def forward(self, draws: Array) -> Array:
49
- params = self.transform_pars()
49
+ params = self.transform_params()
50
50
 
51
51
  # Extract parameters
52
52
  shift: Array = params["shift"]
@@ -60,7 +60,7 @@ class FullAffine(Flow):
60
60
  @eqx.filter_jit
61
61
  @partial(jax.vmap, in_axes=(None, 0))
62
62
  def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
63
- params = self.transform_pars()
63
+ params = self.transform_params()
64
64
 
65
65
  # Extract parameters
66
66
  shift: Array = params["shift"]
@@ -39,7 +39,7 @@ class Planar(Flow):
39
39
  @eqx.filter_jit
40
40
  @partial(jax.vmap, in_axes=(None, 0))
41
41
  def forward(self, draws: Array) -> Array:
42
- params = self.transform_pars()
42
+ params = self.transform_params()
43
43
 
44
44
  # Extract parameters
45
45
  w: Array = params["w"]
@@ -54,7 +54,7 @@ class Planar(Flow):
54
54
  @eqx.filter_jit
55
55
  @partial(jax.vmap, in_axes=(None, 0))
56
56
  def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
57
- params = self.transform_pars()
57
+ params = self.transform_params()
58
58
 
59
59
  # Extract parameters
60
60
  w: Array = params["w"]
@@ -49,7 +49,7 @@ class Radial(Flow):
49
49
  # Returns
50
50
  The transformed samples.
51
51
  """
52
- params = self.transform_pars()
52
+ params = self.transform_params()
53
53
 
54
54
  # Extract parameters
55
55
  alpha = params["alpha"]
@@ -67,7 +67,7 @@ class Radial(Flow):
67
67
  @partial(jax.vmap, in_axes=(None, 0))
68
68
  @eqx.filter_jit
69
69
  def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
70
- params = self.transform_pars()
70
+ params = self.transform_params()
71
71
 
72
72
  # Extract parameters
73
73
  alpha = params["alpha"]
@@ -29,7 +29,7 @@ class MeanField(Variational):
29
29
  - `model`: A probabilistic `Model` object.
30
30
  """
31
31
  # Partition model
32
- params, self._constraints = eqx.partition(model, model.filter_spec())
32
+ params, self._constraints = eqx.partition(model, model.filter_spec)
33
33
 
34
34
  # Flatten params component
35
35
  params, self._unflatten = ravel_pytree(params)
@@ -40,6 +40,22 @@ class MeanField(Variational):
40
40
  "log_std": jnp.zeros(params.size, dtype=params.dtype),
41
41
  }
42
42
 
43
+ @property
44
+ @eqx.filter_jit
45
+ def filter_spec(self):
46
+ # Generate empty specification
47
+ filter_spec = jtu.tree_map(lambda _: False, self)
48
+
49
+ # Specify variational parameters
50
+ filter_spec = eqx.tree_at(
51
+ lambda mf: mf.var_params,
52
+ filter_spec,
53
+ replace=True,
54
+ )
55
+
56
+ return filter_spec
57
+
58
+
43
59
  @eqx.filter_jit
44
60
  def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
45
61
  # Sample variational draws
@@ -59,23 +75,9 @@ class MeanField(Variational):
59
75
  sigma=jnp.exp(self.var_params["log_std"]),
60
76
  ).sum(axis=1)
61
77
 
62
- @eqx.filter_jit
63
- def filter_spec(self):
64
- # Generate empty specification
65
- filter_spec = jtu.tree_map(lambda _: False, self)
66
-
67
- # Specify variational parameters
68
- filter_spec = eqx.tree_at(
69
- lambda mf: mf.var_params,
70
- filter_spec,
71
- replace=True,
72
- )
73
-
74
- return filter_spec
75
-
76
78
  @eqx.filter_jit
77
79
  def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
78
- dyn, static = eqx.partition(self, self.filter_spec())
80
+ dyn, static = eqx.partition(self, self.filter_spec)
79
81
 
80
82
  @eqx.filter_jit
81
83
  def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
@@ -97,7 +99,7 @@ class MeanField(Variational):
97
99
 
98
100
  @eqx.filter_jit
99
101
  def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
100
- dyn, static = eqx.partition(self, self.filter_spec())
102
+ dyn, static = eqx.partition(self, self.filter_spec)
101
103
 
102
104
  @eqx.filter_grad
103
105
  @eqx.filter_jit
@@ -33,7 +33,7 @@ class NormalizingFlow(Variational):
33
33
  - `model`: A probabilistic `Model` object.
34
34
  """
35
35
  # Partition model
36
- params, self._constraints = eqx.partition(model, eqx.is_array)
36
+ params, self._constraints = eqx.partition(model, model.filter_spec)
37
37
 
38
38
  # Flatten params component
39
39
  _, self._unflatten = jfu.ravel_pytree(params)
@@ -41,6 +41,8 @@ class NormalizingFlow(Variational):
41
41
  self.base = base
42
42
  self.flows = flows
43
43
 
44
+ @property
45
+ @eqx.filter_jit
44
46
  def filter_spec(self):
45
47
  # Generate empty specification
46
48
  filter_spec = jtu.tree_map(lambda _: False, self)
@@ -49,7 +51,7 @@ class NormalizingFlow(Variational):
49
51
  filter_spec = eqx.tree_at(
50
52
  lambda vari: vari.flows,
51
53
  filter_spec,
52
- replace=[flow.filter_spec() for flow in self.flows],
54
+ replace=[flow.filter_spec for flow in self.flows],
53
55
  )
54
56
 
55
57
  return filter_spec
@@ -112,7 +114,7 @@ class NormalizingFlow(Variational):
112
114
 
113
115
  @eqx.filter_jit
114
116
  def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
115
- dyn, static = eqx.partition(self, self.filter_spec())
117
+ dyn, static = eqx.partition(self, self.filter_spec)
116
118
 
117
119
  @eqx.filter_jit
118
120
  def elbo(dyn: Self, n: int, key: Key, data: Any = None):
@@ -129,7 +131,7 @@ class NormalizingFlow(Variational):
129
131
 
130
132
  @eqx.filter_jit
131
133
  def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
132
- dyn, static = eqx.partition(self, self.filter_spec())
134
+ dyn, static = eqx.partition(self, self.filter_spec)
133
135
 
134
136
  @eqx.filter_grad
135
137
  @eqx.filter_jit
@@ -19,7 +19,7 @@ class Standard(Variational):
19
19
  - `dim`: Dimension of the parameter space.
20
20
  """
21
21
 
22
- dim: int = eqx.field(static=True)
22
+ dim: int
23
23
  _unflatten: Callable[[Float[Array, "..."]], Model]
24
24
  _constraints: Model
25
25
 
@@ -31,7 +31,7 @@ class Standard(Variational):
31
31
  - `model`: A probabilistic `Model` object.
32
32
  """
33
33
  # Partition model
34
- params, self._constraints = eqx.partition(model, model.filter_spec())
34
+ params, self._constraints = eqx.partition(model, model.filter_spec)
35
35
 
36
36
  # Flatten params component
37
37
  params, self._unflatten = ravel_pytree(params)
@@ -54,7 +54,7 @@ class Standard(Variational):
54
54
  sigma=jnp.array(1.0),
55
55
  ).sum(axis=1, keepdims=True)
56
56
 
57
- @eqx.filter_jit
57
+ @property
58
58
  def filter_spec(self):
59
59
  filter_spec = jtu.tree_map(lambda _: False, self)
60
60
 
@@ -1,41 +1,36 @@
1
- from typing import Callable, Dict
1
+ from typing import Dict
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.numpy as jnp
5
5
  import pytest
6
6
  from jaxtyping import Array
7
7
 
8
- from bayinx import Model
8
+ from bayinx import Model, Parameter
9
9
  from bayinx.dists import normal
10
10
  from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
11
11
  from bayinx.mhx.vi.flows import FullAffine, Planar, Radial
12
12
 
13
13
 
14
- # Tests ----
15
- @pytest.mark.parametrize("var_draws", [1, 10, 100])
16
- def test_meanfield(benchmark, var_draws):
17
- # Construct model definition
18
- class NormalDist(Model):
19
- params: Dict[str, Array]
20
- constraints: Dict[str, Callable[[Array], Array]]
14
+ class NormalDist(Model[Dict[str, Parameter[Array]]]):
15
+ def __init__(self):
16
+ self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
17
+ self.constraints = {}
21
18
 
22
- def __init__(self):
23
- self.params = {"mu": jnp.array([0.0, 0.0])}
24
- self.constraints = {}
19
+ @eqx.filter_jit
20
+ def eval(self, data = None):
21
+ # Get constrained parameters
22
+ params, target = self.constrain_params()
25
23
 
26
- @eqx.filter_jit
27
- def eval(self, data: dict):
28
- # Get constrained parameters
29
- params, target = self.constrain_pars()
24
+ # Evaluate mu ~ N(10,1)
25
+ target += normal.logprob(
26
+ x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
27
+ ).sum()
30
28
 
31
- # Evaluate mu ~ N(10,1)
32
- target += normal.logprob(
33
- x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)
34
- ).sum()
35
-
36
- # Evaluate mu ~ N(10,1)
37
- return target
29
+ return target
38
30
 
31
+ # Tests ----
32
+ @pytest.mark.parametrize("var_draws", [1, 10, 100])
33
+ def test_meanfield(benchmark, var_draws):
39
34
  # Construct model
40
35
  model = NormalDist()
41
36
 
@@ -57,28 +52,6 @@ def test_meanfield(benchmark, var_draws):
57
52
 
58
53
  @pytest.mark.parametrize("var_draws", [1, 10, 100])
59
54
  def test_affine(benchmark, var_draws):
60
- # Construct model definition
61
- class NormalDist(Model):
62
- params: Dict[str, Array]
63
- constraints: Dict[str, Callable[[Array], Array]]
64
-
65
- def __init__(self):
66
- self.params = {"mu": jnp.array([0.0, 0.0])}
67
- self.constraints = {}
68
-
69
- @eqx.filter_jit
70
- def eval(self, data: dict):
71
- # Get constrained parameters
72
- params, target = self.constrain_pars()
73
-
74
- # Evaluate mu ~ N(10,1)
75
- target += normal.logprob(
76
- x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)
77
- ).sum()
78
-
79
- # Evaluate mu ~ N(10,1)
80
- return target
81
-
82
55
  # Construct model
83
56
  model = NormalDist()
84
57
 
@@ -92,7 +65,7 @@ def test_affine(benchmark, var_draws):
92
65
  benchmark(benchmark_fit)
93
66
  vari = vari.fit(20000, var_draws=var_draws)
94
67
 
95
- params = vari.flows[0].transform_pars()
68
+ params = vari.flows[0].transform_params()
96
69
  assert (abs(10.0 - vari.flows[0].params["shift"]) < 0.1).all() and (
97
70
  abs(jnp.eye(2) - params["scale"]) < 0.1
98
71
  ).all()
@@ -102,21 +75,18 @@ def test_affine(benchmark, var_draws):
102
75
  def test_flows(benchmark, var_draws):
103
76
  # Construct model definition
104
77
  class NormalDist(Model):
105
- params: Dict[str, Array]
106
- constraints: Dict[str, Callable[[Array], Array]]
107
-
108
78
  def __init__(self):
109
- self.params = {"mu": jnp.array([0.0, 0.0])}
79
+ self.params = {"mu": Parameter(jnp.array([0.0, 0.0]))}
110
80
  self.constraints = {}
111
81
 
112
82
  @eqx.filter_jit
113
83
  def eval(self, data: dict):
114
84
  # Get constrained parameters
115
- params, target = self.constrain_pars()
85
+ params, target = self.constrain_params()
116
86
 
117
87
  # Evaluate mu ~ N(10,1)
118
88
  target += normal.logprob(
119
- x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0)
89
+ x=params["mu"].vals, mu=jnp.array(10.0), sigma=jnp.array(1.0)
120
90
  ).sum()
121
91
 
122
92
  return target
@@ -1 +0,0 @@
1
- from bayinx.core.model import Model as Model
@@ -1,37 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax.numpy as jnp
4
- from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
5
-
6
- from bayinx.core.constraint import Constraint
7
-
8
-
9
- class Lower(Constraint):
10
- """
11
- Enforces a lower bound on the parameter.
12
- """
13
-
14
- lb: ScalarLike
15
-
16
- def __init__(self, lb: ScalarLike):
17
- self.lb = lb
18
-
19
- def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
20
- """
21
- Applies the lower bound constraint and adjusts the posterior density.
22
-
23
- # Parameters
24
- - `x`: The unconstrained JAX Array-like input.
25
-
26
- # Parameters
27
- A tuple containing:
28
- - The constrained JAX Array (x > self.lb).
29
- - A scalar JAX Array representing the log-absolute-Jacobian of the transformation.
30
- """
31
- # Compute transformation adjustment
32
- laj: Scalar = jnp.sum(x)
33
-
34
- # Compute transformation
35
- x = jnp.exp(x) + self.lb
36
-
37
- return x, laj
@@ -1,26 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Tuple
3
-
4
- import equinox as eqx
5
- from jaxtyping import Array, ArrayLike, Scalar
6
-
7
-
8
- class Constraint(eqx.Module):
9
- """
10
- Abstract base class for defining parameter constraints.
11
- """
12
-
13
- @abstractmethod
14
- def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
15
- """
16
- Applies the constraining transformation to an unconstrained input and computes the log-absolute-Jacobian of the transformation.
17
-
18
- # Parameters
19
- - `x`: The unconstrained JAX Array-like input.
20
-
21
- # Returns
22
- A tuple containing:
23
- - The constrained JAX Array.
24
- - A scalar JAX Array representing the log-absolute-Jacobian of the transformation.
25
- """
26
- pass
File without changes
File without changes
File without changes
File without changes
File without changes