bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,246 @@
1
+ from typing import Callable, Optional, Self, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.flatten_util as jfu
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ import jax.tree_util as jtu
8
+ from jax.lax import scan
9
+ from jaxtyping import Array, PRNGKeyArray, Scalar
10
+
11
+ from bayinx.core.flow import FlowLayer
12
+ from bayinx.core.variational import M, Variational
13
+
14
+
15
+ class NormalizingFlow(Variational[M]):
16
+ """
17
+ An ordered collection of diffeomorphisms that map a base distribution to a
18
+ normalized approximation of a posterior distribution.
19
+
20
+ # Attributes
21
+ - `dim`: The dimension of the support.
22
+ - `base`: A base variational distribution.
23
+ - `flows`: An ordered collection of continuously parameterized diffeomorphisms.
24
+ """
25
+ flows: list[FlowLayer]
26
+ base: Variational[M]
27
+
28
+ def __init__(
29
+ self,
30
+ base: Variational[M],
31
+ flows: list[FlowLayer],
32
+ model: Optional[M] = None,
33
+ _static: Optional[M] = None,
34
+ _unflatten: Optional[Callable[[Array], M]] = None
35
+ ):
36
+ """
37
+ Constructs an unoptimized normalizing flow posterior approximation.
38
+
39
+ # Parameters
40
+ - `base`: The base variational distribution.
41
+ - `flows`: A list of flows.
42
+ - `model`: A probabilistic `Model` object.
43
+ """
44
+ if model is not None:
45
+ # Partition model
46
+ params, self._static = eqx.partition(model, model.filter_spec)
47
+
48
+ # Flatten params component
49
+ _, self._unflatten = jfu.ravel_pytree(params)
50
+ elif _static is not None and _unflatten is not None:
51
+ self._static = _static
52
+ self._unflatten = _unflatten
53
+ else:
54
+ raise ValueError("Either 'model' or '_static' and '_unflatten' must be specified.")
55
+
56
+ self.dim = base.dim
57
+ self.base = base
58
+ self.flows = flows
59
+
60
+ @property
61
+ def filter_spec(self) -> Self:
62
+ # Generate empty specification
63
+ filter_spec: Self = jtu.tree_map(lambda _: False, self)
64
+
65
+ # Specify variational parameters based on each flow's filter spec.
66
+ filter_spec: Self = eqx.tree_at(
67
+ lambda vari: vari.flows,
68
+ filter_spec,
69
+ replace=[flow.filter_spec for flow in self.flows],
70
+ )
71
+
72
+ return filter_spec
73
+
74
+ @eqx.filter_jit
75
+ def sample(self, n: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Array:
76
+ # Sample from the base distribution
77
+ draws: Array = self.base.sample(n, key)
78
+
79
+ assert len(draws.shape) == 2
80
+
81
+ # Apply forward transformations
82
+ for map in self.flows:
83
+ draws = map.forward(draws)
84
+
85
+ assert len(draws.shape) == 2
86
+
87
+ return draws
88
+
89
+ @eqx.filter_jit
90
+ def eval(self, draws: Array) -> Array:
91
+ raise RuntimeError("Evaluating the variational density for a normalizing flow requires an analytic inverse to exist, which many useful flows do not have. Therefore, do not use this method.")
92
+ return jnp.full(draws.shape[0], jnp.nan)
93
+
94
+ @eqx.filter_jit
95
+ def __eval(self, draws: Array) -> Tuple[Array, Array]:
96
+ """
97
+ Evaluate the posterior and variational densities together at the
98
+ transformed `draws` to avoid extra compute.
99
+
100
+ # Parameters
101
+ - `draws`: Draws from the base variational distribution.
102
+
103
+ # Returns
104
+ The posterior and variational densities as JAX Arrays.
105
+ """
106
+ # Evaluate base density
107
+ variational_evals: Array = self.base.eval(draws)
108
+
109
+ # Shape checks
110
+ assert len(variational_evals.shape) == 1
111
+ assert len(draws.shape) == 2
112
+
113
+ for map in self.flows:
114
+ # Apply transformation
115
+ draws, ljas = map.forward_and_adjust(draws)
116
+ assert len(draws.shape) == 2
117
+ assert len(ljas.shape) == 1
118
+
119
+ # Adjust variational density
120
+ variational_evals = variational_evals - ljas
121
+
122
+ # Evaluate posterior at final variational draws
123
+ posterior_evals = self.eval_model(draws)
124
+
125
+ # Shape checks
126
+ assert len(posterior_evals.shape) == 1
127
+ assert len(variational_evals.shape) == 1
128
+ assert posterior_evals.shape == variational_evals.shape
129
+
130
+ return posterior_evals, variational_evals
131
+
132
+ @eqx.filter_jit
133
+ def elbo(self, n: int, batch_size: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Scalar:
134
+ dyn, static = eqx.partition(self, self.filter_spec)
135
+
136
+ def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
137
+ self = eqx.combine(dyn, static)
138
+
139
+ # Split keys
140
+ keys = jr.split(key, n // batch_size)
141
+
142
+ # Split ELBO calculation into batches
143
+ def batched_elbo(carry: None, key: PRNGKeyArray) -> Tuple[None, Array]:
144
+ # Draw from variational distribution
145
+ draws: Array = self.base.sample(batch_size, key)
146
+
147
+ # Evaluate posterior and variational densities
148
+ batched_post_evals, batched_vari_evals = self.__eval(draws)
149
+
150
+ # Compute ELBO estimate
151
+ batched_elbo_evals: Array = batched_post_evals - batched_vari_evals
152
+
153
+ return None, batched_elbo_evals
154
+
155
+ elbo_evals = scan(
156
+ batched_elbo,
157
+ init=None,
158
+ xs=keys,
159
+ length=n // batch_size
160
+ )[1]
161
+
162
+ # Compute average of ELBO estimates
163
+ return jnp.mean(elbo_evals)
164
+
165
+ return elbo(dyn, n, key)
166
+
167
+ @eqx.filter_jit
168
+ def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Self:
169
+ dyn, static = eqx.partition(self, self.filter_spec)
170
+
171
+ # Define ELBO function
172
+ def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
173
+ self = eqx.combine(dyn, static)
174
+
175
+ # Split key
176
+ keys = jr.split(key, n // batch_size)
177
+
178
+ # Split ELBO calculation into batches
179
+ def batched_elbo(carry: None, key: PRNGKeyArray) -> Tuple[None, Array]:
180
+ # Draw from variational distribution
181
+ draws: Array = self.base.sample(batch_size, key)
182
+
183
+ # Evaluate posterior and variational densities
184
+ batched_post_evals, batched_vari_evals = self.__eval(draws)
185
+
186
+ # Compute ELBO estimate
187
+ batched_elbo_evals: Array = batched_post_evals - batched_vari_evals
188
+
189
+ return None, batched_elbo_evals
190
+
191
+ elbo_evals = scan(
192
+ batched_elbo,
193
+ init=None,
194
+ xs=keys,
195
+ length=n // batch_size
196
+ )[1]
197
+
198
+ # Compute average of ELBO estimates
199
+ return jnp.mean(elbo_evals)
200
+
201
+ # Map to its gradient
202
+ elbo_grad: Callable[
203
+ [Self, int, PRNGKeyArray], Self
204
+ ] = eqx.filter_grad(elbo)
205
+
206
+ return elbo_grad(dyn, n, key)
207
+
208
+ @eqx.filter_jit
209
+ def elbo_and_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Tuple[Scalar, Self]:
210
+ dyn, static = eqx.partition(self, self.filter_spec)
211
+
212
+ def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
213
+ self = eqx.combine(dyn, static)
214
+
215
+ # Split keys
216
+ keys = jr.split(key, n // batch_size)
217
+
218
+ # Split ELBO calculation into batches
219
+ def batched_elbo(carry: None, key: PRNGKeyArray) -> Tuple[None, Array]:
220
+ # Draw from variational distribution
221
+ draws: Array = self.base.sample(batch_size, key)
222
+
223
+ # Evaluate posterior and variational densities
224
+ batched_post_evals, batched_vari_evals = self.__eval(draws)
225
+
226
+ # Compute ELBO estimate
227
+ batched_elbo_evals: Array = batched_post_evals - batched_vari_evals
228
+
229
+ return None, batched_elbo_evals
230
+
231
+ elbo_evals = scan(
232
+ batched_elbo,
233
+ init=None,
234
+ xs=keys,
235
+ length=n // batch_size
236
+ )[1]
237
+
238
+ # Compute average of ELBO estimates
239
+ return jnp.mean(elbo_evals)
240
+
241
+ # Map to its value & gradient
242
+ elbo_and_grad: Callable[
243
+ [Self, int, PRNGKeyArray], Tuple[Scalar, Self]
244
+ ] = eqx.filter_value_and_grad(elbo)
245
+
246
+ return elbo_and_grad(dyn, n, key)
bayinx/vi/standard.py ADDED
@@ -0,0 +1,95 @@
1
+ from typing import Self, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.numpy as jnp
5
+ import jax.random as jr
6
+ import jax.tree_util as jtu
7
+ from jax.flatten_util import ravel_pytree
8
+ from jaxtyping import Array, PRNGKeyArray, PyTree, Scalar
9
+
10
+ from bayinx.core.variational import M, Variational
11
+ from bayinx.dists import normal
12
+
13
+
14
+ class Standard(Variational[M]):
15
+ """
16
+ A standard normal approximation to a posterior distribution.
17
+
18
+ # Attributes
19
+ - `dim`: The dimension of the support.
20
+ """
21
+
22
+ def __init__(self, model: M):
23
+ """
24
+ Constructs a standard normal approximation to a posterior distribution.
25
+
26
+ # Parameters
27
+ - `model`: A probabilistic `Model` object.
28
+ """
29
+ # Partition model
30
+ params, self._static = eqx.partition(model, model.filter_spec)
31
+
32
+ # Flatten params component
33
+ params, self._unflatten = ravel_pytree(params)
34
+
35
+ # Store dimension of parameter space
36
+ self.dim = jnp.size(params)
37
+
38
+
39
+ @eqx.filter_jit
40
+ def sample(self, n: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Array:
41
+ # Sample variational draws
42
+ draws: Array = jr.normal(key=key, shape=(n, self.dim))
43
+
44
+ # Shape checks
45
+ assert len(draws.shape) == 2
46
+
47
+ return draws
48
+
49
+ @eqx.filter_jit
50
+ def eval(self, draws: Array) -> Array:
51
+ return normal.logprob(
52
+ x=draws,
53
+ mu=jnp.array(0.0),
54
+ sigma=jnp.array(1.0),
55
+ ).sum(axis=1)
56
+
57
+ @property
58
+ def filter_spec(self):
59
+ filter_spec = jtu.tree_map(lambda _: False, self)
60
+
61
+ return filter_spec
62
+
63
+ @eqx.filter_jit
64
+ def elbo(self, n: int, batch_size: int, key: PRNGKeyArray) -> Scalar:
65
+ dyn, static = eqx.partition(self, self.filter_spec)
66
+
67
+ @eqx.filter_jit
68
+ def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
69
+ vari = eqx.combine(dyn, static)
70
+
71
+ # Sample draws from variational distribution
72
+ draws: Array = vari.sample(n, key)
73
+
74
+ # Evaluate posterior density for each draw
75
+ posterior_evals: Array = vari.eval_model(draws)
76
+
77
+ # Evaluate variational density for each draw
78
+ variational_evals: Array = vari.eval(draws)
79
+
80
+ # Evaluate ELBO
81
+ return jnp.mean(posterior_evals - variational_evals)
82
+
83
+ return elbo(dyn, n, key)
84
+
85
+ @eqx.filter_jit
86
+ def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Self:
87
+ raise RuntimeError("Do not use the 'elbo_grad' method for a Standard variational approximation. It has no variational parameters.")
88
+ return self
89
+
90
+ def elbo_and_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Tuple[Scalar, PyTree]:
91
+ """
92
+ Evaluate the ELBO and its gradient.
93
+ """
94
+ raise RuntimeError("Do not use the 'elbo_and_grad' method for a Standard variational approximation. It has no variational parameters.")
95
+ return self.elbo(n, key), self
@@ -0,0 +1,93 @@
1
+ Metadata-Version: 2.4
2
+ Name: bayinx
3
+ Version: 0.5.3
4
+ Summary: Bayesian Inference with JAX
5
+ Author: Todd Pocuca
6
+ Maintainer: Todd Pocuca
7
+ License-File: LICENSE
8
+ Requires-Python: >=3.13
9
+ Requires-Dist: diffrax>=0.7.0
10
+ Requires-Dist: equinox>=0.13.2
11
+ Requires-Dist: jax>=0.8.0
12
+ Requires-Dist: jaxtyping>=0.2.36
13
+ Requires-Dist: optax>=0.2.4
14
+ Description-Content-Type: text/markdown
15
+
16
+ # Bayinx: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
17
+
18
+ Bayinx is an embedded probabilistic programming language in Python, powered by
19
+ [JAX](https://mc-stan.org/). It is heavily inspired by and aims to have
20
+ feature parity with [Stan](https://mc-stan.org/), but extends the types of
21
+ objects you can work with and focuses on normalizing flows variational
22
+ inference for sampling.
23
+
24
+
25
+ ## Coming From Stan
26
+
27
+ There are a few differences between the syntax of Bayinx and Stan.
28
+ First, as Bayinx is embedded in Python, model definitions are Pythonic and
29
+ rely on you defining a class that inherits from the `Model` base class:
30
+
31
+ ```py
32
+ class MyModel(Model, init=False):
33
+ # ...
34
+ ```
35
+
36
+ > Note: Users should specify `init=False` to avoid static type checkers from
37
+ raising irrelevant errors, but more importantly it should remind you that
38
+ you should **NOT** implement your own `__init__` method!
39
+
40
+ The `data` and `parameters` blocks in Stan are then combined into the attribute
41
+ definitions with Bayinx. For example, if we are modelling a simple normal distribution
42
+ with an unknown mean and variance 1, then we might write:
43
+
44
+ ```py
45
+ class MyModel(Model, init=False):
46
+ mean: Continuous[Array] = define(shape = ()) # a scalar mean parameter
47
+ x: Observed[Array] = define(shape = 'n_obs') # a vector of observed values
48
+
49
+ # ...
50
+ ```
51
+
52
+ The `model` block in Stan is then defined by implementing the `model` method with Bayinx:
53
+
54
+ ```py
55
+ class MyModel(Model, init=False):
56
+ mean: Continuous[Array] = define(shape = ())
57
+ x: Observed[Array] = define(shape = 'n_obs')
58
+
59
+ def model(self, target):
60
+ # Equivalent to 'x ~ normal(mean, 1.0)' in Stan
61
+ self.x << Normal(self.mean, 1.0)
62
+
63
+ return target
64
+ ```
65
+
66
+ Notice that the `~` operator in Stan has been replaced with `<<`, and to reference nodes of a model you must work with `self`.
67
+
68
+ > Note: Bayinx does not currently have something similar to `transformed data` or `transformed parameters`, however that is likely to be included in a future release.
69
+
70
+ You can then construct the variational approximation to the posterior:
71
+
72
+ ```py
73
+ import bayinx as byx
74
+ from bayinx.flows import DiagAffine
75
+ import jax.numpy as jnp
76
+
77
+ # Fit variational approximation
78
+ posterior = byx.Posterior(MyModel, n_obs = 3, x = jnp.array([-1.0, 0.0, 1.0]))
79
+ posterior.configure(flowspecs = [DiagAffine()])
80
+ posterior.fit()
81
+ ```
82
+
83
+ This approximation can then be worked with by sampling nodes:
84
+
85
+ ```py
86
+ mean_draws = posterior.sample('mean', 10000)
87
+ print(mean_draws.mean())
88
+ ```
89
+
90
+
91
+ ## Roadmap
92
+ - [ ] Implement OT-Flow: https://arxiv.org/abs/2006.00104
93
+ - [ ]
@@ -0,0 +1,44 @@
1
+ bayinx/__init__.py,sha256=hhpQ8JM9kzJLUqeV_72ZbUf9gxTFH-SFkq-GyaYscvI,126
2
+ bayinx/ops.py,sha256=QsKrk2tMOxcYXuUidQyiA1e5KbV5WH80XUvM-PV8wZc,2666
3
+ bayinx/posterior.py,sha256=ab7Ubx3BDTkejaNciWDR-J9GEgbygRZPUgc23rz7YOg,6760
4
+ bayinx/constraints/__init__.py,sha256=E9WFI5xPAuVOFTzaLKgG2uV8k5Pho9w0mlmsMYrkSSI,154
5
+ bayinx/constraints/identity.py,sha256=IMR2WHB_GL89IOgkY7sOOScMyMIJkkGpMWpbUVkfOUY,629
6
+ bayinx/constraints/interval.py,sha256=OeM_aED8pZPdhpyrxOUQjg7IXWjXX2GiJGOoocaU7WI,1811
7
+ bayinx/constraints/lower.py,sha256=xXP-vrpQwnBSUN_1f1qYSSKApodfHfveI0f86h6go_k,1517
8
+ bayinx/constraints/upper.py,sha256=RWppD7SKP6KciQt_Wrd_w59vIykm_vaUtE82E9UEcBs,1529
9
+ bayinx/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ bayinx/core/constraint.py,sha256=w5Roomp-YITFyIcfXpd_P18JYEdYo6yVcNFaYk0MiEE,932
11
+ bayinx/core/context.py,sha256=kSvZFHyveUAkEr-gokZjsG6zp7QqrzcXePFcaTtBFqM,1074
12
+ bayinx/core/distribution.py,sha256=d0dMxzh5NxOv9-IYqV9D-amBcGFLsHsxn8D598IRKEI,990
13
+ bayinx/core/flow.py,sha256=SSGxdLAZLuBUQ8g1D6-QwZPjWuVGTK0fzH6X8CAgFxw,2736
14
+ bayinx/core/model.py,sha256=3FiqNPcLI7QTPhwvH-orYgcKoitozqT0kGxMnoa0dV4,7808
15
+ bayinx/core/node.py,sha256=5nRi3YzKGTUhl6-AOPNkYmYX-xb8DjuOt_XLMVX9k8Q,6134
16
+ bayinx/core/types.py,sha256=of52_tUQurdyfbSzdHjQ0EJUWl8DFEW2Pia9Lr_n3Lk,378
17
+ bayinx/core/utils.py,sha256=_2CxYev5Gu85wMqqqydENmnygvYJh2zB76lBD2-s3y0,3519
18
+ bayinx/core/variational.py,sha256=J7vwGKIulVdaZCoqf7XaRt0Ku7LPzNaqa4xOH_fQ-Nw,4800
19
+ bayinx/dists/__init__.py,sha256=7-nWGyK5obf3lxSlAv5JbzEuQVrRxIC1tU_HseMUtXU,218
20
+ bayinx/dists/bernoulli.py,sha256=vHatAx6FCWCjm-epKu-QqlHjrUN_RXkidGWmzTlurOk,5178
21
+ bayinx/dists/binomial.py,sha256=ap6KN4qW_AbGMZNdz5_8VHtlQYRmIZGlhy9-Z6vulsM,6220
22
+ bayinx/dists/exponential.py,sha256=229Ank-WcJsWhzEY93PvclFWz11VDQLaM1uLEAstf7Y,5705
23
+ bayinx/dists/normal.py,sha256=mnAG2GMjyOdHQ6VAdkDrVJ-d3Rzg2MYqLgRq6-LWwBI,5948
24
+ bayinx/dists/poisson.py,sha256=iBcREtJg0xjXhxKSh0lMTmr7tWAfsO0O3FO-Ft8a6vw,5496
25
+ bayinx/flows/__init__.py,sha256=SozrytzAbeTckrcH_zpiS2bjnhkPsrwZc3nYsUhP4YE,299
26
+ bayinx/flows/diagaffine.py,sha256=cMc2QyC9xipy29FQc7Rzwjo0ga8ajHhxA4hwAIhthSg,3402
27
+ bayinx/flows/fullaffine.py,sha256=b6eXUu72ft_Jf2Kx1wUmdMFGZzn-Bz4J76Hpz7ZKoIw,3594
28
+ bayinx/flows/lowrankaffine.py,sha256=heCWs2iZTTf1BXL6M16yMW8p896Bf15TkMhHv-RaObQ,4769
29
+ bayinx/flows/otflow.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
+ bayinx/flows/planar.py,sha256=NLtDcXrT6sVJP4YW0dtgdkins_iEOmvLg0sYsLKUndU,4280
31
+ bayinx/flows/radial.py,sha256=lkIotasV1VteGgHOQxkN2UXRH-ja4oUkU4sXYB52JdQ,6
32
+ bayinx/flows/sylvester.py,sha256=K4Dc8QaKE0WQDRNCMhNdyuBjPnfhljQYTIFz2KBm-h8,6523
33
+ bayinx/nodes/__init__.py,sha256=HWb4Yi-wSd_Fr7AywsAed8EezJXemKagZFRlqypu_qY,141
34
+ bayinx/nodes/continuous.py,sha256=k5GkM6MO6ag4r9wXY48c1j3Hjc0KSalq4xUY9nIz07M,1817
35
+ bayinx/nodes/observed.py,sha256=9SiOHAeLVYSAaFJFuLgF6yiwk8cQgSreSXs5ZvLHaRE,943
36
+ bayinx/nodes/stochastic.py,sha256=2G89o4vC0I0Pw_2zLStaEHJ5OXUSAbVA0hyenyo0aIs,614
37
+ bayinx/vi/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
+ bayinx/vi/meanfield.py,sha256=N9YeItj9Dk6OpdtBQ4mm-OE4nQXJ-xtLDlbK1t9EP5M,3805
39
+ bayinx/vi/normalizing_flow.py,sha256=LsTZgi6ehDZ4LLuHfnLeaGEo2Z0_CZhRePcoNCXJgyo,8226
40
+ bayinx/vi/standard.py,sha256=r4dydWZMQv5QyCPjBLG3eyU3wczxFUu3pqdA2569eoI,2941
41
+ bayinx-0.5.3.dist-info/METADATA,sha256=ptDkp2X80xOz0pkKajne-pJSllGH2fNHLZzwYYwJt38,2937
42
+ bayinx-0.5.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
43
+ bayinx-0.5.3.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
44
+ bayinx-0.5.3.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,28 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Tuple
3
-
4
- import equinox as eqx
5
- from jaxtyping import Scalar
6
-
7
- from bayinx.core._parameter import Parameter
8
-
9
-
10
- class Constraint(eqx.Module):
11
- """
12
- Abstract base class for defining parameter constraints.
13
- """
14
-
15
- @abstractmethod
16
- def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
17
- """
18
- Applies the constraining transformation to a parameter and computes the log-absolute-Jacobian of the transformation.
19
-
20
- # Parameters
21
- - `x`: The unconstrained `Parameter`.
22
-
23
- # Returns
24
- A tuple containing:
25
- - The constrained `Parameter`.
26
- - A scalar Array representing the log-absolute-Jacobian of the transformation.
27
- """
28
- pass
bayinx/core/_flow.py DELETED
@@ -1,80 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Callable, Dict, Self, Tuple
3
-
4
- import equinox as eqx
5
- import jax.tree_util as jtu
6
- from jaxtyping import Array, Float
7
-
8
-
9
- class Flow(eqx.Module):
10
- """
11
- An abstract base class for a flow(of a normalizing flow).
12
-
13
- # Attributes
14
- - `params`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
15
- - `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
16
- """
17
-
18
- params: Dict[str, Float[Array, "..."]]
19
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
20
-
21
- @abstractmethod
22
- def forward(self, draws: Array) -> Array:
23
- """
24
- Computes the forward transformation of `draws`.
25
- """
26
- pass
27
-
28
- @abstractmethod
29
- def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
30
- """
31
- Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
32
-
33
- # Returns
34
- A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
35
- """
36
- pass
37
-
38
- # Default filter specification
39
- @property
40
- @eqx.filter_jit
41
- def filter_spec(self):
42
- """
43
- Generates a filter specification to subset relevant parameters for the flow.
44
- """
45
- # Generate empty specification
46
- filter_spec = jtu.tree_map(lambda _: False, self)
47
-
48
- # Specify JAX Array parameters
49
- filter_spec = eqx.tree_at(
50
- lambda flow: flow.params,
51
- filter_spec,
52
- replace=jtu.tree_map(eqx.is_array, self.params),
53
- )
54
-
55
- return filter_spec
56
-
57
- @eqx.filter_jit
58
- def constrain_params(self: Self):
59
- """
60
- Constrain `params` to the appropriate domain.
61
-
62
- # Returns
63
- A dictionary of transformed JAX Arrays representing the constrained parameters.
64
- """
65
- t_params = self.params
66
-
67
- for par, map in self.constraints.items():
68
- t_params[par] = map(t_params[par])
69
-
70
- return t_params
71
-
72
- @eqx.filter_jit
73
- def transform_params(self: Self) -> Dict[str, Array]:
74
- """
75
- Apply a custom transformation to `params` if needed.
76
-
77
- # Returns
78
- A dictionary of transformed JAX Arrays representing the transformed parameters.
79
- """
80
- return self.constrain_params()
bayinx/core/_model.py DELETED
@@ -1,98 +0,0 @@
1
- from abc import abstractmethod
2
- from dataclasses import field, fields
3
- from typing import Any, Self, Tuple
4
-
5
- import equinox as eqx
6
- import jax.numpy as jnp
7
- import jax.tree as jt
8
- from jaxtyping import Scalar
9
-
10
- from ._constraint import Constraint
11
- from ._parameter import Parameter
12
-
13
-
14
- def constrain(constraint: Constraint):
15
- """Defines constraint metadata."""
16
- return field(metadata={"constraint": constraint})
17
-
18
-
19
- class Model(eqx.Module):
20
- """
21
- An abstract base class used to define probabilistic models.
22
-
23
- Annotate parameter attributes with `Parameter`.
24
-
25
- Include constraints by setting them equal to `constrain(Constraint)`.
26
- """
27
-
28
- @abstractmethod
29
- def eval(self, data: Any) -> Scalar:
30
- pass
31
-
32
- # Default filter specification
33
- @property
34
- @eqx.filter_jit
35
- def filter_spec(self) -> Self:
36
- """
37
- Generates a filter specification to subset relevant parameters for the model.
38
- """
39
- # Generate empty specification
40
- filter_spec: Self = jt.map(lambda _: False, self)
41
-
42
- for f in fields(self):
43
- # Extract attribute from field
44
- attr = getattr(self, f.name)
45
-
46
- # Check if attribute is a parameter
47
- if isinstance(attr, Parameter):
48
- # Update filter specification for parameter
49
- filter_spec = eqx.tree_at(
50
- lambda model: getattr(model, f.name),
51
- filter_spec,
52
- replace=attr.filter_spec,
53
- )
54
-
55
- return filter_spec
56
-
57
- @eqx.filter_jit
58
- def constrain_params(self) -> Tuple[Self, Scalar]:
59
- """
60
- Constrain parameters to the appropriate domain.
61
-
62
- # Returns
63
- A constrained `Model` object and the adjustment to the posterior.
64
- """
65
- constrained: Self = self
66
- target: Scalar = jnp.array(0.0)
67
-
68
- for f in fields(self):
69
- # Extract attribute
70
- attr = getattr(self, f.name)
71
-
72
- # Check if constrained parameter
73
- if isinstance(attr, Parameter) and "constraint" in f.metadata:
74
- param = attr
75
- constraint = f.metadata["constraint"]
76
-
77
- # Apply constraint
78
- param, laj = constraint.constrain(param)
79
-
80
- # Update parameters for constrained model
81
- constrained = eqx.tree_at(
82
- lambda model: getattr(model, f.name), constrained, replace=param
83
- )
84
-
85
- # Adjust posterior density
86
- target += laj
87
-
88
- return constrained, target
89
-
90
- @eqx.filter_jit
91
- def transform_params(self) -> Tuple[Self, Scalar]:
92
- """
93
- Apply a custom transformation to parameters if needed(defaults to constrained parameters).
94
-
95
- # Returns
96
- A transformed `Model` object and the adjustment to the posterior.
97
- """
98
- return self.constrain_params()