bayinx 0.3.14__py3-none-any.whl → 0.3.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/core/_optimization.py +5 -6
- bayinx/core/_variational.py +25 -9
- bayinx/dists/gamma2.py +1 -1
- bayinx/dists/posnormal.py +0 -7
- bayinx/mhx/vi/meanfield.py +3 -2
- {bayinx-0.3.14.dist-info → bayinx-0.3.16.dist-info}/METADATA +3 -3
- {bayinx-0.3.14.dist-info → bayinx-0.3.16.dist-info}/RECORD +9 -9
- {bayinx-0.3.14.dist-info → bayinx-0.3.16.dist-info}/WHEEL +0 -0
- {bayinx-0.3.14.dist-info → bayinx-0.3.16.dist-info}/licenses/LICENSE +0 -0
bayinx/core/_optimization.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, Tuple, TypeVar
|
1
|
+
from typing import Any, Callable, Tuple, TypeVar
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.lax as lax
|
@@ -20,7 +20,7 @@ def optimize_model(
|
|
20
20
|
tolerance: float = 1e-4,
|
21
21
|
) -> M:
|
22
22
|
"""
|
23
|
-
Optimize the
|
23
|
+
Optimize the parameters of the model.
|
24
24
|
|
25
25
|
# Parameters
|
26
26
|
- `max_iters`: Maximum number of iterations for the optimization loop.
|
@@ -33,14 +33,13 @@ def optimize_model(
|
|
33
33
|
dyn, static = eqx.partition(model, model.filter_spec)
|
34
34
|
|
35
35
|
# Derive gradient for posterior
|
36
|
-
|
37
|
-
@eqx.filter_grad
|
38
|
-
def eval_grad(dyn: M):
|
36
|
+
def eval(dyn: M) -> Scalar:
|
39
37
|
# Reconstruct model
|
40
38
|
model: M = eqx.combine(dyn, static)
|
41
39
|
|
42
40
|
# Evaluate posterior
|
43
41
|
return model.eval(data)
|
42
|
+
eval_grad: Callable[[M], M] = eqx.filter_jit(eqx.filter_grad(eval))
|
44
43
|
|
45
44
|
# Construct scheduler
|
46
45
|
schedule: Schedule = opx.warmup_cosine_decay_schedule(
|
@@ -71,7 +70,7 @@ def optimize_model(
|
|
71
70
|
i = i + 1
|
72
71
|
|
73
72
|
# Evaluate gradient of posterior
|
74
|
-
updates = eval_grad(dyn)
|
73
|
+
updates: PyTree = eval_grad(dyn)
|
75
74
|
|
76
75
|
# Compute updates
|
77
76
|
updates, opt_state = optim.update(
|
bayinx/core/_variational.py
CHANGED
@@ -175,17 +175,33 @@ class Variational(eqx.Module, Generic[M]):
|
|
175
175
|
data: Any = None,
|
176
176
|
key: Key = jr.PRNGKey(0),
|
177
177
|
) -> Array:
|
178
|
-
# Sample
|
179
|
-
|
178
|
+
# Sample a single draw to evaluate shape of output
|
179
|
+
draw: Array = self.sample(1, key)
|
180
|
+
output: Array = func(self._unflatten(draw), data)
|
181
|
+
|
182
|
+
# Allocate space for results
|
183
|
+
results: Array = jnp.zeros((n,) + output.shape, dtype=output.dtype)
|
184
|
+
|
185
|
+
@eqx.filter_jit
|
186
|
+
def body_fun(i: int, state: Tuple[Key, Array]) -> Tuple[Key, Array]:
|
187
|
+
# Unpack state
|
188
|
+
key, results = state
|
189
|
+
|
190
|
+
# Update PRNG key
|
191
|
+
next, key = jr.split(key)
|
192
|
+
|
193
|
+
# Draw from variational
|
194
|
+
draw: Array = self.sample(1, key)
|
180
195
|
|
181
|
-
# Evaluate posterior predictive
|
182
|
-
@jax.jit
|
183
|
-
@jax.vmap
|
184
|
-
def evaluate(draw: Array):
|
185
196
|
# Reconstruct model
|
186
197
|
model: M = self._unflatten(draw)
|
187
198
|
|
188
|
-
#
|
189
|
-
|
199
|
+
# Update results with output
|
200
|
+
results = results.at[i].set(func(model, data))
|
201
|
+
|
202
|
+
return next, results
|
203
|
+
|
204
|
+
# Evaluate draws
|
205
|
+
results: Array = jax.lax.fori_loop(0, n, body_fun, (key, results))[1]
|
190
206
|
|
191
|
-
return
|
207
|
+
return results
|
bayinx/dists/gamma2.py
CHANGED
bayinx/dists/posnormal.py
CHANGED
@@ -139,9 +139,6 @@ def cdf(
|
|
139
139
|
|
140
140
|
# Returns
|
141
141
|
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
142
|
-
|
143
|
-
# Notes
|
144
|
-
Not numerically stable for small `x`.
|
145
142
|
"""
|
146
143
|
# Cast to Array
|
147
144
|
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
@@ -160,7 +157,6 @@ def cdf(
|
|
160
157
|
return evals
|
161
158
|
|
162
159
|
|
163
|
-
# TODO: make numerically stable
|
164
160
|
def logcdf(
|
165
161
|
x: Float[ArrayLike, "..."],
|
166
162
|
mu: Float[ArrayLike, "..."],
|
@@ -176,9 +172,6 @@ def logcdf(
|
|
176
172
|
|
177
173
|
# Returns
|
178
174
|
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
179
|
-
|
180
|
-
# Notes
|
181
|
-
Not numerically stable for small `x`.
|
182
175
|
"""
|
183
176
|
# Cast to Array
|
184
177
|
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
bayinx/mhx/vi/meanfield.py
CHANGED
@@ -23,12 +23,13 @@ class MeanField(Variational, Generic[M]):
|
|
23
23
|
|
24
24
|
var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
|
25
25
|
|
26
|
-
def __init__(self, model: M):
|
26
|
+
def __init__(self, model: M, init_log_std: float = 0.0):
|
27
27
|
"""
|
28
28
|
Constructs an unoptimized meanfield posterior approximation.
|
29
29
|
|
30
30
|
# Parameters
|
31
31
|
- `model`: A probabilistic `Model` object.
|
32
|
+
- `init_log_std`: The initial log-transformed standard deviation of the Gaussian approximation.
|
32
33
|
"""
|
33
34
|
# Partition model
|
34
35
|
params, self._constraints = eqx.partition(model, model.filter_spec)
|
@@ -39,7 +40,7 @@ class MeanField(Variational, Generic[M]):
|
|
39
40
|
# Initialize variational parameters
|
40
41
|
self.var_params = {
|
41
42
|
"mean": params,
|
42
|
-
"log_std": jnp.
|
43
|
+
"log_std": jnp.full(params.size, init_log_std, params.dtype),
|
43
44
|
}
|
44
45
|
|
45
46
|
@property
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: bayinx
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.16
|
4
4
|
Summary: Bayesian Inference with JAX
|
5
5
|
License-File: LICENSE
|
6
6
|
Requires-Python: >=3.12
|
@@ -10,9 +10,9 @@ Requires-Dist: jaxtyping>=0.2.36
|
|
10
10
|
Requires-Dist: optax>=0.2.4
|
11
11
|
Description-Content-Type: text/markdown
|
12
12
|
|
13
|
-
# <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
13
|
+
# `Bayinx`: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
14
14
|
|
15
|
-
The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is
|
15
|
+
The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is similar to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
|
16
16
|
|
17
17
|
In the short-term, I'm going to focus on:
|
18
18
|
1) Implementing as much machinery as I feel is enough.
|
@@ -6,14 +6,14 @@ bayinx/core/__init__.py,sha256=Qmy0EjzqqKwI9F8rjmC9j6J8hiDw6A54yOck2WuQJkY,344
|
|
6
6
|
bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
|
7
7
|
bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
8
8
|
bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
|
9
|
-
bayinx/core/_optimization.py,sha256=
|
9
|
+
bayinx/core/_optimization.py,sha256=mmeVUqfFARz8F7q4LRl-uEwVWzekmzh-9o7PnuvsHZk,2651
|
10
10
|
bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
|
11
|
-
bayinx/core/_variational.py,sha256=
|
11
|
+
bayinx/core/_variational.py,sha256=MbOgWs0-mkEwIYsh1Hhn7anWkQ7fAGjkr6VJybYAuR0,6185
|
12
12
|
bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
|
13
13
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
14
|
-
bayinx/dists/gamma2.py,sha256=
|
14
|
+
bayinx/dists/gamma2.py,sha256=HtB60LUQdPj4yDAHme2jsHNmLfrAKWsSZnDYkxAGaOI,1548
|
15
15
|
bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
|
16
|
-
bayinx/dists/posnormal.py,sha256=
|
16
|
+
bayinx/dists/posnormal.py,sha256=cOLCdd39DX3v8DD-seSIKNk4OfdNfaYaLzpCh_xBGyw,7150
|
17
17
|
bayinx/dists/uniform.py,sha256=2ZQxEfAX5TFgSPuQ8joFDuFbd_NfmQ1GvmGGjusqvNQ,3461
|
18
18
|
bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
|
19
19
|
bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
@@ -23,7 +23,7 @@ bayinx/dists/censored/posnormal/r.py,sha256=wMDt2Am1TD376ms8B-o6PFCJZXmUJd2-aBC-
|
|
23
23
|
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
24
24
|
bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
|
26
|
-
bayinx/mhx/vi/meanfield.py,sha256=
|
26
|
+
bayinx/mhx/vi/meanfield.py,sha256=wM0v6Q2m0pPyEdOaT8DvFqdRLsHijr7AYNQKAcsZtRQ,3881
|
27
27
|
bayinx/mhx/vi/normalizing_flow.py,sha256=vzLu5H1G1-pBqhgHWmIZkUTyPE1DxC9vBwpiZeIyu1I,4712
|
28
28
|
bayinx/mhx/vi/standard.py,sha256=LYgglaGQMGmXpzFR4eMJnXkl2PhBeggbXMvO5zJpf2c,1578
|
29
29
|
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
@@ -31,7 +31,7 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
|
|
31
31
|
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
32
32
|
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
33
33
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
34
|
-
bayinx-0.3.
|
35
|
-
bayinx-0.3.
|
36
|
-
bayinx-0.3.
|
37
|
-
bayinx-0.3.
|
34
|
+
bayinx-0.3.16.dist-info/METADATA,sha256=9oLNZAc1YvN9FpWiuFFlztL-OkhclOYPBQFkuiHmUT8,3087
|
35
|
+
bayinx-0.3.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
36
|
+
bayinx-0.3.16.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
37
|
+
bayinx-0.3.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|