bayinx 0.2.6__tar.gz → 0.2.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {bayinx-0.2.6 → bayinx-0.2.9}/PKG-INFO +2 -4
  2. {bayinx-0.2.6 → bayinx-0.2.9}/README.md +1 -1
  3. {bayinx-0.2.6 → bayinx-0.2.9}/pyproject.toml +4 -3
  4. bayinx-0.2.9/src/bayinx/core/flow.py +80 -0
  5. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/core/model.py +19 -0
  6. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/core/variational.py +17 -10
  7. bayinx-0.2.9/src/bayinx/mhx/__init__.py +1 -0
  8. bayinx-0.2.9/src/bayinx/mhx/vi/__init__.py +3 -0
  9. bayinx-0.2.9/src/bayinx/mhx/vi/flows/__init__.py +3 -0
  10. {bayinx-0.2.6/src/bayinx/machinery/variational → bayinx-0.2.9/src/bayinx/mhx/vi}/flows/affine.py +1 -3
  11. {bayinx-0.2.6/src/bayinx/machinery/variational → bayinx-0.2.9/src/bayinx/mhx/vi}/flows/planar.py +1 -3
  12. {bayinx-0.2.6/src/bayinx/machinery/variational → bayinx-0.2.9/src/bayinx/mhx/vi}/flows/radial.py +3 -4
  13. bayinx-0.2.9/src/bayinx/mhx/vi/flows/sylvester.py +19 -0
  14. {bayinx-0.2.6/src/bayinx/machinery/variational → bayinx-0.2.9/src/bayinx/mhx/vi}/meanfield.py +4 -4
  15. {bayinx-0.2.6/src/bayinx/machinery/variational → bayinx-0.2.9/src/bayinx/mhx/vi}/normalizing_flow.py +12 -14
  16. {bayinx-0.2.6/src/bayinx/machinery/variational → bayinx-0.2.9/src/bayinx/mhx/vi}/standard.py +6 -6
  17. {bayinx-0.2.6 → bayinx-0.2.9}/tests/test_variational.py +14 -11
  18. {bayinx-0.2.6 → bayinx-0.2.9}/uv.lock +12 -4
  19. bayinx-0.2.6/src/bayinx/core/flow.py +0 -68
  20. bayinx-0.2.6/src/bayinx/machinery/variational/__init__.py +0 -5
  21. bayinx-0.2.6/src/bayinx/machinery/variational/flows/__init__.py +0 -3
  22. bayinx-0.2.6/src/bayinx/machinery/variational/flows/sylvester.py +0 -76
  23. bayinx-0.2.6/tests/__init__.py +0 -0
  24. {bayinx-0.2.6 → bayinx-0.2.9}/.github/workflows/release_and_publish.yml +0 -0
  25. {bayinx-0.2.6 → bayinx-0.2.9}/.gitignore +0 -0
  26. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/__init__.py +0 -0
  27. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/core/__init__.py +0 -0
  28. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/core/utils.py +0 -0
  29. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/dists/__init__.py +0 -0
  30. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/dists/bernoulli.py +0 -0
  31. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/dists/binomial.py +0 -0
  32. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/dists/gamma.py +0 -0
  33. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/dists/gamma2.py +0 -0
  34. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/dists/normal.py +0 -0
  35. {bayinx-0.2.6 → bayinx-0.2.9}/src/bayinx/py.typed +0 -0
  36. {bayinx-0.2.6/src/bayinx/machinery → bayinx-0.2.9/tests}/__init__.py +0 -0
@@ -1,14 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.6
3
+ Version: 0.2.9
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
7
7
  Requires-Dist: jax>=0.4.38
8
8
  Requires-Dist: jaxtyping>=0.2.36
9
9
  Requires-Dist: optax>=0.2.4
10
- Requires-Dist: pytest-benchmark>=5.1.0
11
- Requires-Dist: pytest>=8.3.5
12
10
  Description-Content-Type: text/markdown
13
11
 
14
12
  # <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
@@ -23,7 +21,7 @@ In the short-term, I'm going to focus on:
23
21
  In the long-term, I'm going to focus on:
24
22
  1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
25
23
  2) How to make working with the posterior as easy as possible.
26
- - That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
24
+ - That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{D}, \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{D}, \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
27
25
 
28
26
  Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
29
27
 
@@ -10,7 +10,7 @@ In the short-term, I'm going to focus on:
10
10
  In the long-term, I'm going to focus on:
11
11
  1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
12
12
  2) How to make working with the posterior as easy as possible.
13
- - That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
13
+ - That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{D}, \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{D}, \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
14
14
 
15
15
  Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
16
16
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.2.6"
3
+ version = "0.2.9"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -9,8 +9,6 @@ dependencies = [
9
9
  "jax>=0.4.38",
10
10
  "jaxtyping>=0.2.36",
11
11
  "optax>=0.2.4",
12
- "pytest>=8.3.5",
13
- "pytest-benchmark>=5.1.0",
14
12
  ]
15
13
 
16
14
  [build-system]
@@ -19,3 +17,6 @@ build-backend = "hatchling.build"
19
17
 
20
18
  [tool.pytest.ini_options]
21
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
+
21
+ [dependency-groups]
22
+ dev = ["pytest>=8.3.5", "pytest-benchmark>=5.1.0"]
@@ -0,0 +1,80 @@
1
+ from abc import abstractmethod
2
+ from typing import Callable, Dict, Self, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax.tree_util as jtu
6
+ from jaxtyping import Array, Float
7
+
8
+ from bayinx.core.utils import __MyMeta
9
+
10
+
11
+ class Flow(eqx.Module, metaclass=__MyMeta):
12
+ """
13
+ A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
14
+
15
+ # Attributes
16
+ - `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
17
+ - `constraints`: A dictionary of functions that constrain their corresponding parameter.
18
+ """
19
+
20
+ params: Dict[str, Float[Array, "..."]]
21
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
22
+
23
+ @abstractmethod
24
+ def forward(self, draws: Array) -> Array:
25
+ """
26
+ Computes the forward transformation of `draws`.
27
+ """
28
+ pass
29
+
30
+ @abstractmethod
31
+ def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
32
+ """
33
+ Computes the log-absolute-determinant of the Jacobian at `draws` and applies the forward transformation.
34
+
35
+ # Returns
36
+ A tuple of JAX Arrays containing the log-absolute-determinant of the Jacobians and transformed draws.
37
+ """
38
+ pass
39
+
40
+ # Default filter specification
41
+ def filter_spec(self):
42
+ """
43
+ Generates a filter specification to subset relevant parameters for the flow.
44
+ """
45
+ # Generate empty specification
46
+ filter_spec = jtu.tree_map(lambda _: False, self)
47
+
48
+ # Specify JAX Array parameters
49
+ filter_spec = eqx.tree_at(
50
+ lambda flow: flow.params,
51
+ filter_spec,
52
+ replace=jtu.tree_map(eqx.is_array, self.params),
53
+ )
54
+
55
+ return filter_spec
56
+
57
+ @eqx.filter_jit
58
+ def constrain_pars(self: Self):
59
+ """
60
+ Constrain `params` to the appropriate domain.
61
+
62
+ # Returns
63
+ A dictionary of transformed JAX Arrays representing the constrained parameters.
64
+ """
65
+ t_params = self.params
66
+
67
+ for par, map in self.constraints.items():
68
+ t_params[par] = map(t_params[par])
69
+
70
+ return t_params
71
+
72
+ @eqx.filter_jit
73
+ def transform_pars(self: Self) -> Dict[str, Array]:
74
+ """
75
+ Apply a custom transformation to `params` if needed.
76
+
77
+ # Returns
78
+ A dictionary of transformed JAX Arrays representing the transformed parameters.
79
+ """
80
+ return self.constrain_pars()
@@ -2,6 +2,7 @@ from abc import abstractmethod
2
2
  from typing import Any, Callable, Dict
3
3
 
4
4
  import equinox as eqx
5
+ import jax.tree_util as jtu
5
6
  from jaxtyping import Array, Scalar
6
7
 
7
8
  from bayinx.core.utils import __MyMeta
@@ -23,6 +24,23 @@ class Model(eqx.Module, metaclass=__MyMeta):
23
24
  def eval(self, data: Any) -> Scalar:
24
25
  pass
25
26
 
27
+ # Default filter specification
28
+ def filter_spec(self):
29
+ """
30
+ Generates a filter specification to subset relevant parameters for the model.
31
+ """
32
+ # Generate empty specification
33
+ filter_spec = jtu.tree_map(lambda _: False, self)
34
+
35
+ # Specify JAX Array parameters
36
+ filter_spec = eqx.tree_at(
37
+ lambda model: model.params,
38
+ filter_spec,
39
+ replace=jtu.tree_map(eqx.is_array, self.params),
40
+ )
41
+
42
+ return filter_spec
43
+
26
44
  def __init_subclass__(cls):
27
45
  # Add constrain method
28
46
  def constrain_pars(self: Model) -> Dict[str, Array]:
@@ -43,6 +61,7 @@ class Model(eqx.Module, metaclass=__MyMeta):
43
61
 
44
62
  # Add transform_pars method if not present
45
63
  if not callable(getattr(cls, "transform_pars", None)):
64
+
46
65
  def transform_pars(self: Model) -> Dict[str, Array]:
47
66
  """
48
67
  Apply a custom transformation to `params` if needed.
@@ -104,6 +104,9 @@ class Variational(eqx.Module):
104
104
  - `var_draws`: Number of variational draws to draw each iteration.
105
105
  - `key`: A PRNG key.
106
106
  """
107
+ # Partition variational
108
+ dyn, static = eqx.partition(self, self.filter_spec())
109
+
107
110
  # Construct scheduler
108
111
  schedule: Schedule = opx.cosine_decay_schedule(
109
112
  init_value=learning_rate, decay_steps=max_iters
@@ -113,20 +116,20 @@ class Variational(eqx.Module):
113
116
  optim: GradientTransformation = opx.chain(
114
117
  opx.scale(-1.0), opx.nadam(schedule)
115
118
  )
116
- opt_state: OptState = optim.init(eqx.filter(self, self.filter_spec()))
119
+ opt_state: OptState = optim.init(dyn)
117
120
 
118
121
  # Optimization loop helper functions
119
122
  @eqx.filter_jit
120
123
  def condition(state: Tuple[Self, OptState, Scalar, Key]):
121
124
  # Unpack iteration state
122
- self, opt_state, i, key = state
125
+ dyn, opt_state, i, key = state
123
126
 
124
127
  return i < max_iters
125
128
 
126
129
  @eqx.filter_jit
127
130
  def body(state: Tuple[Self, OptState, Scalar, Key]):
128
131
  # Unpack iteration state
129
- self, opt_state, i, key = state
132
+ dyn, opt_state, i, key = state
130
133
 
131
134
  # Update iteration
132
135
  i = i + 1
@@ -134,26 +137,30 @@ class Variational(eqx.Module):
134
137
  # Update PRNG key
135
138
  key, _ = jr.split(key)
136
139
 
140
+ # Combine variational
141
+ vari = eqx.combine(dyn, static)
142
+
137
143
  # Compute gradient of the ELBO
138
- updates: PyTree = self.elbo_grad(var_draws, key, data)
144
+ updates: PyTree = vari.elbo_grad(var_draws, key, data)
139
145
 
140
146
  # Compute updates
141
147
  updates, opt_state = optim.update(
142
- updates, opt_state, eqx.filter(self, self.filter_spec())
148
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
143
149
  )
144
150
 
145
151
  # Update variational distribution
146
- self: Self = eqx.apply_updates(self, updates)
152
+ dyn = eqx.apply_updates(dyn, updates)
147
153
 
148
- return self, opt_state, i, key
154
+ return dyn, opt_state, i, key
149
155
 
150
156
  # Run optimization loop
151
- self = lax.while_loop(
157
+ dyn = lax.while_loop(
152
158
  cond_fun=condition,
153
159
  body_fun=body,
154
- init_val=(self, opt_state, jnp.array(0, jnp.uint32), key),
160
+ init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
155
161
  )[0]
156
162
 
157
- return self
163
+ # Return optimized variational
164
+ return eqx.combine(dyn, static)
158
165
 
159
166
  cls.fit = eqx.filter_jit(fit)
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,3 @@
1
+ from bayinx.mhx.vi.meanfield import MeanField as MeanField
2
+ from bayinx.mhx.vi.normalizing_flow import NormalizingFlow as NormalizingFlow
3
+ from bayinx.mhx.vi.standard import Standard as Standard
@@ -0,0 +1,3 @@
1
+ from bayinx.mhx.vi.flows.affine import Affine as Affine
2
+ from bayinx.mhx.vi.flows.planar import Planar as Planar
3
+ from bayinx.mhx.vi.flows.radial import Radial as Radial
@@ -19,9 +19,7 @@ class Affine(Flow):
19
19
  """
20
20
 
21
21
  params: Dict[str, Float[Array, "..."]]
22
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
23
- eqx.field(static=True)
24
- )
22
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
25
23
 
26
24
  def __init__(self, dim: int):
27
25
  """
@@ -20,9 +20,7 @@ class Planar(Flow):
20
20
  """
21
21
 
22
22
  params: Dict[str, Float[Array, "..."]]
23
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
24
- eqx.field(static=True)
25
- )
23
+ constraints: Dict[str, Callable[[Array], Array]]
26
24
 
27
25
  def __init__(self, dim: int, key=jr.PRNGKey(0)):
28
26
  """
@@ -21,9 +21,7 @@ class Radial(Flow):
21
21
  """
22
22
 
23
23
  params: Dict[str, Float[Array, "..."]]
24
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
25
- eqx.field(static=True)
26
- )
24
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
27
25
 
28
26
  def __init__(self, dim: int, key=jr.PRNGKey(0)):
29
27
  """
@@ -88,7 +86,8 @@ class Radial(Flow):
88
86
  # Compute density adjustment
89
87
  ladj = jnp.log(
90
88
  jnp.abs(
91
- (1.0 + alpha * beta / (alpha + r) ** 2.0) * (1.0 + x) ** (center.size - 1.0)
89
+ (1.0 + alpha * beta / (alpha + r) ** 2.0)
90
+ * (1.0 + x) ** (center.size - 1.0)
92
91
  )
93
92
  )
94
93
 
@@ -0,0 +1,19 @@
1
+ from typing import Callable, Dict
2
+
3
+ from jaxtyping import Array, Float
4
+
5
+ from bayinx.core import Flow
6
+
7
+
8
+ # TODO
9
+ class Sylvester(Flow):
10
+ """
11
+ A sylvester flow.
12
+
13
+ # Attributes
14
+ - `params`: A dictionary containing the JAX Arrays representing the flow parameters.
15
+ - `constraints`: A dictionary of constraining transformations.
16
+ """
17
+
18
+ params: Dict[str, Float[Array, "..."]]
19
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
@@ -31,15 +31,15 @@ class MeanField(Variational):
31
31
  - `model`: A probabilistic `Model` object.
32
32
  """
33
33
  # Partition model
34
- params, self._constraints = eqx.partition(model, eqx.is_array)
34
+ params, self._constraints = eqx.partition(model, model.filter_spec())
35
35
 
36
36
  # Flatten params component
37
- flat_params, self._unflatten = ravel_pytree(params)
37
+ params, self._unflatten = ravel_pytree(params)
38
38
 
39
39
  # Initialize variational parameters
40
40
  self.var_params = {
41
- "mean": flat_params,
42
- "log_std": jnp.zeros(flat_params.size, dtype=flat_params.dtype),
41
+ "mean": params,
42
+ "log_std": jnp.zeros(params.size, dtype=params.dtype),
43
43
  }
44
44
 
45
45
  @eqx.filter_jit
@@ -23,8 +23,8 @@ class NormalizingFlow(Variational):
23
23
 
24
24
  flows: list[Flow]
25
25
  base: Variational
26
- _unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
27
- _constraints: Model = eqx.field(static=True)
26
+ _unflatten: Callable[[Float[Array, "..."]], Model]
27
+ _constraints: Model
28
28
 
29
29
  def __init__(self, base: Variational, flows: list[Flow], model: Model):
30
30
  """
@@ -39,7 +39,7 @@ class NormalizingFlow(Variational):
39
39
  params, self._constraints = eqx.partition(model, eqx.is_array)
40
40
 
41
41
  # Flatten params component
42
- flat_params, self._unflatten = jfu.ravel_pytree(params)
42
+ _, self._unflatten = jfu.ravel_pytree(params)
43
43
 
44
44
  self.base = base
45
45
  self.flows = flows
@@ -73,7 +73,7 @@ class NormalizingFlow(Variational):
73
73
  return variational_evals
74
74
 
75
75
  @eqx.filter_jit
76
- def _eval(self, draws: Array, data=None) -> Tuple[Scalar, Array]:
76
+ def __eval(self, draws: Array, data=None) -> Tuple[Array, Array]:
77
77
  """
78
78
  Evaluate the posterior and variational densities at the transformed
79
79
  `draws` to avoid extra compute when requiring variational draws for
@@ -84,7 +84,7 @@ class NormalizingFlow(Variational):
84
84
  - `data`: Any data required to evaluate the posterior density.
85
85
 
86
86
  # Returns
87
- The posterior and variational densities.
87
+ The posterior and variational densities as JAX Arrays.
88
88
  """
89
89
  # Evaluate base density
90
90
  variational_evals: Array = self.base.eval(draws)
@@ -102,30 +102,30 @@ class NormalizingFlow(Variational):
102
102
  return posterior_evals, variational_evals
103
103
 
104
104
  def filter_spec(self):
105
- # Only optimize the parameters of the flows
105
+ # Generate empty specification
106
106
  filter_spec = jtu.tree_map(lambda _: False, self)
107
+
108
+ # Specify variational parameters based on each flow's filter spec.
107
109
  filter_spec = eqx.tree_at(
108
- lambda nf: nf.flows,
110
+ lambda vari: vari.flows,
109
111
  filter_spec,
110
- replace=True,
112
+ replace=[flow.filter_spec() for flow in self.flows],
111
113
  )
112
114
 
113
115
  return filter_spec
114
116
 
115
117
  @eqx.filter_jit
116
118
  def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
117
- # Partition
118
119
  dyn, static = eqx.partition(self, self.filter_spec())
119
120
 
120
121
  @eqx.filter_jit
121
122
  def elbo(dyn: Self, n: int, key: Key, data: Any = None):
122
- # Combine
123
123
  self = eqx.combine(dyn, static)
124
124
 
125
125
  # Sample draws from variational distribution
126
126
  draws: Array = self.base.sample(n, key)
127
127
 
128
- posterior_evals, variational_evals = self._eval(draws, data)
128
+ posterior_evals, variational_evals = self.__eval(draws, data)
129
129
  # Evaluate ELBO
130
130
  return jnp.mean(posterior_evals - variational_evals)
131
131
 
@@ -133,19 +133,17 @@ class NormalizingFlow(Variational):
133
133
 
134
134
  @eqx.filter_jit
135
135
  def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
136
- # Partition
137
136
  dyn, static = eqx.partition(self, self.filter_spec())
138
137
 
139
138
  @eqx.filter_grad
140
139
  @eqx.filter_jit
141
140
  def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
142
- # Combine
143
141
  self = eqx.combine(dyn, static)
144
142
 
145
143
  # Sample draws from variational distribution
146
144
  draws: Array = self.base.sample(n, key)
147
145
 
148
- posterior_evals, variational_evals = self._eval(draws, data)
146
+ posterior_evals, variational_evals = self.__eval(draws, data)
149
147
  # Evaluate ELBO
150
148
  return jnp.mean(posterior_evals - variational_evals)
151
149
 
@@ -13,15 +13,15 @@ from bayinx.dists import normal
13
13
 
14
14
  class Standard(Variational):
15
15
  """
16
- A standard normal distribution approximation to a posterior distribution.
16
+ A standard normal approximation to a posterior distribution.
17
17
 
18
18
  # Attributes
19
19
  - `dim`: Dimension of the parameter space.
20
20
  """
21
21
 
22
22
  dim: int = eqx.field(static=True)
23
- _unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
24
- _constraints: Model = eqx.field(static=True)
23
+ _unflatten: Callable[[Float[Array, "..."]], Model]
24
+ _constraints: Model
25
25
 
26
26
  def __init__(self, model: Model):
27
27
  """
@@ -31,13 +31,13 @@ class Standard(Variational):
31
31
  - `model`: A probabilistic `Model` object.
32
32
  """
33
33
  # Partition model
34
- _, self._constraints = eqx.partition(model, eqx.is_array)
34
+ params, self._constraints = eqx.partition(model, model.filter_spec())
35
35
 
36
36
  # Flatten params component
37
- _, self._unflatten = ravel_pytree(_)
37
+ params, self._unflatten = ravel_pytree(params)
38
38
 
39
39
  # Store dimension of parameter space
40
- self.dim = jnp.size(_)
40
+ self.dim = jnp.size(params)
41
41
 
42
42
  @eqx.filter_jit
43
43
  def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
@@ -7,16 +7,14 @@ from jaxtyping import Array
7
7
 
8
8
  from bayinx import Model
9
9
  from bayinx.dists import normal
10
- from bayinx.machinery.variational import MeanField, NormalizingFlow, Standard
11
- from bayinx.machinery.variational.flows.affine import Affine
12
- from bayinx.machinery.variational.flows.planar import Planar
13
- from bayinx.machinery.variational.flows.radial import Radial
10
+ from bayinx.mhx.vi import MeanField, NormalizingFlow, Standard
11
+ from bayinx.mhx.vi.flows import Affine, Planar, Radial
14
12
 
15
13
 
16
14
  # Tests ----
17
15
  @pytest.mark.parametrize("var_draws", [1, 10, 100])
18
16
  def test_meanfield(benchmark, var_draws):
19
- # Construct model
17
+ # Construct model definition
20
18
  class NormalDist(Model):
21
19
  params: Dict[str, Array]
22
20
  constraints: Dict[str, Callable[[Array], Array]]
@@ -35,6 +33,7 @@ def test_meanfield(benchmark, var_draws):
35
33
  normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0))
36
34
  )
37
35
 
36
+ # Construct model
38
37
  model = NormalDist()
39
38
 
40
39
  # Construct meanfield variational
@@ -42,7 +41,7 @@ def test_meanfield(benchmark, var_draws):
42
41
 
43
42
  # Optimize variational distribution
44
43
  def benchmark_fit():
45
- vari.fit(10000, var_draws = var_draws)
44
+ vari.fit(10000, var_draws=var_draws)
46
45
 
47
46
  benchmark(benchmark_fit)
48
47
  vari = vari.fit(10000)
@@ -55,7 +54,7 @@ def test_meanfield(benchmark, var_draws):
55
54
 
56
55
  @pytest.mark.parametrize("var_draws", [1, 10, 100])
57
56
  def test_affine(benchmark, var_draws):
58
- # Construct model
57
+ # Construct model definition
59
58
  class NormalDist(Model):
60
59
  params: Dict[str, Array]
61
60
  constraints: Dict[str, Callable[[Array], Array]]
@@ -74,6 +73,7 @@ def test_affine(benchmark, var_draws):
74
73
  normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0))
75
74
  )
76
75
 
76
+ # Construct model
77
77
  model = NormalDist()
78
78
 
79
79
  # Construct normalizing flow variational
@@ -81,7 +81,7 @@ def test_affine(benchmark, var_draws):
81
81
 
82
82
  # Optimize variational distribution
83
83
  def benchmark_fit():
84
- vari.fit(10000, var_draws = var_draws)
84
+ vari.fit(10000, var_draws=var_draws)
85
85
 
86
86
  benchmark(benchmark_fit)
87
87
  vari = vari.fit(10000)
@@ -94,7 +94,7 @@ def test_affine(benchmark, var_draws):
94
94
 
95
95
  @pytest.mark.parametrize("var_draws", [1, 10, 100])
96
96
  def test_flows(benchmark, var_draws):
97
- # Construct model
97
+ # Construct model definition
98
98
  class NormalDist(Model):
99
99
  params: Dict[str, Array]
100
100
  constraints: Dict[str, Callable[[Array], Array]]
@@ -113,14 +113,17 @@ def test_flows(benchmark, var_draws):
113
113
  normal.logprob(x=params["mu"], mu=jnp.array(10.0), sigma=jnp.array(1.0))
114
114
  )
115
115
 
116
+ # Construct model
116
117
  model = NormalDist()
117
118
 
118
119
  # Construct normalizing flow variational
119
- vari = NormalizingFlow(Standard(model), [Planar(2), Radial(2), Planar(2), Radial(2), Planar(2)], model)
120
+ vari = NormalizingFlow(
121
+ Standard(model), [Planar(2), Radial(2), Planar(2), Radial(2), Planar(2)], model
122
+ )
120
123
 
121
124
  # Optimize variational distribution
122
125
  def benchmark_fit():
123
- vari.fit(10000, var_draws = var_draws)
126
+ vari.fit(10000, var_draws=var_draws)
124
127
 
125
128
  benchmark(benchmark_fit)
126
129
  vari = vari.fit(10000)
@@ -17,13 +17,17 @@ wheels = [
17
17
 
18
18
  [[package]]
19
19
  name = "bayinx"
20
- version = "0.2.1"
20
+ version = "0.2.2"
21
21
  source = { editable = "." }
22
22
  dependencies = [
23
23
  { name = "equinox" },
24
24
  { name = "jax" },
25
25
  { name = "jaxtyping" },
26
26
  { name = "optax" },
27
+ ]
28
+
29
+ [package.dev-dependencies]
30
+ dev = [
27
31
  { name = "pytest" },
28
32
  { name = "pytest-benchmark" },
29
33
  ]
@@ -34,6 +38,10 @@ requires-dist = [
34
38
  { name = "jax", specifier = ">=0.4.38" },
35
39
  { name = "jaxtyping", specifier = ">=0.2.36" },
36
40
  { name = "optax", specifier = ">=0.2.4" },
41
+ ]
42
+
43
+ [package.metadata.requires-dev]
44
+ dev = [
37
45
  { name = "pytest", specifier = ">=8.3.5" },
38
46
  { name = "pytest-benchmark", specifier = ">=5.1.0" },
39
47
  ]
@@ -96,11 +104,11 @@ epy = [
96
104
 
97
105
  [[package]]
98
106
  name = "iniconfig"
99
- version = "2.0.0"
107
+ version = "2.1.0"
100
108
  source = { registry = "https://pypi.org/simple" }
101
- sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 }
109
+ sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 }
102
110
  wheels = [
103
- { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
111
+ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 },
104
112
  ]
105
113
 
106
114
  [[package]]
@@ -1,68 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Callable, Dict, Self, Tuple
3
-
4
- import equinox as eqx
5
- from jaxtyping import Array, Float
6
-
7
- from bayinx.core.utils import __MyMeta
8
-
9
-
10
- class Flow(eqx.Module, metaclass=__MyMeta):
11
- """
12
- A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
13
-
14
- # Attributes
15
- - `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
16
- - `constraints`: A dictionary of functions that constrain their corresponding parameter.
17
- """
18
-
19
- params: Dict[str, Float[Array, "..."]]
20
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
21
-
22
- @abstractmethod
23
- def forward(self, draws: Array) -> Array:
24
- """
25
- Computes the forward transformation of `draws`.
26
- """
27
- pass
28
-
29
- @abstractmethod
30
- def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
31
- """
32
- Computes the log-absolute-determinant of the Jacobian at `draws` and applies the forward transformation.
33
-
34
- # Returns
35
- A tuple of JAX Arrays containing the log-absolute-determinant of the Jacobians and transformed draws.
36
- """
37
- pass
38
-
39
- def __init_subclass__(cls):
40
- # Add contrain_pars method
41
- def constrain_pars(self: Self):
42
- """
43
- Constrain `params` to the appropriate domain.
44
-
45
- # Returns
46
- A dictionary of transformed JAX Arrays representing the constrained parameters.
47
- """
48
- t_params = self.params
49
-
50
- for par, map in self.constraints.items():
51
- t_params[par] = map(t_params[par])
52
-
53
- return t_params
54
-
55
- cls.constrain_pars = eqx.filter_jit(constrain_pars)
56
-
57
- # Add transform_pars method if not present
58
- if not callable(getattr(cls, "transform_pars", None)):
59
- def transform_pars(self: Self) -> Dict[str, Array]:
60
- """
61
- Apply a custom transformation to `params` if needed.
62
-
63
- # Returns
64
- A dictionary of transformed JAX Arrays representing the transformed parameters.
65
- """
66
- return self.constrain_pars()
67
-
68
- cls.transform_pars = eqx.filter_jit(transform_pars)
@@ -1,5 +0,0 @@
1
- from bayinx.machinery.variational.meanfield import MeanField as MeanField
2
- from bayinx.machinery.variational.normalizing_flow import (
3
- NormalizingFlow as NormalizingFlow,
4
- )
5
- from bayinx.machinery.variational.standard import Standard as Standard
@@ -1,3 +0,0 @@
1
- from bayinx.machinery.variational.flows.affine import Affine as Affine
2
- from bayinx.machinery.variational.flows.planar import Planar as Planar
3
- from bayinx.machinery.variational.flows.radial import Radial as Radial
@@ -1,76 +0,0 @@
1
- from functools import partial
2
- from typing import Callable, Dict, Tuple
3
-
4
- import equinox as eqx
5
- import jax
6
- import jax.numpy as jnp
7
- import jax.random as jr
8
- from jaxtyping import Array, Float, Scalar
9
-
10
- from bayinx.core import Flow
11
-
12
-
13
- class Sylvester(Flow):
14
- """
15
- A sylvester flow.
16
-
17
- # Attributes
18
- - `params`: A dictionary containing the JAX Arrays representing the flow parameters.
19
- - `constraints`: A dictionary of constraining transformations.
20
- """
21
-
22
- params: Dict[str, Float[Array, "..."]]
23
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
24
- eqx.field(static=True)
25
- )
26
-
27
- def __init__(self, dim: int, key=jr.PRNGKey(0)):
28
- """
29
- Initializes a planar flow.
30
-
31
- # Parameters
32
- - `dim`: The dimension of the parameter space.
33
- """
34
- self.params = {
35
- "u": jr.normal(key, (dim,)),
36
- "w": jr.normal(key, (dim,)),
37
- "b": jr.normal(key, (1,)),
38
- }
39
- self.constraints = {}
40
-
41
- @eqx.filter_jit
42
- @partial(jax.vmap, in_axes=(None, 0))
43
- def forward(self, draws: Array) -> Array:
44
- params = self.constrain_pars()
45
-
46
- # Extract parameters
47
- w: Array = params["w"]
48
- u: Array = params["u"]
49
- b: Array = params["b"]
50
-
51
- # Compute forward transformation
52
- draws = draws + u * jnp.tanh(draws.dot(w) + b)
53
-
54
- return draws
55
-
56
- @eqx.filter_jit
57
- @partial(jax.vmap, in_axes=(None, 0))
58
- def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
59
- params = self.constrain_pars()
60
-
61
- # Extract parameters
62
- w: Array = params["w"]
63
- u: Array = params["u"]
64
- b: Array = params["b"]
65
-
66
- # Compute shared intermediates
67
- x: Array = draws.dot(w) + b
68
-
69
- # Compute forward transformation
70
- draws = draws + u * jnp.tanh(x)
71
-
72
- # Compute ladj
73
- h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
74
- ladj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
75
-
76
- return ladj, draws
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes