bayinx 0.3.10__py3-none-any.whl → 0.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/constraints/__init__.py +1 -1
- bayinx/core/__init__.py +10 -1
- bayinx/core/_optimization.py +94 -0
- bayinx/core/_parameter.py +2 -0
- bayinx/core/_variational.py +15 -5
- bayinx/dists/__init__.py +1 -1
- bayinx/dists/censored/posnormal/r.py +2 -1
- bayinx/mhx/opt/__init__.py +0 -0
- bayinx/mhx/vi/__init__.py +1 -1
- bayinx/mhx/vi/meanfield.py +4 -2
- bayinx/mhx/vi/normalizing_flow.py +3 -1
- bayinx/mhx/vi/standard.py +1 -1
- {bayinx-0.3.10.dist-info → bayinx-0.3.12.dist-info}/METADATA +1 -1
- {bayinx-0.3.10.dist-info → bayinx-0.3.12.dist-info}/RECORD +16 -14
- {bayinx-0.3.10.dist-info → bayinx-0.3.12.dist-info}/WHEEL +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.3.12.dist-info}/licenses/LICENSE +0 -0
bayinx/constraints/__init__.py
CHANGED
bayinx/core/__init__.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1
1
|
from ._constraint import Constraint
|
2
2
|
from ._flow import Flow
|
3
3
|
from ._model import Model, constrain
|
4
|
+
from ._optimization import optimize_model
|
4
5
|
from ._parameter import Parameter
|
5
6
|
from ._variational import Variational
|
6
7
|
|
7
|
-
__all__ = [
|
8
|
+
__all__ = [
|
9
|
+
"Constraint",
|
10
|
+
"Flow",
|
11
|
+
"Model",
|
12
|
+
"constrain",
|
13
|
+
"optimize_model",
|
14
|
+
"Parameter",
|
15
|
+
"Variational",
|
16
|
+
]
|
@@ -0,0 +1,94 @@
|
|
1
|
+
from typing import Any, Tuple
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.lax as lax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import optax as opx
|
7
|
+
from jaxtyping import PyTree, Scalar
|
8
|
+
from optax import GradientTransformation, OptState, Schedule
|
9
|
+
|
10
|
+
from ._model import Model
|
11
|
+
|
12
|
+
|
13
|
+
@eqx.filter_jit
|
14
|
+
def optimize_model(
|
15
|
+
model: Model,
|
16
|
+
max_iters: int,
|
17
|
+
data: Any = None,
|
18
|
+
learning_rate: float = 1,
|
19
|
+
weight_decay: float = 0.0,
|
20
|
+
tolerance: float = 1e-4,
|
21
|
+
) -> Model:
|
22
|
+
"""
|
23
|
+
Optimize the dynamic parameters of the model.
|
24
|
+
|
25
|
+
# Parameters
|
26
|
+
- `max_iters`: Maximum number of iterations for the optimization loop.
|
27
|
+
- `data`: Data to evaluate the model with.
|
28
|
+
- `learning_rate`: Initial learning rate for optimization.
|
29
|
+
- `weight_decay`: Weight decay for the optimizer.
|
30
|
+
- `tolerance`: Relative tolerance of loss decrease for stopping early (not implemented in the loop).
|
31
|
+
"""
|
32
|
+
# Get dynamic and static componts of model
|
33
|
+
dyn, static = eqx.partition(model, model.filter_spec)
|
34
|
+
|
35
|
+
# Derive gradient for posterior
|
36
|
+
@eqx.filter_jit
|
37
|
+
@eqx.filter_grad
|
38
|
+
def eval_grad(dyn: Model):
|
39
|
+
# Reconstruct model
|
40
|
+
model: Model = eqx.combine(dyn, static)
|
41
|
+
|
42
|
+
# Evaluate posterior
|
43
|
+
return model.eval(data)
|
44
|
+
|
45
|
+
# Construct scheduler
|
46
|
+
schedule: Schedule = opx.warmup_cosine_decay_schedule(
|
47
|
+
init_value=1e-16,
|
48
|
+
peak_value=learning_rate,
|
49
|
+
warmup_steps=int(max_iters / 10),
|
50
|
+
decay_steps=max_iters - int(max_iters / 10),
|
51
|
+
)
|
52
|
+
|
53
|
+
optim: GradientTransformation = opx.chain(
|
54
|
+
opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
|
55
|
+
)
|
56
|
+
opt_state: OptState = optim.init(dyn)
|
57
|
+
|
58
|
+
@eqx.filter_jit
|
59
|
+
def condition(state: Tuple[PyTree, OptState, Scalar]):
|
60
|
+
# Unpack iteration state
|
61
|
+
current_opt_dyn, opt_state, i = state
|
62
|
+
|
63
|
+
return i < max_iters
|
64
|
+
|
65
|
+
@eqx.filter_jit
|
66
|
+
def body(state: Tuple[PyTree, OptState, Scalar]):
|
67
|
+
# Unpack iteration state
|
68
|
+
dyn, opt_state, i = state
|
69
|
+
|
70
|
+
# Update iteration
|
71
|
+
i = i + 1
|
72
|
+
|
73
|
+
# Evaluate gradient of posterior
|
74
|
+
updates = eval_grad(dyn)
|
75
|
+
|
76
|
+
# Compute updates
|
77
|
+
updates, opt_state = optim.update(
|
78
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
|
79
|
+
)
|
80
|
+
|
81
|
+
# Update model
|
82
|
+
dyn = eqx.apply_updates(dyn, updates)
|
83
|
+
|
84
|
+
return dyn, opt_state, i
|
85
|
+
|
86
|
+
# Run optimization loop
|
87
|
+
dyn = lax.while_loop(
|
88
|
+
cond_fun=condition,
|
89
|
+
body_fun=body,
|
90
|
+
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32)),
|
91
|
+
)[0]
|
92
|
+
|
93
|
+
# Return optimized model
|
94
|
+
return eqx.combine(dyn, static)
|
bayinx/core/_parameter.py
CHANGED
bayinx/core/_variational.py
CHANGED
@@ -13,7 +13,9 @@ from optax import GradientTransformation, OptState, Schedule
|
|
13
13
|
|
14
14
|
from ._model import Model
|
15
15
|
|
16
|
-
M = TypeVar(
|
16
|
+
M = TypeVar("M", bound=Model)
|
17
|
+
|
18
|
+
|
17
19
|
class Variational(eqx.Module, Generic[M]):
|
18
20
|
"""
|
19
21
|
An abstract base class used to define variational methods.
|
@@ -80,6 +82,7 @@ class Variational(eqx.Module, Generic[M]):
|
|
80
82
|
# Evaluate posterior density
|
81
83
|
return model.eval(data)
|
82
84
|
|
85
|
+
# TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
|
83
86
|
@eqx.filter_jit
|
84
87
|
def fit(
|
85
88
|
self,
|
@@ -107,7 +110,10 @@ class Variational(eqx.Module, Generic[M]):
|
|
107
110
|
|
108
111
|
# Construct scheduler
|
109
112
|
schedule: Schedule = opx.warmup_cosine_decay_schedule(
|
110
|
-
init_value=1e-16,
|
113
|
+
init_value=1e-16,
|
114
|
+
peak_value=learning_rate,
|
115
|
+
warmup_steps=int(max_iters / 10),
|
116
|
+
decay_steps=max_iters - int(max_iters / 10),
|
111
117
|
)
|
112
118
|
|
113
119
|
# Initialize optimizer
|
@@ -163,7 +169,11 @@ class Variational(eqx.Module, Generic[M]):
|
|
163
169
|
|
164
170
|
@eqx.filter_jit
|
165
171
|
def posterior_predictive(
|
166
|
-
self,
|
172
|
+
self,
|
173
|
+
func: Callable[[M, Any], Array],
|
174
|
+
n: int,
|
175
|
+
data: Any = None,
|
176
|
+
key: Key = jr.PRNGKey(0),
|
167
177
|
) -> Array:
|
168
178
|
# Sample draws from the variational approximation
|
169
179
|
draws: Array = self.sample(n, key)
|
@@ -171,11 +181,11 @@ class Variational(eqx.Module, Generic[M]):
|
|
171
181
|
# Evaluate posterior predictive
|
172
182
|
@jax.jit
|
173
183
|
@jax.vmap
|
174
|
-
def evaluate(draw: Array
|
184
|
+
def evaluate(draw: Array):
|
175
185
|
# Reconstruct model
|
176
186
|
model: M = self._unflatten(draw)
|
177
187
|
|
178
188
|
# Evaluate
|
179
189
|
return func(model, data)
|
180
190
|
|
181
|
-
return evaluate(draws
|
191
|
+
return evaluate(draws)
|
bayinx/dists/__init__.py
CHANGED
@@ -78,12 +78,13 @@ def logprob(
|
|
78
78
|
|
79
79
|
return evals
|
80
80
|
|
81
|
+
|
81
82
|
def sample(
|
82
83
|
n: int,
|
83
84
|
mu: Float[ArrayLike, "..."],
|
84
85
|
sigma: Float[ArrayLike, "..."],
|
85
86
|
censor: Float[ArrayLike, "..."] = jnp.inf,
|
86
|
-
key: Key = jr.PRNGKey(0)
|
87
|
+
key: Key = jr.PRNGKey(0),
|
87
88
|
) -> Float[Array, "..."]:
|
88
89
|
"""
|
89
90
|
Sample from a right-censored positive Normal distribution.
|
File without changes
|
bayinx/mhx/vi/__init__.py
CHANGED
bayinx/mhx/vi/meanfield.py
CHANGED
@@ -10,7 +10,9 @@ from jaxtyping import Array, Float, Key, Scalar
|
|
10
10
|
from bayinx.core import Model, Variational
|
11
11
|
from bayinx.dists import normal
|
12
12
|
|
13
|
-
M = TypeVar(
|
13
|
+
M = TypeVar("M", bound=Model)
|
14
|
+
|
15
|
+
|
14
16
|
class MeanField(Variational, Generic[M]):
|
15
17
|
"""
|
16
18
|
A fully factorized Gaussian approximation to a posterior distribution.
|
@@ -19,7 +21,7 @@ class MeanField(Variational, Generic[M]):
|
|
19
21
|
- `var_params`: The variational parameters for the approximation.
|
20
22
|
"""
|
21
23
|
|
22
|
-
var_params: Dict[str, Float[Array, "..."]]
|
24
|
+
var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
|
23
25
|
|
24
26
|
def __init__(self, model: M):
|
25
27
|
"""
|
@@ -9,7 +9,9 @@ from jaxtyping import Array, Key, Scalar
|
|
9
9
|
|
10
10
|
from bayinx.core import Flow, Model, Variational
|
11
11
|
|
12
|
-
M = TypeVar(
|
12
|
+
M = TypeVar("M", bound=Model)
|
13
|
+
|
14
|
+
|
13
15
|
class NormalizingFlow(Variational, Generic[M]):
|
14
16
|
"""
|
15
17
|
An ordered collection of diffeomorphisms that map a base distribution to a
|
bayinx/mhx/vi/standard.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1
1
|
bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
|
2
2
|
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx/constraints/__init__.py,sha256=
|
3
|
+
bayinx/constraints/__init__.py,sha256=027WJxRLkybXZkmusfvR6iZayY2pDid7Tw6TTTeA6ko,64
|
4
4
|
bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
|
5
|
-
bayinx/core/__init__.py,sha256=
|
5
|
+
bayinx/core/__init__.py,sha256=Qmy0EjzqqKwI9F8rjmC9j6J8hiDw6A54yOck2WuQJkY,344
|
6
6
|
bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
|
7
7
|
bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
8
8
|
bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
|
9
|
-
bayinx/core/
|
10
|
-
bayinx/core/
|
11
|
-
bayinx/
|
9
|
+
bayinx/core/_optimization.py,sha256=cXO07guDG5kd64kWq4_de-gLgxTT6vIOU3IOL3TMl6U,2583
|
10
|
+
bayinx/core/_parameter.py,sha256=UfLonbTwxzr1g76Cf3HzAh9u4UBtlcaLByYLXgq-aCQ,1082
|
11
|
+
bayinx/core/_variational.py,sha256=bQYiN8c4AGPt4hsNT68zN7J6o0fdsLGwVjbDyl62LnI,5639
|
12
|
+
bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
|
12
13
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
13
14
|
bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
|
14
15
|
bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
|
@@ -18,18 +19,19 @@ bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U
|
|
18
19
|
bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
19
20
|
bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
|
20
21
|
bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
21
|
-
bayinx/dists/censored/posnormal/r.py,sha256=
|
22
|
+
bayinx/dists/censored/posnormal/r.py,sha256=r8QfEThaaqsqqrr6PHWiY0EdfnjQcbKdTAHN2GmDVzI,3428
|
22
23
|
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
23
|
-
bayinx/mhx/
|
24
|
-
bayinx/mhx/vi/
|
25
|
-
bayinx/mhx/vi/
|
26
|
-
bayinx/mhx/vi/
|
24
|
+
bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
|
+
bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
|
26
|
+
bayinx/mhx/vi/meanfield.py,sha256=kx5WeD93-XO8NHxd4L4pZ8V19Y9B6j-yE3Y5OBXMcTk,3743
|
27
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=vzLu5H1G1-pBqhgHWmIZkUTyPE1DxC9vBwpiZeIyu1I,4712
|
28
|
+
bayinx/mhx/vi/standard.py,sha256=LYgglaGQMGmXpzFR4eMJnXkl2PhBeggbXMvO5zJpf2c,1578
|
27
29
|
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
28
30
|
bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
|
29
31
|
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
30
32
|
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
31
33
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
32
|
-
bayinx-0.3.
|
33
|
-
bayinx-0.3.
|
34
|
-
bayinx-0.3.
|
35
|
-
bayinx-0.3.
|
34
|
+
bayinx-0.3.12.dist-info/METADATA,sha256=RvwZfeFmarMW4SlQhurD5TLnAS0sge9lVNcHPUNjIIM,3080
|
35
|
+
bayinx-0.3.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
36
|
+
bayinx-0.3.12.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
37
|
+
bayinx-0.3.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|