bayinx 0.3.14__py3-none-any.whl → 0.3.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- from typing import Any, Tuple, TypeVar
1
+ from typing import Any, Callable, Tuple, TypeVar
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.lax as lax
@@ -20,7 +20,7 @@ def optimize_model(
20
20
  tolerance: float = 1e-4,
21
21
  ) -> M:
22
22
  """
23
- Optimize the dynamic parameters of the model.
23
+ Optimize the parameters of the model.
24
24
 
25
25
  # Parameters
26
26
  - `max_iters`: Maximum number of iterations for the optimization loop.
@@ -33,14 +33,13 @@ def optimize_model(
33
33
  dyn, static = eqx.partition(model, model.filter_spec)
34
34
 
35
35
  # Derive gradient for posterior
36
- @eqx.filter_jit
37
- @eqx.filter_grad
38
- def eval_grad(dyn: M):
36
+ def eval(dyn: M) -> Scalar:
39
37
  # Reconstruct model
40
38
  model: M = eqx.combine(dyn, static)
41
39
 
42
40
  # Evaluate posterior
43
41
  return model.eval(data)
42
+ eval_grad: Callable[[M], M] = eqx.filter_jit(eqx.filter_grad(eval))
44
43
 
45
44
  # Construct scheduler
46
45
  schedule: Schedule = opx.warmup_cosine_decay_schedule(
@@ -71,7 +70,7 @@ def optimize_model(
71
70
  i = i + 1
72
71
 
73
72
  # Evaluate gradient of posterior
74
- updates = eval_grad(dyn)
73
+ updates: PyTree = eval_grad(dyn)
75
74
 
76
75
  # Compute updates
77
76
  updates, opt_state = optim.update(
@@ -175,17 +175,33 @@ class Variational(eqx.Module, Generic[M]):
175
175
  data: Any = None,
176
176
  key: Key = jr.PRNGKey(0),
177
177
  ) -> Array:
178
- # Sample draws from the variational approximation
179
- draws: Array = self.sample(n, key)
178
+ # Sample a single draw to evaluate shape of output
179
+ draw: Array = self.sample(1, key)
180
+ output: Array = func(self._unflatten(draw), data)
181
+
182
+ # Allocate space for results
183
+ results: Array = jnp.zeros((n,) + output.shape, dtype=output.dtype)
184
+
185
+ @eqx.filter_jit
186
+ def body_fun(i: int, state: Tuple[Key, Array]) -> Tuple[Key, Array]:
187
+ # Unpack state
188
+ key, results = state
189
+
190
+ # Update PRNG key
191
+ next, key = jr.split(key)
192
+
193
+ # Draw from variational
194
+ draw: Array = self.sample(1, key)
180
195
 
181
- # Evaluate posterior predictive
182
- @jax.jit
183
- @jax.vmap
184
- def evaluate(draw: Array):
185
196
  # Reconstruct model
186
197
  model: M = self._unflatten(draw)
187
198
 
188
- # Evaluate
189
- return func(model, data)
199
+ # Update results with output
200
+ results = results.at[i].set(func(model, data))
201
+
202
+ return next, results
203
+
204
+ # Evaluate draws
205
+ results: Array = jax.lax.fori_loop(0, n, body_fun, (key, results))[1]
190
206
 
191
- return evaluate(draws)
207
+ return results
bayinx/dists/gamma2.py CHANGED
@@ -46,4 +46,4 @@ def logprob(
46
46
  + nu * (lax.log(nu) - lax.log(mu))
47
47
  + (nu - 1.0) * lax.log(x)
48
48
  - (x * nu / mu)
49
- ) # pyright: ignore
49
+ )
bayinx/dists/posnormal.py CHANGED
@@ -139,9 +139,6 @@ def cdf(
139
139
 
140
140
  # Returns
141
141
  The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
142
-
143
- # Notes
144
- Not numerically stable for small `x`.
145
142
  """
146
143
  # Cast to Array
147
144
  x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
@@ -160,7 +157,6 @@ def cdf(
160
157
  return evals
161
158
 
162
159
 
163
- # TODO: make numerically stable
164
160
  def logcdf(
165
161
  x: Float[ArrayLike, "..."],
166
162
  mu: Float[ArrayLike, "..."],
@@ -176,9 +172,6 @@ def logcdf(
176
172
 
177
173
  # Returns
178
174
  The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
179
-
180
- # Notes
181
- Not numerically stable for small `x`.
182
175
  """
183
176
  # Cast to Array
184
177
  x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
@@ -23,12 +23,13 @@ class MeanField(Variational, Generic[M]):
23
23
 
24
24
  var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
25
25
 
26
- def __init__(self, model: M):
26
+ def __init__(self, model: M, init_log_std: float = 0.0):
27
27
  """
28
28
  Constructs an unoptimized meanfield posterior approximation.
29
29
 
30
30
  # Parameters
31
31
  - `model`: A probabilistic `Model` object.
32
+ - `init_log_std`: The initial log-transformed standard deviation of the Gaussian approximation.
32
33
  """
33
34
  # Partition model
34
35
  params, self._constraints = eqx.partition(model, model.filter_spec)
@@ -39,7 +40,7 @@ class MeanField(Variational, Generic[M]):
39
40
  # Initialize variational parameters
40
41
  self.var_params = {
41
42
  "mean": params,
42
- "log_std": jnp.zeros(params.size, dtype=params.dtype),
43
+ "log_std": jnp.full(params.size, init_log_std, params.dtype),
43
44
  }
44
45
 
45
46
  @property
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.14
3
+ Version: 0.3.16
4
4
  Summary: Bayesian Inference with JAX
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -10,9 +10,9 @@ Requires-Dist: jaxtyping>=0.2.36
10
10
  Requires-Dist: optax>=0.2.4
11
11
  Description-Content-Type: text/markdown
12
12
 
13
- # <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
13
+ # `Bayinx`: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
14
14
 
15
- The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is equivalent to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
15
+ The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is similar to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
16
16
 
17
17
  In the short-term, I'm going to focus on:
18
18
  1) Implementing as much machinery as I feel is enough.
@@ -6,14 +6,14 @@ bayinx/core/__init__.py,sha256=Qmy0EjzqqKwI9F8rjmC9j6J8hiDw6A54yOck2WuQJkY,344
6
6
  bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
7
7
  bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
8
  bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
9
- bayinx/core/_optimization.py,sha256=dL3COMFTP0_FeD44hreNybh_UD6zqJy6sSP54ITJsBc,2605
9
+ bayinx/core/_optimization.py,sha256=mmeVUqfFARz8F7q4LRl-uEwVWzekmzh-9o7PnuvsHZk,2651
10
10
  bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
11
- bayinx/core/_variational.py,sha256=bQYiN8c4AGPt4hsNT68zN7J6o0fdsLGwVjbDyl62LnI,5639
11
+ bayinx/core/_variational.py,sha256=MbOgWs0-mkEwIYsh1Hhn7anWkQ7fAGjkr6VJybYAuR0,6185
12
12
  bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
13
13
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
14
- bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
14
+ bayinx/dists/gamma2.py,sha256=HtB60LUQdPj4yDAHme2jsHNmLfrAKWsSZnDYkxAGaOI,1548
15
15
  bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
16
- bayinx/dists/posnormal.py,sha256=w9plA1EctXwXOiY0doc4ZndjnwptbEZBHHCGdc4gviY,7292
16
+ bayinx/dists/posnormal.py,sha256=cOLCdd39DX3v8DD-seSIKNk4OfdNfaYaLzpCh_xBGyw,7150
17
17
  bayinx/dists/uniform.py,sha256=2ZQxEfAX5TFgSPuQ8joFDuFbd_NfmQ1GvmGGjusqvNQ,3461
18
18
  bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
19
19
  bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
@@ -23,7 +23,7 @@ bayinx/dists/censored/posnormal/r.py,sha256=wMDt2Am1TD376ms8B-o6PFCJZXmUJd2-aBC-
23
23
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
24
24
  bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
26
- bayinx/mhx/vi/meanfield.py,sha256=kx5WeD93-XO8NHxd4L4pZ8V19Y9B6j-yE3Y5OBXMcTk,3743
26
+ bayinx/mhx/vi/meanfield.py,sha256=wM0v6Q2m0pPyEdOaT8DvFqdRLsHijr7AYNQKAcsZtRQ,3881
27
27
  bayinx/mhx/vi/normalizing_flow.py,sha256=vzLu5H1G1-pBqhgHWmIZkUTyPE1DxC9vBwpiZeIyu1I,4712
28
28
  bayinx/mhx/vi/standard.py,sha256=LYgglaGQMGmXpzFR4eMJnXkl2PhBeggbXMvO5zJpf2c,1578
29
29
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
@@ -31,7 +31,7 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
31
31
  bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
32
32
  bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
33
33
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
34
- bayinx-0.3.14.dist-info/METADATA,sha256=ba_0a1aYSrKIGhGFG_IWA7TltEmGpK78IqCFeLV7XVI,3080
35
- bayinx-0.3.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
- bayinx-0.3.14.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
37
- bayinx-0.3.14.dist-info/RECORD,,
34
+ bayinx-0.3.16.dist-info/METADATA,sha256=9oLNZAc1YvN9FpWiuFFlztL-OkhclOYPBQFkuiHmUT8,3087
35
+ bayinx-0.3.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ bayinx-0.3.16.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
37
+ bayinx-0.3.16.dist-info/RECORD,,