bayinx 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/constraints/__init__.py +1 -1
- bayinx/core/__init__.py +10 -1
- bayinx/core/_optimization.py +94 -0
- bayinx/core/_parameter.py +2 -0
- bayinx/core/_variational.py +13 -3
- bayinx/dists/__init__.py +1 -1
- bayinx/dists/censored/posnormal/r.py +7 -7
- bayinx/dists/uniform.py +39 -5
- bayinx/mhx/opt/__init__.py +0 -0
- bayinx/mhx/vi/__init__.py +1 -1
- bayinx/mhx/vi/meanfield.py +4 -2
- bayinx/mhx/vi/normalizing_flow.py +3 -1
- bayinx/mhx/vi/standard.py +1 -1
- {bayinx-0.3.11.dist-info → bayinx-0.3.13.dist-info}/METADATA +1 -1
- {bayinx-0.3.11.dist-info → bayinx-0.3.13.dist-info}/RECORD +17 -15
- {bayinx-0.3.11.dist-info → bayinx-0.3.13.dist-info}/WHEEL +0 -0
- {bayinx-0.3.11.dist-info → bayinx-0.3.13.dist-info}/licenses/LICENSE +0 -0
bayinx/constraints/__init__.py
CHANGED
bayinx/core/__init__.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1
1
|
from ._constraint import Constraint
|
2
2
|
from ._flow import Flow
|
3
3
|
from ._model import Model, constrain
|
4
|
+
from ._optimization import optimize_model
|
4
5
|
from ._parameter import Parameter
|
5
6
|
from ._variational import Variational
|
6
7
|
|
7
|
-
__all__ = [
|
8
|
+
__all__ = [
|
9
|
+
"Constraint",
|
10
|
+
"Flow",
|
11
|
+
"Model",
|
12
|
+
"constrain",
|
13
|
+
"optimize_model",
|
14
|
+
"Parameter",
|
15
|
+
"Variational",
|
16
|
+
]
|
@@ -0,0 +1,94 @@
|
|
1
|
+
from typing import Any, Tuple
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.lax as lax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import optax as opx
|
7
|
+
from jaxtyping import PyTree, Scalar
|
8
|
+
from optax import GradientTransformation, OptState, Schedule
|
9
|
+
|
10
|
+
from ._model import Model
|
11
|
+
|
12
|
+
|
13
|
+
@eqx.filter_jit
|
14
|
+
def optimize_model(
|
15
|
+
model: Model,
|
16
|
+
max_iters: int,
|
17
|
+
data: Any = None,
|
18
|
+
learning_rate: float = 1,
|
19
|
+
weight_decay: float = 0.0,
|
20
|
+
tolerance: float = 1e-4,
|
21
|
+
) -> Model:
|
22
|
+
"""
|
23
|
+
Optimize the dynamic parameters of the model.
|
24
|
+
|
25
|
+
# Parameters
|
26
|
+
- `max_iters`: Maximum number of iterations for the optimization loop.
|
27
|
+
- `data`: Data to evaluate the model with.
|
28
|
+
- `learning_rate`: Initial learning rate for optimization.
|
29
|
+
- `weight_decay`: Weight decay for the optimizer.
|
30
|
+
- `tolerance`: Relative tolerance of loss decrease for stopping early (not implemented in the loop).
|
31
|
+
"""
|
32
|
+
# Get dynamic and static componts of model
|
33
|
+
dyn, static = eqx.partition(model, model.filter_spec)
|
34
|
+
|
35
|
+
# Derive gradient for posterior
|
36
|
+
@eqx.filter_jit
|
37
|
+
@eqx.filter_grad
|
38
|
+
def eval_grad(dyn: Model):
|
39
|
+
# Reconstruct model
|
40
|
+
model: Model = eqx.combine(dyn, static)
|
41
|
+
|
42
|
+
# Evaluate posterior
|
43
|
+
return model.eval(data)
|
44
|
+
|
45
|
+
# Construct scheduler
|
46
|
+
schedule: Schedule = opx.warmup_cosine_decay_schedule(
|
47
|
+
init_value=1e-16,
|
48
|
+
peak_value=learning_rate,
|
49
|
+
warmup_steps=int(max_iters / 10),
|
50
|
+
decay_steps=max_iters - int(max_iters / 10),
|
51
|
+
)
|
52
|
+
|
53
|
+
optim: GradientTransformation = opx.chain(
|
54
|
+
opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
|
55
|
+
)
|
56
|
+
opt_state: OptState = optim.init(dyn)
|
57
|
+
|
58
|
+
@eqx.filter_jit
|
59
|
+
def condition(state: Tuple[PyTree, OptState, Scalar]):
|
60
|
+
# Unpack iteration state
|
61
|
+
current_opt_dyn, opt_state, i = state
|
62
|
+
|
63
|
+
return i < max_iters
|
64
|
+
|
65
|
+
@eqx.filter_jit
|
66
|
+
def body(state: Tuple[PyTree, OptState, Scalar]):
|
67
|
+
# Unpack iteration state
|
68
|
+
dyn, opt_state, i = state
|
69
|
+
|
70
|
+
# Update iteration
|
71
|
+
i = i + 1
|
72
|
+
|
73
|
+
# Evaluate gradient of posterior
|
74
|
+
updates = eval_grad(dyn)
|
75
|
+
|
76
|
+
# Compute updates
|
77
|
+
updates, opt_state = optim.update(
|
78
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
|
79
|
+
)
|
80
|
+
|
81
|
+
# Update model
|
82
|
+
dyn = eqx.apply_updates(dyn, updates)
|
83
|
+
|
84
|
+
return dyn, opt_state, i
|
85
|
+
|
86
|
+
# Run optimization loop
|
87
|
+
dyn = lax.while_loop(
|
88
|
+
cond_fun=condition,
|
89
|
+
body_fun=body,
|
90
|
+
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32)),
|
91
|
+
)[0]
|
92
|
+
|
93
|
+
# Return optimized model
|
94
|
+
return eqx.combine(dyn, static)
|
bayinx/core/_parameter.py
CHANGED
bayinx/core/_variational.py
CHANGED
@@ -13,7 +13,9 @@ from optax import GradientTransformation, OptState, Schedule
|
|
13
13
|
|
14
14
|
from ._model import Model
|
15
15
|
|
16
|
-
M = TypeVar(
|
16
|
+
M = TypeVar("M", bound=Model)
|
17
|
+
|
18
|
+
|
17
19
|
class Variational(eqx.Module, Generic[M]):
|
18
20
|
"""
|
19
21
|
An abstract base class used to define variational methods.
|
@@ -80,6 +82,7 @@ class Variational(eqx.Module, Generic[M]):
|
|
80
82
|
# Evaluate posterior density
|
81
83
|
return model.eval(data)
|
82
84
|
|
85
|
+
# TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
|
83
86
|
@eqx.filter_jit
|
84
87
|
def fit(
|
85
88
|
self,
|
@@ -107,7 +110,10 @@ class Variational(eqx.Module, Generic[M]):
|
|
107
110
|
|
108
111
|
# Construct scheduler
|
109
112
|
schedule: Schedule = opx.warmup_cosine_decay_schedule(
|
110
|
-
init_value=1e-16,
|
113
|
+
init_value=1e-16,
|
114
|
+
peak_value=learning_rate,
|
115
|
+
warmup_steps=int(max_iters / 10),
|
116
|
+
decay_steps=max_iters - int(max_iters / 10),
|
111
117
|
)
|
112
118
|
|
113
119
|
# Initialize optimizer
|
@@ -163,7 +169,11 @@ class Variational(eqx.Module, Generic[M]):
|
|
163
169
|
|
164
170
|
@eqx.filter_jit
|
165
171
|
def posterior_predictive(
|
166
|
-
self,
|
172
|
+
self,
|
173
|
+
func: Callable[[M, Any], Array],
|
174
|
+
n: int,
|
175
|
+
data: Any = None,
|
176
|
+
key: Key = jr.PRNGKey(0),
|
167
177
|
) -> Array:
|
168
178
|
# Sample draws from the variational approximation
|
169
179
|
draws: Array = self.sample(n, key)
|
bayinx/dists/__init__.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
2
|
import jax.random as jr
|
3
|
+
from jax.scipy.special import ndtri
|
3
4
|
from jaxtyping import Array, ArrayLike, Float, Key
|
4
5
|
|
5
|
-
from bayinx.dists import posnormal
|
6
|
+
from bayinx.dists import normal, posnormal
|
6
7
|
|
7
8
|
|
8
9
|
def prob(
|
@@ -78,12 +79,13 @@ def logprob(
|
|
78
79
|
|
79
80
|
return evals
|
80
81
|
|
82
|
+
|
81
83
|
def sample(
|
82
84
|
n: int,
|
83
85
|
mu: Float[ArrayLike, "..."],
|
84
86
|
sigma: Float[ArrayLike, "..."],
|
85
87
|
censor: Float[ArrayLike, "..."] = jnp.inf,
|
86
|
-
key: Key = jr.PRNGKey(0)
|
88
|
+
key: Key = jr.PRNGKey(0),
|
87
89
|
) -> Float[Array, "..."]:
|
88
90
|
"""
|
89
91
|
Sample from a right-censored positive Normal distribution.
|
@@ -107,10 +109,8 @@ def sample(
|
|
107
109
|
# Derive shape
|
108
110
|
shape = (n,) + jnp.broadcast_shapes(mu.shape, sigma.shape, censor.shape)
|
109
111
|
|
110
|
-
#
|
111
|
-
draws = jr.
|
112
|
-
|
113
|
-
# Censor values
|
114
|
-
draws = jnp.where(censor <= draws, censor, draws)
|
112
|
+
# Construct draws
|
113
|
+
draws = jr.uniform(key, shape)
|
114
|
+
draws = mu + sigma * ndtri(normal.cdf(-mu/sigma, 0.0, 1.0) + draws * normal.cdf(mu/sigma, 0.0, 1.0))
|
115
115
|
|
116
116
|
return draws
|
bayinx/dists/uniform.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import jax.lax as _lax
|
2
2
|
import jax.numpy as jnp
|
3
|
-
|
3
|
+
import jax.random as jr
|
4
|
+
from jaxtyping import Array, ArrayLike, Float, Key
|
4
5
|
|
5
6
|
|
6
7
|
def prob(
|
@@ -17,8 +18,10 @@ def prob(
|
|
17
18
|
# Returns
|
18
19
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
19
20
|
"""
|
21
|
+
# Cast to Array
|
22
|
+
x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
|
20
23
|
|
21
|
-
return 1.0 / (ub - lb)
|
24
|
+
return 1.0 / (ub - lb)
|
22
25
|
|
23
26
|
|
24
27
|
def logprob(
|
@@ -35,8 +38,10 @@ def logprob(
|
|
35
38
|
# Returns
|
36
39
|
The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
37
40
|
"""
|
41
|
+
# Cast to Array
|
42
|
+
x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
|
38
43
|
|
39
|
-
return _lax.log(1.0) - _lax.log(ub - lb)
|
44
|
+
return _lax.log(1.0) - _lax.log(ub - lb)
|
40
45
|
|
41
46
|
|
42
47
|
def uprob(
|
@@ -53,8 +58,10 @@ def uprob(
|
|
53
58
|
# Returns
|
54
59
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
55
60
|
"""
|
61
|
+
# Cast to Array
|
62
|
+
x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
|
56
63
|
|
57
|
-
return jnp.ones(jnp.
|
64
|
+
return jnp.ones(jnp.broadcast_shapes(x.shape, lb.shape, ub.shape))
|
58
65
|
|
59
66
|
|
60
67
|
def ulogprob(
|
@@ -71,5 +78,32 @@ def ulogprob(
|
|
71
78
|
# Returns
|
72
79
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
73
80
|
"""
|
81
|
+
# Cast to Array
|
82
|
+
x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
|
74
83
|
|
75
|
-
return jnp.zeros(jnp.
|
84
|
+
return jnp.zeros(jnp.broadcast_shapes(x.shape, lb.shape, ub.shape))
|
85
|
+
|
86
|
+
def sample(
|
87
|
+
n: int, lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."], key: Key = jr.PRNGKey(0),
|
88
|
+
) -> Float[Array, "..."]:
|
89
|
+
"""
|
90
|
+
Sample from a Uniform distribution.
|
91
|
+
|
92
|
+
# Parameters
|
93
|
+
- `n`: Number of draws to sample per-parameter.
|
94
|
+
- `lb`: The lower bound parameter(s).
|
95
|
+
- `ub`: The upper bound parameter(s).
|
96
|
+
|
97
|
+
# Returns
|
98
|
+
Draws from a Uniform distribution. The output will have the shape of (n,) + the broadcasted shapes of `lb` and `ub`.
|
99
|
+
"""
|
100
|
+
# Cast to Array
|
101
|
+
lb, ub = jnp.asarray(lb), jnp.asarray(ub)
|
102
|
+
|
103
|
+
# Derive shape
|
104
|
+
shape = (n,) + jnp.broadcast_shapes(lb.shape, ub.shape)
|
105
|
+
|
106
|
+
# Construct draws
|
107
|
+
draws = jr.uniform(key, shape, minval = lb, maxval = ub)
|
108
|
+
|
109
|
+
return draws
|
File without changes
|
bayinx/mhx/vi/__init__.py
CHANGED
bayinx/mhx/vi/meanfield.py
CHANGED
@@ -10,7 +10,9 @@ from jaxtyping import Array, Float, Key, Scalar
|
|
10
10
|
from bayinx.core import Model, Variational
|
11
11
|
from bayinx.dists import normal
|
12
12
|
|
13
|
-
M = TypeVar(
|
13
|
+
M = TypeVar("M", bound=Model)
|
14
|
+
|
15
|
+
|
14
16
|
class MeanField(Variational, Generic[M]):
|
15
17
|
"""
|
16
18
|
A fully factorized Gaussian approximation to a posterior distribution.
|
@@ -19,7 +21,7 @@ class MeanField(Variational, Generic[M]):
|
|
19
21
|
- `var_params`: The variational parameters for the approximation.
|
20
22
|
"""
|
21
23
|
|
22
|
-
var_params: Dict[str, Float[Array, "..."]]
|
24
|
+
var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
|
23
25
|
|
24
26
|
def __init__(self, model: M):
|
25
27
|
"""
|
@@ -9,7 +9,9 @@ from jaxtyping import Array, Key, Scalar
|
|
9
9
|
|
10
10
|
from bayinx.core import Flow, Model, Variational
|
11
11
|
|
12
|
-
M = TypeVar(
|
12
|
+
M = TypeVar("M", bound=Model)
|
13
|
+
|
14
|
+
|
13
15
|
class NormalizingFlow(Variational, Generic[M]):
|
14
16
|
"""
|
15
17
|
An ordered collection of diffeomorphisms that map a base distribution to a
|
bayinx/mhx/vi/standard.py
CHANGED
@@ -1,35 +1,37 @@
|
|
1
1
|
bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
|
2
2
|
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx/constraints/__init__.py,sha256=
|
3
|
+
bayinx/constraints/__init__.py,sha256=027WJxRLkybXZkmusfvR6iZayY2pDid7Tw6TTTeA6ko,64
|
4
4
|
bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
|
5
|
-
bayinx/core/__init__.py,sha256=
|
5
|
+
bayinx/core/__init__.py,sha256=Qmy0EjzqqKwI9F8rjmC9j6J8hiDw6A54yOck2WuQJkY,344
|
6
6
|
bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
|
7
7
|
bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
8
8
|
bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
|
9
|
-
bayinx/core/
|
10
|
-
bayinx/core/
|
11
|
-
bayinx/
|
9
|
+
bayinx/core/_optimization.py,sha256=cXO07guDG5kd64kWq4_de-gLgxTT6vIOU3IOL3TMl6U,2583
|
10
|
+
bayinx/core/_parameter.py,sha256=UfLonbTwxzr1g76Cf3HzAh9u4UBtlcaLByYLXgq-aCQ,1082
|
11
|
+
bayinx/core/_variational.py,sha256=bQYiN8c4AGPt4hsNT68zN7J6o0fdsLGwVjbDyl62LnI,5639
|
12
|
+
bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
|
12
13
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
13
14
|
bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
|
14
15
|
bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
|
15
16
|
bayinx/dists/posnormal.py,sha256=w9plA1EctXwXOiY0doc4ZndjnwptbEZBHHCGdc4gviY,7292
|
16
|
-
bayinx/dists/uniform.py,sha256=
|
17
|
+
bayinx/dists/uniform.py,sha256=2ZQxEfAX5TFgSPuQ8joFDuFbd_NfmQ1GvmGGjusqvNQ,3461
|
17
18
|
bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
|
18
19
|
bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
19
20
|
bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
|
20
21
|
bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
21
|
-
bayinx/dists/censored/posnormal/r.py,sha256=
|
22
|
+
bayinx/dists/censored/posnormal/r.py,sha256=wMDt2Am1TD376ms8B-o6PFCJZXmUJd2-aBC-t9kidH4,3456
|
22
23
|
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
23
|
-
bayinx/mhx/
|
24
|
-
bayinx/mhx/vi/
|
25
|
-
bayinx/mhx/vi/
|
26
|
-
bayinx/mhx/vi/
|
24
|
+
bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
|
+
bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
|
26
|
+
bayinx/mhx/vi/meanfield.py,sha256=kx5WeD93-XO8NHxd4L4pZ8V19Y9B6j-yE3Y5OBXMcTk,3743
|
27
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=vzLu5H1G1-pBqhgHWmIZkUTyPE1DxC9vBwpiZeIyu1I,4712
|
28
|
+
bayinx/mhx/vi/standard.py,sha256=LYgglaGQMGmXpzFR4eMJnXkl2PhBeggbXMvO5zJpf2c,1578
|
27
29
|
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
28
30
|
bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
|
29
31
|
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
30
32
|
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
31
33
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
32
|
-
bayinx-0.3.
|
33
|
-
bayinx-0.3.
|
34
|
-
bayinx-0.3.
|
35
|
-
bayinx-0.3.
|
34
|
+
bayinx-0.3.13.dist-info/METADATA,sha256=1mcHMTXrzMGPcibMy_vHdBJMNPtT0VUF83eVBS_MJlg,3080
|
35
|
+
bayinx-0.3.13.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
36
|
+
bayinx-0.3.13.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
37
|
+
bayinx-0.3.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|