bayinx 0.3.11__tar.gz → 0.3.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {bayinx-0.3.11 → bayinx-0.3.13}/PKG-INFO +1 -1
  2. {bayinx-0.3.11 → bayinx-0.3.13}/pyproject.toml +2 -2
  3. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/constraints/__init__.py +1 -1
  4. bayinx-0.3.13/src/bayinx/core/__init__.py +16 -0
  5. bayinx-0.3.13/src/bayinx/core/_optimization.py +94 -0
  6. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/core/_parameter.py +2 -0
  7. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/core/_variational.py +13 -3
  8. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/__init__.py +1 -1
  9. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/censored/posnormal/r.py +7 -7
  10. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/uniform.py +39 -5
  11. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/vi/__init__.py +1 -1
  12. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/vi/meanfield.py +4 -2
  13. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/vi/normalizing_flow.py +3 -1
  14. bayinx-0.3.13/tests/__init__.py +0 -0
  15. bayinx-0.3.11/src/bayinx/core/__init__.py +0 -7
  16. {bayinx-0.3.11 → bayinx-0.3.13}/.github/workflows/release_and_publish.yml +0 -0
  17. {bayinx-0.3.11 → bayinx-0.3.13}/.gitignore +0 -0
  18. {bayinx-0.3.11 → bayinx-0.3.13}/LICENSE +0 -0
  19. {bayinx-0.3.11 → bayinx-0.3.13}/README.md +0 -0
  20. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/__init__.py +0 -0
  21. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/constraints/lower.py +0 -0
  22. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/core/_constraint.py +0 -0
  23. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/core/_flow.py +0 -0
  24. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/core/_model.py +0 -0
  25. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/bernoulli.py +0 -0
  26. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/censored/__init__.py +0 -0
  27. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/censored/gamma2/__init__.py +0 -0
  28. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/censored/gamma2/r.py +0 -0
  29. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/censored/posnormal/__init__.py +0 -0
  30. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/gamma2.py +0 -0
  31. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/normal.py +0 -0
  32. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/dists/posnormal.py +0 -0
  33. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/__init__.py +0 -0
  34. {bayinx-0.3.11/tests → bayinx-0.3.13/src/bayinx/mhx/opt}/__init__.py +0 -0
  35. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  36. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
  37. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  38. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  39. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  40. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/mhx/vi/standard.py +1 -1
  41. {bayinx-0.3.11 → bayinx-0.3.13}/src/bayinx/py.typed +0 -0
  42. {bayinx-0.3.11 → bayinx-0.3.13}/tests/test_predictive.py +0 -0
  43. {bayinx-0.3.11 → bayinx-0.3.13}/tests/test_variational.py +0 -0
  44. {bayinx-0.3.11 → bayinx-0.3.13}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.11
3
+ Version: 0.3.13
4
4
  Summary: Bayesian Inference with JAX
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.3.11"
3
+ version = "0.3.13"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
19
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
20
 
21
21
  [tool.bumpversion]
22
- current_version = "0.3.11"
22
+ current_version = "0.3.13"
23
23
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
24
  serialize = ["{major}.{minor}.{patch}"]
25
25
  search = "{current_version}"
@@ -1,3 +1,3 @@
1
1
  from bayinx.constraints.lower import Lower
2
2
 
3
- __all__ = ['Lower']
3
+ __all__ = ["Lower"]
@@ -0,0 +1,16 @@
1
+ from ._constraint import Constraint
2
+ from ._flow import Flow
3
+ from ._model import Model, constrain
4
+ from ._optimization import optimize_model
5
+ from ._parameter import Parameter
6
+ from ._variational import Variational
7
+
8
+ __all__ = [
9
+ "Constraint",
10
+ "Flow",
11
+ "Model",
12
+ "constrain",
13
+ "optimize_model",
14
+ "Parameter",
15
+ "Variational",
16
+ ]
@@ -0,0 +1,94 @@
1
+ from typing import Any, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.lax as lax
5
+ import jax.numpy as jnp
6
+ import optax as opx
7
+ from jaxtyping import PyTree, Scalar
8
+ from optax import GradientTransformation, OptState, Schedule
9
+
10
+ from ._model import Model
11
+
12
+
13
+ @eqx.filter_jit
14
+ def optimize_model(
15
+ model: Model,
16
+ max_iters: int,
17
+ data: Any = None,
18
+ learning_rate: float = 1,
19
+ weight_decay: float = 0.0,
20
+ tolerance: float = 1e-4,
21
+ ) -> Model:
22
+ """
23
+ Optimize the dynamic parameters of the model.
24
+
25
+ # Parameters
26
+ - `max_iters`: Maximum number of iterations for the optimization loop.
27
+ - `data`: Data to evaluate the model with.
28
+ - `learning_rate`: Initial learning rate for optimization.
29
+ - `weight_decay`: Weight decay for the optimizer.
30
+ - `tolerance`: Relative tolerance of loss decrease for stopping early (not implemented in the loop).
31
+ """
32
+ # Get dynamic and static componts of model
33
+ dyn, static = eqx.partition(model, model.filter_spec)
34
+
35
+ # Derive gradient for posterior
36
+ @eqx.filter_jit
37
+ @eqx.filter_grad
38
+ def eval_grad(dyn: Model):
39
+ # Reconstruct model
40
+ model: Model = eqx.combine(dyn, static)
41
+
42
+ # Evaluate posterior
43
+ return model.eval(data)
44
+
45
+ # Construct scheduler
46
+ schedule: Schedule = opx.warmup_cosine_decay_schedule(
47
+ init_value=1e-16,
48
+ peak_value=learning_rate,
49
+ warmup_steps=int(max_iters / 10),
50
+ decay_steps=max_iters - int(max_iters / 10),
51
+ )
52
+
53
+ optim: GradientTransformation = opx.chain(
54
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
55
+ )
56
+ opt_state: OptState = optim.init(dyn)
57
+
58
+ @eqx.filter_jit
59
+ def condition(state: Tuple[PyTree, OptState, Scalar]):
60
+ # Unpack iteration state
61
+ current_opt_dyn, opt_state, i = state
62
+
63
+ return i < max_iters
64
+
65
+ @eqx.filter_jit
66
+ def body(state: Tuple[PyTree, OptState, Scalar]):
67
+ # Unpack iteration state
68
+ dyn, opt_state, i = state
69
+
70
+ # Update iteration
71
+ i = i + 1
72
+
73
+ # Evaluate gradient of posterior
74
+ updates = eval_grad(dyn)
75
+
76
+ # Compute updates
77
+ updates, opt_state = optim.update(
78
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
79
+ )
80
+
81
+ # Update model
82
+ dyn = eqx.apply_updates(dyn, updates)
83
+
84
+ return dyn, opt_state, i
85
+
86
+ # Run optimization loop
87
+ dyn = lax.while_loop(
88
+ cond_fun=condition,
89
+ body_fun=body,
90
+ init_val=(dyn, opt_state, jnp.array(0, jnp.uint32)),
91
+ )[0]
92
+
93
+ # Return optimized model
94
+ return eqx.combine(dyn, static)
@@ -5,6 +5,8 @@ import jax.tree as jt
5
5
  from jaxtyping import PyTree
6
6
 
7
7
  T = TypeVar("T", bound=PyTree)
8
+
9
+
8
10
  class Parameter(eqx.Module, Generic[T]):
9
11
  """
10
12
  A container for a parameter of a `Model`.
@@ -13,7 +13,9 @@ from optax import GradientTransformation, OptState, Schedule
13
13
 
14
14
  from ._model import Model
15
15
 
16
- M = TypeVar('M', bound=Model)
16
+ M = TypeVar("M", bound=Model)
17
+
18
+
17
19
  class Variational(eqx.Module, Generic[M]):
18
20
  """
19
21
  An abstract base class used to define variational methods.
@@ -80,6 +82,7 @@ class Variational(eqx.Module, Generic[M]):
80
82
  # Evaluate posterior density
81
83
  return model.eval(data)
82
84
 
85
+ # TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
83
86
  @eqx.filter_jit
84
87
  def fit(
85
88
  self,
@@ -107,7 +110,10 @@ class Variational(eqx.Module, Generic[M]):
107
110
 
108
111
  # Construct scheduler
109
112
  schedule: Schedule = opx.warmup_cosine_decay_schedule(
110
- init_value=1e-16, peak_value=learning_rate, warmup_steps=int(max_iters/10), decay_steps=max_iters-int(max_iters/10)
113
+ init_value=1e-16,
114
+ peak_value=learning_rate,
115
+ warmup_steps=int(max_iters / 10),
116
+ decay_steps=max_iters - int(max_iters / 10),
111
117
  )
112
118
 
113
119
  # Initialize optimizer
@@ -163,7 +169,11 @@ class Variational(eqx.Module, Generic[M]):
163
169
 
164
170
  @eqx.filter_jit
165
171
  def posterior_predictive(
166
- self, func: Callable[[M, Any], Array], n: int, data: Any = None, key: Key = jr.PRNGKey(0)
172
+ self,
173
+ func: Callable[[M, Any], Array],
174
+ n: int,
175
+ data: Any = None,
176
+ key: Key = jr.PRNGKey(0),
167
177
  ) -> Array:
168
178
  # Sample draws from the variational approximation
169
179
  draws: Array = self.sample(n, key)
@@ -1,3 +1,3 @@
1
1
  from bayinx.dists import censored, gamma2, normal, posnormal
2
2
 
3
- __all__ = ['censored', "gamma2", "normal", "posnormal"]
3
+ __all__ = ["censored", "gamma2", "normal", "posnormal"]
@@ -1,8 +1,9 @@
1
1
  import jax.numpy as jnp
2
2
  import jax.random as jr
3
+ from jax.scipy.special import ndtri
3
4
  from jaxtyping import Array, ArrayLike, Float, Key
4
5
 
5
- from bayinx.dists import posnormal
6
+ from bayinx.dists import normal, posnormal
6
7
 
7
8
 
8
9
  def prob(
@@ -78,12 +79,13 @@ def logprob(
78
79
 
79
80
  return evals
80
81
 
82
+
81
83
  def sample(
82
84
  n: int,
83
85
  mu: Float[ArrayLike, "..."],
84
86
  sigma: Float[ArrayLike, "..."],
85
87
  censor: Float[ArrayLike, "..."] = jnp.inf,
86
- key: Key = jr.PRNGKey(0)
88
+ key: Key = jr.PRNGKey(0),
87
89
  ) -> Float[Array, "..."]:
88
90
  """
89
91
  Sample from a right-censored positive Normal distribution.
@@ -107,10 +109,8 @@ def sample(
107
109
  # Derive shape
108
110
  shape = (n,) + jnp.broadcast_shapes(mu.shape, sigma.shape, censor.shape)
109
111
 
110
- # Draw from positive normal
111
- draws = jr.truncated_normal(key, 0.0, jnp.inf, shape) * sigma + mu
112
-
113
- # Censor values
114
- draws = jnp.where(censor <= draws, censor, draws)
112
+ # Construct draws
113
+ draws = jr.uniform(key, shape)
114
+ draws = mu + sigma * ndtri(normal.cdf(-mu/sigma, 0.0, 1.0) + draws * normal.cdf(mu/sigma, 0.0, 1.0))
115
115
 
116
116
  return draws
@@ -1,6 +1,7 @@
1
1
  import jax.lax as _lax
2
2
  import jax.numpy as jnp
3
- from jaxtyping import Array, ArrayLike, Float
3
+ import jax.random as jr
4
+ from jaxtyping import Array, ArrayLike, Float, Key
4
5
 
5
6
 
6
7
  def prob(
@@ -17,8 +18,10 @@ def prob(
17
18
  # Returns
18
19
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
19
20
  """
21
+ # Cast to Array
22
+ x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
20
23
 
21
- return 1.0 / (ub - lb) # pyright: ignore
24
+ return 1.0 / (ub - lb)
22
25
 
23
26
 
24
27
  def logprob(
@@ -35,8 +38,10 @@ def logprob(
35
38
  # Returns
36
39
  The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
37
40
  """
41
+ # Cast to Array
42
+ x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
38
43
 
39
- return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
44
+ return _lax.log(1.0) - _lax.log(ub - lb)
40
45
 
41
46
 
42
47
  def uprob(
@@ -53,8 +58,10 @@ def uprob(
53
58
  # Returns
54
59
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
55
60
  """
61
+ # Cast to Array
62
+ x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
56
63
 
57
- return jnp.ones(jnp.broadcast_arrays(x, lb, ub))
64
+ return jnp.ones(jnp.broadcast_shapes(x.shape, lb.shape, ub.shape))
58
65
 
59
66
 
60
67
  def ulogprob(
@@ -71,5 +78,32 @@ def ulogprob(
71
78
  # Returns
72
79
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
73
80
  """
81
+ # Cast to Array
82
+ x, lb, ub = jnp.asarray(x), jnp.asarray(lb), jnp.asarray(ub)
74
83
 
75
- return jnp.zeros(jnp.broadcast_arrays(x, lb, ub))
84
+ return jnp.zeros(jnp.broadcast_shapes(x.shape, lb.shape, ub.shape))
85
+
86
+ def sample(
87
+ n: int, lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."], key: Key = jr.PRNGKey(0),
88
+ ) -> Float[Array, "..."]:
89
+ """
90
+ Sample from a Uniform distribution.
91
+
92
+ # Parameters
93
+ - `n`: Number of draws to sample per-parameter.
94
+ - `lb`: The lower bound parameter(s).
95
+ - `ub`: The upper bound parameter(s).
96
+
97
+ # Returns
98
+ Draws from a Uniform distribution. The output will have the shape of (n,) + the broadcasted shapes of `lb` and `ub`.
99
+ """
100
+ # Cast to Array
101
+ lb, ub = jnp.asarray(lb), jnp.asarray(ub)
102
+
103
+ # Derive shape
104
+ shape = (n,) + jnp.broadcast_shapes(lb.shape, ub.shape)
105
+
106
+ # Construct draws
107
+ draws = jr.uniform(key, shape, minval = lb, maxval = ub)
108
+
109
+ return draws
@@ -2,4 +2,4 @@ from bayinx.mhx.vi.meanfield import MeanField
2
2
  from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
3
3
  from bayinx.mhx.vi.standard import Standard
4
4
 
5
- __all__ = ['MeanField', 'NormalizingFlow', 'Standard']
5
+ __all__ = ["MeanField", "NormalizingFlow", "Standard"]
@@ -10,7 +10,9 @@ from jaxtyping import Array, Float, Key, Scalar
10
10
  from bayinx.core import Model, Variational
11
11
  from bayinx.dists import normal
12
12
 
13
- M = TypeVar('M', bound=Model)
13
+ M = TypeVar("M", bound=Model)
14
+
15
+
14
16
  class MeanField(Variational, Generic[M]):
15
17
  """
16
18
  A fully factorized Gaussian approximation to a posterior distribution.
@@ -19,7 +21,7 @@ class MeanField(Variational, Generic[M]):
19
21
  - `var_params`: The variational parameters for the approximation.
20
22
  """
21
23
 
22
- var_params: Dict[str, Float[Array, "..."]] #todo: just expand to attributes
24
+ var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
23
25
 
24
26
  def __init__(self, model: M):
25
27
  """
@@ -9,7 +9,9 @@ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Flow, Model, Variational
11
11
 
12
- M = TypeVar('M', bound=Model)
12
+ M = TypeVar("M", bound=Model)
13
+
14
+
13
15
  class NormalizingFlow(Variational, Generic[M]):
14
16
  """
15
17
  An ordered collection of diffeomorphisms that map a base distribution to a
File without changes
@@ -1,7 +0,0 @@
1
- from ._constraint import Constraint
2
- from ._flow import Flow
3
- from ._model import Model, constrain
4
- from ._parameter import Parameter
5
- from ._variational import Variational
6
-
7
- __all__ = ["Constraint", "Flow", "Model", "constrain", "Parameter", "Variational"]
File without changes
File without changes
File without changes
File without changes
@@ -1,4 +1,3 @@
1
-
2
1
  import equinox as eqx
3
2
  import jax.numpy as jnp
4
3
  import jax.random as jr
@@ -17,6 +16,7 @@ class Standard(Variational[M]):
17
16
  # Attributes
18
17
  - `dim`: Dimension of the parameter space.
19
18
  """
19
+
20
20
  dim: int
21
21
 
22
22
  def __init__(self, model: M):
File without changes
File without changes