bayinx 0.3.11__py3-none-any.whl → 0.3.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,3 @@
1
1
  from bayinx.constraints.lower import Lower
2
2
 
3
- __all__ = ['Lower']
3
+ __all__ = ["Lower"]
bayinx/core/__init__.py CHANGED
@@ -1,7 +1,16 @@
1
1
  from ._constraint import Constraint
2
2
  from ._flow import Flow
3
3
  from ._model import Model, constrain
4
+ from ._optimization import optimize_model
4
5
  from ._parameter import Parameter
5
6
  from ._variational import Variational
6
7
 
7
- __all__ = ["Constraint", "Flow", "Model", "constrain", "Parameter", "Variational"]
8
+ __all__ = [
9
+ "Constraint",
10
+ "Flow",
11
+ "Model",
12
+ "constrain",
13
+ "optimize_model",
14
+ "Parameter",
15
+ "Variational",
16
+ ]
@@ -0,0 +1,94 @@
1
+ from typing import Any, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.lax as lax
5
+ import jax.numpy as jnp
6
+ import optax as opx
7
+ from jaxtyping import PyTree, Scalar
8
+ from optax import GradientTransformation, OptState, Schedule
9
+
10
+ from ._model import Model
11
+
12
+
13
+ @eqx.filter_jit
14
+ def optimize_model(
15
+ model: Model,
16
+ max_iters: int,
17
+ data: Any = None,
18
+ learning_rate: float = 1,
19
+ weight_decay: float = 0.0,
20
+ tolerance: float = 1e-4,
21
+ ) -> Model:
22
+ """
23
+ Optimize the dynamic parameters of the model.
24
+
25
+ # Parameters
26
+ - `max_iters`: Maximum number of iterations for the optimization loop.
27
+ - `data`: Data to evaluate the model with.
28
+ - `learning_rate`: Initial learning rate for optimization.
29
+ - `weight_decay`: Weight decay for the optimizer.
30
+ - `tolerance`: Relative tolerance of loss decrease for stopping early (not implemented in the loop).
31
+ """
32
+ # Get dynamic and static componts of model
33
+ dyn, static = eqx.partition(model, model.filter_spec)
34
+
35
+ # Derive gradient for posterior
36
+ @eqx.filter_jit
37
+ @eqx.filter_grad
38
+ def eval_grad(dyn: Model):
39
+ # Reconstruct model
40
+ model: Model = eqx.combine(dyn, static)
41
+
42
+ # Evaluate posterior
43
+ return model.eval(data)
44
+
45
+ # Construct scheduler
46
+ schedule: Schedule = opx.warmup_cosine_decay_schedule(
47
+ init_value=1e-16,
48
+ peak_value=learning_rate,
49
+ warmup_steps=int(max_iters / 10),
50
+ decay_steps=max_iters - int(max_iters / 10),
51
+ )
52
+
53
+ optim: GradientTransformation = opx.chain(
54
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
55
+ )
56
+ opt_state: OptState = optim.init(dyn)
57
+
58
+ @eqx.filter_jit
59
+ def condition(state: Tuple[PyTree, OptState, Scalar]):
60
+ # Unpack iteration state
61
+ current_opt_dyn, opt_state, i = state
62
+
63
+ return i < max_iters
64
+
65
+ @eqx.filter_jit
66
+ def body(state: Tuple[PyTree, OptState, Scalar]):
67
+ # Unpack iteration state
68
+ dyn, opt_state, i = state
69
+
70
+ # Update iteration
71
+ i = i + 1
72
+
73
+ # Evaluate gradient of posterior
74
+ updates = eval_grad(dyn)
75
+
76
+ # Compute updates
77
+ updates, opt_state = optim.update(
78
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
79
+ )
80
+
81
+ # Update model
82
+ dyn = eqx.apply_updates(dyn, updates)
83
+
84
+ return dyn, opt_state, i
85
+
86
+ # Run optimization loop
87
+ dyn = lax.while_loop(
88
+ cond_fun=condition,
89
+ body_fun=body,
90
+ init_val=(dyn, opt_state, jnp.array(0, jnp.uint32)),
91
+ )[0]
92
+
93
+ # Return optimized model
94
+ return eqx.combine(dyn, static)
bayinx/core/_parameter.py CHANGED
@@ -5,6 +5,8 @@ import jax.tree as jt
5
5
  from jaxtyping import PyTree
6
6
 
7
7
  T = TypeVar("T", bound=PyTree)
8
+
9
+
8
10
  class Parameter(eqx.Module, Generic[T]):
9
11
  """
10
12
  A container for a parameter of a `Model`.
@@ -13,7 +13,9 @@ from optax import GradientTransformation, OptState, Schedule
13
13
 
14
14
  from ._model import Model
15
15
 
16
- M = TypeVar('M', bound=Model)
16
+ M = TypeVar("M", bound=Model)
17
+
18
+
17
19
  class Variational(eqx.Module, Generic[M]):
18
20
  """
19
21
  An abstract base class used to define variational methods.
@@ -80,6 +82,7 @@ class Variational(eqx.Module, Generic[M]):
80
82
  # Evaluate posterior density
81
83
  return model.eval(data)
82
84
 
85
+ # TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
83
86
  @eqx.filter_jit
84
87
  def fit(
85
88
  self,
@@ -107,7 +110,10 @@ class Variational(eqx.Module, Generic[M]):
107
110
 
108
111
  # Construct scheduler
109
112
  schedule: Schedule = opx.warmup_cosine_decay_schedule(
110
- init_value=1e-16, peak_value=learning_rate, warmup_steps=int(max_iters/10), decay_steps=max_iters-int(max_iters/10)
113
+ init_value=1e-16,
114
+ peak_value=learning_rate,
115
+ warmup_steps=int(max_iters / 10),
116
+ decay_steps=max_iters - int(max_iters / 10),
111
117
  )
112
118
 
113
119
  # Initialize optimizer
@@ -163,7 +169,11 @@ class Variational(eqx.Module, Generic[M]):
163
169
 
164
170
  @eqx.filter_jit
165
171
  def posterior_predictive(
166
- self, func: Callable[[M, Any], Array], n: int, data: Any = None, key: Key = jr.PRNGKey(0)
172
+ self,
173
+ func: Callable[[M, Any], Array],
174
+ n: int,
175
+ data: Any = None,
176
+ key: Key = jr.PRNGKey(0),
167
177
  ) -> Array:
168
178
  # Sample draws from the variational approximation
169
179
  draws: Array = self.sample(n, key)
bayinx/dists/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from bayinx.dists import censored, gamma2, normal, posnormal
2
2
 
3
- __all__ = ['censored', "gamma2", "normal", "posnormal"]
3
+ __all__ = ["censored", "gamma2", "normal", "posnormal"]
@@ -78,12 +78,13 @@ def logprob(
78
78
 
79
79
  return evals
80
80
 
81
+
81
82
  def sample(
82
83
  n: int,
83
84
  mu: Float[ArrayLike, "..."],
84
85
  sigma: Float[ArrayLike, "..."],
85
86
  censor: Float[ArrayLike, "..."] = jnp.inf,
86
- key: Key = jr.PRNGKey(0)
87
+ key: Key = jr.PRNGKey(0),
87
88
  ) -> Float[Array, "..."]:
88
89
  """
89
90
  Sample from a right-censored positive Normal distribution.
File without changes
bayinx/mhx/vi/__init__.py CHANGED
@@ -2,4 +2,4 @@ from bayinx.mhx.vi.meanfield import MeanField
2
2
  from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
3
3
  from bayinx.mhx.vi.standard import Standard
4
4
 
5
- __all__ = ['MeanField', 'NormalizingFlow', 'Standard']
5
+ __all__ = ["MeanField", "NormalizingFlow", "Standard"]
@@ -10,7 +10,9 @@ from jaxtyping import Array, Float, Key, Scalar
10
10
  from bayinx.core import Model, Variational
11
11
  from bayinx.dists import normal
12
12
 
13
- M = TypeVar('M', bound=Model)
13
+ M = TypeVar("M", bound=Model)
14
+
15
+
14
16
  class MeanField(Variational, Generic[M]):
15
17
  """
16
18
  A fully factorized Gaussian approximation to a posterior distribution.
@@ -19,7 +21,7 @@ class MeanField(Variational, Generic[M]):
19
21
  - `var_params`: The variational parameters for the approximation.
20
22
  """
21
23
 
22
- var_params: Dict[str, Float[Array, "..."]] #todo: just expand to attributes
24
+ var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
23
25
 
24
26
  def __init__(self, model: M):
25
27
  """
@@ -9,7 +9,9 @@ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Flow, Model, Variational
11
11
 
12
- M = TypeVar('M', bound=Model)
12
+ M = TypeVar("M", bound=Model)
13
+
14
+
13
15
  class NormalizingFlow(Variational, Generic[M]):
14
16
  """
15
17
  An ordered collection of diffeomorphisms that map a base distribution to a
bayinx/mhx/vi/standard.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
1
  import equinox as eqx
3
2
  import jax.numpy as jnp
4
3
  import jax.random as jr
@@ -17,6 +16,7 @@ class Standard(Variational[M]):
17
16
  # Attributes
18
17
  - `dim`: Dimension of the parameter space.
19
18
  """
19
+
20
20
  dim: int
21
21
 
22
22
  def __init__(self, model: M):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.11
3
+ Version: 0.3.12
4
4
  Summary: Bayesian Inference with JAX
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -1,14 +1,15 @@
1
1
  bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
2
2
  bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx/constraints/__init__.py,sha256=PiWXZKi7YdbTMKvw-OE5f-t87jJT893uAFrwWWBfOdg,64
3
+ bayinx/constraints/__init__.py,sha256=027WJxRLkybXZkmusfvR6iZayY2pDid7Tw6TTTeA6ko,64
4
4
  bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
5
- bayinx/core/__init__.py,sha256=bZvQITgW0DWuPKl3wCLKt6WHKogYKx8Zz36g8z9Aung,253
5
+ bayinx/core/__init__.py,sha256=Qmy0EjzqqKwI9F8rjmC9j6J8hiDw6A54yOck2WuQJkY,344
6
6
  bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
7
7
  bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
8
  bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
9
- bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
10
- bayinx/core/_variational.py,sha256=yU5_fsolFD8De0mJptjRBVq5lA7rQsDvT3qmVGbW-gI,5460
11
- bayinx/dists/__init__.py,sha256=9DdPea7HAnBOzaV_4gM5noPX8YCb_p06d8PJvGfFy3Y,118
9
+ bayinx/core/_optimization.py,sha256=cXO07guDG5kd64kWq4_de-gLgxTT6vIOU3IOL3TMl6U,2583
10
+ bayinx/core/_parameter.py,sha256=UfLonbTwxzr1g76Cf3HzAh9u4UBtlcaLByYLXgq-aCQ,1082
11
+ bayinx/core/_variational.py,sha256=bQYiN8c4AGPt4hsNT68zN7J6o0fdsLGwVjbDyl62LnI,5639
12
+ bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
12
13
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
13
14
  bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
14
15
  bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
@@ -18,18 +19,19 @@ bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U
18
19
  bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
19
20
  bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
20
21
  bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
21
- bayinx/dists/censored/posnormal/r.py,sha256=Ypi6w_t53pAzRVzjcStx2RhozkAlCDLnQmgKykhpQQ4,3426
22
+ bayinx/dists/censored/posnormal/r.py,sha256=r8QfEThaaqsqqrr6PHWiY0EdfnjQcbKdTAHN2GmDVzI,3428
22
23
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
23
- bayinx/mhx/vi/__init__.py,sha256=2woNB5oZxfs8pZCkOfzriGahRFLzkLdkTj8_keTN0I0,205
24
- bayinx/mhx/vi/meanfield.py,sha256=Z7kGQAyp5iB8rEdjbwAbVTFH4GwxlTKDZFbdJ-FN5Vs,3739
25
- bayinx/mhx/vi/normalizing_flow.py,sha256=8pLMDdZPIt5wlgbhHWSFY1ChSWM9pvSD2bQx3zgz1F8,4710
26
- bayinx/mhx/vi/standard.py,sha256=W-ZvigJkUpqVlREgiFm9io8ansT1XpZwq5AqSmdv--E,1578
24
+ bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
26
+ bayinx/mhx/vi/meanfield.py,sha256=kx5WeD93-XO8NHxd4L4pZ8V19Y9B6j-yE3Y5OBXMcTk,3743
27
+ bayinx/mhx/vi/normalizing_flow.py,sha256=vzLu5H1G1-pBqhgHWmIZkUTyPE1DxC9vBwpiZeIyu1I,4712
28
+ bayinx/mhx/vi/standard.py,sha256=LYgglaGQMGmXpzFR4eMJnXkl2PhBeggbXMvO5zJpf2c,1578
27
29
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
28
30
  bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
29
31
  bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
30
32
  bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
31
33
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
32
- bayinx-0.3.11.dist-info/METADATA,sha256=0D6AlI-tcUBQiLQXnCBPgj62oy-LzOQxV93Rw2cv5cA,3080
33
- bayinx-0.3.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
34
- bayinx-0.3.11.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
35
- bayinx-0.3.11.dist-info/RECORD,,
34
+ bayinx-0.3.12.dist-info/METADATA,sha256=RvwZfeFmarMW4SlQhurD5TLnAS0sge9lVNcHPUNjIIM,3080
35
+ bayinx-0.3.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ bayinx-0.3.12.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
37
+ bayinx-0.3.12.dist-info/RECORD,,