bayinx 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,3 @@
1
1
  from bayinx.constraints.lower import Lower
2
2
 
3
- __all__ = ['Lower']
3
+ __all__ = ["Lower"]
bayinx/core/__init__.py CHANGED
@@ -1,7 +1,16 @@
1
1
  from ._constraint import Constraint
2
2
  from ._flow import Flow
3
3
  from ._model import Model, constrain
4
+ from ._optimization import optimize_model
4
5
  from ._parameter import Parameter
5
6
  from ._variational import Variational
6
7
 
7
- __all__ = ["Constraint", "Flow", "Model", "constrain", "Parameter", "Variational"]
8
+ __all__ = [
9
+ "Constraint",
10
+ "Flow",
11
+ "Model",
12
+ "constrain",
13
+ "optimize_model",
14
+ "Parameter",
15
+ "Variational",
16
+ ]
@@ -0,0 +1,94 @@
1
+ from typing import Any, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.lax as lax
5
+ import jax.numpy as jnp
6
+ import optax as opx
7
+ from jaxtyping import PyTree, Scalar
8
+ from optax import GradientTransformation, OptState, Schedule
9
+
10
+ from ._model import Model
11
+
12
+
13
+ @eqx.filter_jit
14
+ def optimize_model(
15
+ model: Model,
16
+ max_iters: int,
17
+ data: Any = None,
18
+ learning_rate: float = 1,
19
+ weight_decay: float = 0.0,
20
+ tolerance: float = 1e-4,
21
+ ) -> Model:
22
+ """
23
+ Optimize the dynamic parameters of the model.
24
+
25
+ # Parameters
26
+ - `max_iters`: Maximum number of iterations for the optimization loop.
27
+ - `data`: Data to evaluate the model with.
28
+ - `learning_rate`: Initial learning rate for optimization.
29
+ - `weight_decay`: Weight decay for the optimizer.
30
+ - `tolerance`: Relative tolerance of loss decrease for stopping early (not implemented in the loop).
31
+ """
32
+ # Get dynamic and static componts of model
33
+ dyn, static = eqx.partition(model, model.filter_spec)
34
+
35
+ # Derive gradient for posterior
36
+ @eqx.filter_jit
37
+ @eqx.filter_grad
38
+ def eval_grad(dyn: Model):
39
+ # Reconstruct model
40
+ model: Model = eqx.combine(dyn, static)
41
+
42
+ # Evaluate posterior
43
+ return model.eval(data)
44
+
45
+ # Construct scheduler
46
+ schedule: Schedule = opx.warmup_cosine_decay_schedule(
47
+ init_value=1e-16,
48
+ peak_value=learning_rate,
49
+ warmup_steps=int(max_iters / 10),
50
+ decay_steps=max_iters - int(max_iters / 10),
51
+ )
52
+
53
+ optim: GradientTransformation = opx.chain(
54
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
55
+ )
56
+ opt_state: OptState = optim.init(dyn)
57
+
58
+ @eqx.filter_jit
59
+ def condition(state: Tuple[PyTree, OptState, Scalar]):
60
+ # Unpack iteration state
61
+ current_opt_dyn, opt_state, i = state
62
+
63
+ return i < max_iters
64
+
65
+ @eqx.filter_jit
66
+ def body(state: Tuple[PyTree, OptState, Scalar]):
67
+ # Unpack iteration state
68
+ dyn, opt_state, i = state
69
+
70
+ # Update iteration
71
+ i = i + 1
72
+
73
+ # Evaluate gradient of posterior
74
+ updates = eval_grad(dyn)
75
+
76
+ # Compute updates
77
+ updates, opt_state = optim.update(
78
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
79
+ )
80
+
81
+ # Update model
82
+ dyn = eqx.apply_updates(dyn, updates)
83
+
84
+ return dyn, opt_state, i
85
+
86
+ # Run optimization loop
87
+ dyn = lax.while_loop(
88
+ cond_fun=condition,
89
+ body_fun=body,
90
+ init_val=(dyn, opt_state, jnp.array(0, jnp.uint32)),
91
+ )[0]
92
+
93
+ # Return optimized model
94
+ return eqx.combine(dyn, static)
bayinx/core/_parameter.py CHANGED
@@ -5,6 +5,8 @@ import jax.tree as jt
5
5
  from jaxtyping import PyTree
6
6
 
7
7
  T = TypeVar("T", bound=PyTree)
8
+
9
+
8
10
  class Parameter(eqx.Module, Generic[T]):
9
11
  """
10
12
  A container for a parameter of a `Model`.
@@ -13,7 +13,9 @@ from optax import GradientTransformation, OptState, Schedule
13
13
 
14
14
  from ._model import Model
15
15
 
16
- M = TypeVar('M', bound=Model)
16
+ M = TypeVar("M", bound=Model)
17
+
18
+
17
19
  class Variational(eqx.Module, Generic[M]):
18
20
  """
19
21
  An abstract base class used to define variational methods.
@@ -80,6 +82,7 @@ class Variational(eqx.Module, Generic[M]):
80
82
  # Evaluate posterior density
81
83
  return model.eval(data)
82
84
 
85
+ # TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
83
86
  @eqx.filter_jit
84
87
  def fit(
85
88
  self,
@@ -107,7 +110,10 @@ class Variational(eqx.Module, Generic[M]):
107
110
 
108
111
  # Construct scheduler
109
112
  schedule: Schedule = opx.warmup_cosine_decay_schedule(
110
- init_value=1e-16, peak_value=learning_rate, warmup_steps=int(max_iters/10), decay_steps=max_iters-int(max_iters/10)
113
+ init_value=1e-16,
114
+ peak_value=learning_rate,
115
+ warmup_steps=int(max_iters / 10),
116
+ decay_steps=max_iters - int(max_iters / 10),
111
117
  )
112
118
 
113
119
  # Initialize optimizer
@@ -163,7 +169,11 @@ class Variational(eqx.Module, Generic[M]):
163
169
 
164
170
  @eqx.filter_jit
165
171
  def posterior_predictive(
166
- self, func: Callable[[M, Any], Array], n: int, data: Any = None, key: Key = jr.PRNGKey(0)
172
+ self,
173
+ func: Callable[[M, Any], Array],
174
+ n: int,
175
+ data: Any = None,
176
+ key: Key = jr.PRNGKey(0),
167
177
  ) -> Array:
168
178
  # Sample draws from the variational approximation
169
179
  draws: Array = self.sample(n, key)
bayinx/dists/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from bayinx.dists import censored, gamma2, normal, posnormal
2
2
 
3
- __all__ = ['censored', "gamma2", "normal", "posnormal"]
3
+ __all__ = ["censored", "gamma2", "normal", "posnormal"]
@@ -1,8 +1,9 @@
1
1
  import jax.numpy as jnp
2
2
  import jax.random as jr
3
+ from jax.scipy.special import ndtri
3
4
  from jaxtyping import Array, ArrayLike, Float, Key
4
5
 
5
- from bayinx.dists import posnormal
6
+ from bayinx.dists import normal, posnormal
6
7
 
7
8
 
8
9
  def prob(
@@ -78,12 +79,13 @@ def logprob(
78
79
 
79
80
  return evals
80
81
 
82
+
81
83
  def sample(
82
84
  n: int,
83
85
  mu: Float[ArrayLike, "..."],
84
86
  sigma: Float[ArrayLike, "..."],
85
87
  censor: Float[ArrayLike, "..."] = jnp.inf,
86
- key: Key = jr.PRNGKey(0)
88
+ key: Key = jr.PRNGKey(0),
87
89
  ) -> Float[Array, "..."]:
88
90
  """
89
91
  Sample from a right-censored positive Normal distribution.
@@ -107,10 +109,8 @@ def sample(
107
109
  # Derive shape
108
110
  shape = (n,) + jnp.broadcast_shapes(mu.shape, sigma.shape, censor.shape)
109
111
 
110
- # Draw from positive normal
111
- draws = jr.truncated_normal(key, 0.0, jnp.inf, shape) * sigma + mu
112
-
113
- # Censor values
114
- draws = jnp.where(censor <= draws, censor, draws)
112
+ # Construct draws
113
+ draws = jr.uniform(key, shape)
114
+ draws = mu + sigma * ndtri(normal.cdf(-mu/sigma, 0.0, 1.0) + draws * normal.cdf(mu/sigma, 0.0, 1.0))
115
115
 
116
116
  return draws
bayinx/dists/uniform.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import jax.lax as _lax
2
2
  import jax.numpy as jnp
3
- from jaxtyping import Array, ArrayLike, Float
3
+ import jax.random as jr
4
+ from jaxtyping import Array, ArrayLike, Float, Key
4
5
 
5
6
 
6
7
  def prob(
@@ -17,8 +18,10 @@ def prob(
17
18
  # Returns
18
19
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
19
20
  """
21
+ # Cast to Array
22
+ x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
20
23
 
21
- return 1.0 / (ub - lb) # pyright: ignore
24
+ return 1.0 / (ub - lb)
22
25
 
23
26
 
24
27
  def logprob(
@@ -35,8 +38,10 @@ def logprob(
35
38
  # Returns
36
39
  The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
37
40
  """
41
+ # Cast to Array
42
+ x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
38
43
 
39
- return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
44
+ return _lax.log(1.0) - _lax.log(ub - lb)
40
45
 
41
46
 
42
47
  def uprob(
@@ -53,8 +58,10 @@ def uprob(
53
58
  # Returns
54
59
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
55
60
  """
61
+ # Cast to Array
62
+ x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
56
63
 
57
- return jnp.ones(jnp.broadcast_arrays(x, lb, ub))
64
+ return jnp.ones(jnp.broadcast_shapes(x.shape, lb.shape, ub.shape))
58
65
 
59
66
 
60
67
  def ulogprob(
@@ -71,5 +78,32 @@ def ulogprob(
71
78
  # Returns
72
79
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
73
80
  """
81
+ # Cast to Array
82
+ x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
74
83
 
75
- return jnp.zeros(jnp.broadcast_arrays(x, lb, ub))
84
+ return jnp.zeros(jnp.broadcast_shapes(x.shape, lb.shape, ub.shape))
85
+
86
+ def sample(
87
+ n: int, lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."], key: Key = jr.PRNGKey(0),
88
+ ) -> Float[Array, "..."]:
89
+ """
90
+ Sample from a Uniform distribution.
91
+
92
+ # Parameters
93
+ - `n`: Number of draws to sample per-parameter.
94
+ - `lb`: The lower bound parameter(s).
95
+ - `ub`: The upper bound parameter(s).
96
+
97
+ # Returns
98
+ Draws from a Uniform distribution. The output will have the shape of (n,) + the broadcasted shapes of `lb` and `ub`.
99
+ """
100
+ # Cast to Array
101
+ lb, ub = jnp.asarray(lb), jnp.asarray(ub)
102
+
103
+ # Derive shape
104
+ shape = (n,) + jnp.broadcast_shapes(lb.shape, ub.shape)
105
+
106
+ # Construct draws
107
+ draws = jr.uniform(key, shape, minval = lb, maxval = ub)
108
+
109
+ return draws
File without changes
bayinx/mhx/vi/__init__.py CHANGED
@@ -2,4 +2,4 @@ from bayinx.mhx.vi.meanfield import MeanField
2
2
  from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
3
3
  from bayinx.mhx.vi.standard import Standard
4
4
 
5
- __all__ = ['MeanField', 'NormalizingFlow', 'Standard']
5
+ __all__ = ["MeanField", "NormalizingFlow", "Standard"]
@@ -10,7 +10,9 @@ from jaxtyping import Array, Float, Key, Scalar
10
10
  from bayinx.core import Model, Variational
11
11
  from bayinx.dists import normal
12
12
 
13
- M = TypeVar('M', bound=Model)
13
+ M = TypeVar("M", bound=Model)
14
+
15
+
14
16
  class MeanField(Variational, Generic[M]):
15
17
  """
16
18
  A fully factorized Gaussian approximation to a posterior distribution.
@@ -19,7 +21,7 @@ class MeanField(Variational, Generic[M]):
19
21
  - `var_params`: The variational parameters for the approximation.
20
22
  """
21
23
 
22
- var_params: Dict[str, Float[Array, "..."]] #todo: just expand to attributes
24
+ var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
23
25
 
24
26
  def __init__(self, model: M):
25
27
  """
@@ -9,7 +9,9 @@ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Flow, Model, Variational
11
11
 
12
- M = TypeVar('M', bound=Model)
12
+ M = TypeVar("M", bound=Model)
13
+
14
+
13
15
  class NormalizingFlow(Variational, Generic[M]):
14
16
  """
15
17
  An ordered collection of diffeomorphisms that map a base distribution to a
bayinx/mhx/vi/standard.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
1
  import equinox as eqx
3
2
  import jax.numpy as jnp
4
3
  import jax.random as jr
@@ -17,6 +16,7 @@ class Standard(Variational[M]):
17
16
  # Attributes
18
17
  - `dim`: Dimension of the parameter space.
19
18
  """
19
+
20
20
  dim: int
21
21
 
22
22
  def __init__(self, model: M):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.11
3
+ Version: 0.3.13
4
4
  Summary: Bayesian Inference with JAX
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -1,35 +1,37 @@
1
1
  bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
2
2
  bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx/constraints/__init__.py,sha256=PiWXZKi7YdbTMKvw-OE5f-t87jJT893uAFrwWWBfOdg,64
3
+ bayinx/constraints/__init__.py,sha256=027WJxRLkybXZkmusfvR6iZayY2pDid7Tw6TTTeA6ko,64
4
4
  bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
5
- bayinx/core/__init__.py,sha256=bZvQITgW0DWuPKl3wCLKt6WHKogYKx8Zz36g8z9Aung,253
5
+ bayinx/core/__init__.py,sha256=Qmy0EjzqqKwI9F8rjmC9j6J8hiDw6A54yOck2WuQJkY,344
6
6
  bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
7
7
  bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
8
  bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
9
- bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
10
- bayinx/core/_variational.py,sha256=yU5_fsolFD8De0mJptjRBVq5lA7rQsDvT3qmVGbW-gI,5460
11
- bayinx/dists/__init__.py,sha256=9DdPea7HAnBOzaV_4gM5noPX8YCb_p06d8PJvGfFy3Y,118
9
+ bayinx/core/_optimization.py,sha256=cXO07guDG5kd64kWq4_de-gLgxTT6vIOU3IOL3TMl6U,2583
10
+ bayinx/core/_parameter.py,sha256=UfLonbTwxzr1g76Cf3HzAh9u4UBtlcaLByYLXgq-aCQ,1082
11
+ bayinx/core/_variational.py,sha256=bQYiN8c4AGPt4hsNT68zN7J6o0fdsLGwVjbDyl62LnI,5639
12
+ bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
12
13
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
13
14
  bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
14
15
  bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
15
16
  bayinx/dists/posnormal.py,sha256=w9plA1EctXwXOiY0doc4ZndjnwptbEZBHHCGdc4gviY,7292
16
- bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
17
+ bayinx/dists/uniform.py,sha256=2ZQxEfAX5TFgSPuQ8joFDuFbd_NfmQ1GvmGGjusqvNQ,3461
17
18
  bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
18
19
  bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
19
20
  bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
20
21
  bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
21
- bayinx/dists/censored/posnormal/r.py,sha256=Ypi6w_t53pAzRVzjcStx2RhozkAlCDLnQmgKykhpQQ4,3426
22
+ bayinx/dists/censored/posnormal/r.py,sha256=wMDt2Am1TD376ms8B-o6PFCJZXmUJd2-aBC-t9kidH4,3456
22
23
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
23
- bayinx/mhx/vi/__init__.py,sha256=2woNB5oZxfs8pZCkOfzriGahRFLzkLdkTj8_keTN0I0,205
24
- bayinx/mhx/vi/meanfield.py,sha256=Z7kGQAyp5iB8rEdjbwAbVTFH4GwxlTKDZFbdJ-FN5Vs,3739
25
- bayinx/mhx/vi/normalizing_flow.py,sha256=8pLMDdZPIt5wlgbhHWSFY1ChSWM9pvSD2bQx3zgz1F8,4710
26
- bayinx/mhx/vi/standard.py,sha256=W-ZvigJkUpqVlREgiFm9io8ansT1XpZwq5AqSmdv--E,1578
24
+ bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
26
+ bayinx/mhx/vi/meanfield.py,sha256=kx5WeD93-XO8NHxd4L4pZ8V19Y9B6j-yE3Y5OBXMcTk,3743
27
+ bayinx/mhx/vi/normalizing_flow.py,sha256=vzLu5H1G1-pBqhgHWmIZkUTyPE1DxC9vBwpiZeIyu1I,4712
28
+ bayinx/mhx/vi/standard.py,sha256=LYgglaGQMGmXpzFR4eMJnXkl2PhBeggbXMvO5zJpf2c,1578
27
29
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
28
30
  bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
29
31
  bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
30
32
  bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
31
33
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
32
- bayinx-0.3.11.dist-info/METADATA,sha256=0D6AlI-tcUBQiLQXnCBPgj62oy-LzOQxV93Rw2cv5cA,3080
33
- bayinx-0.3.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
34
- bayinx-0.3.11.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
35
- bayinx-0.3.11.dist-info/RECORD,,
34
+ bayinx-0.3.13.dist-info/METADATA,sha256=1mcHMTXrzMGPcibMy_vHdBJMNPtT0VUF83eVBS_MJlg,3080
35
+ bayinx-0.3.13.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ bayinx-0.3.13.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
37
+ bayinx-0.3.13.dist-info/RECORD,,