bayinx 0.3.10__tar.gz → 0.3.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {bayinx-0.3.10 → bayinx-0.3.12}/PKG-INFO +1 -1
  2. {bayinx-0.3.10 → bayinx-0.3.12}/pyproject.toml +2 -2
  3. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/constraints/__init__.py +1 -1
  4. bayinx-0.3.12/src/bayinx/core/__init__.py +16 -0
  5. bayinx-0.3.12/src/bayinx/core/_optimization.py +94 -0
  6. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/core/_parameter.py +2 -0
  7. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/core/_variational.py +15 -5
  8. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/__init__.py +1 -1
  9. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/censored/posnormal/r.py +2 -1
  10. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/vi/__init__.py +1 -1
  11. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/vi/meanfield.py +4 -2
  12. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/vi/normalizing_flow.py +3 -1
  13. bayinx-0.3.12/tests/__init__.py +0 -0
  14. bayinx-0.3.10/src/bayinx/core/__init__.py +0 -7
  15. {bayinx-0.3.10 → bayinx-0.3.12}/.github/workflows/release_and_publish.yml +0 -0
  16. {bayinx-0.3.10 → bayinx-0.3.12}/.gitignore +0 -0
  17. {bayinx-0.3.10 → bayinx-0.3.12}/LICENSE +0 -0
  18. {bayinx-0.3.10 → bayinx-0.3.12}/README.md +0 -0
  19. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/__init__.py +0 -0
  20. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/constraints/lower.py +0 -0
  21. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/core/_constraint.py +0 -0
  22. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/core/_flow.py +0 -0
  23. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/core/_model.py +0 -0
  24. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/bernoulli.py +0 -0
  25. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/censored/__init__.py +0 -0
  26. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/censored/gamma2/__init__.py +0 -0
  27. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/censored/gamma2/r.py +0 -0
  28. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/censored/posnormal/__init__.py +0 -0
  29. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/gamma2.py +0 -0
  30. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/normal.py +0 -0
  31. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/posnormal.py +0 -0
  32. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/dists/uniform.py +0 -0
  33. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/__init__.py +0 -0
  34. {bayinx-0.3.10/tests → bayinx-0.3.12/src/bayinx/mhx/opt}/__init__.py +0 -0
  35. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  36. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
  37. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  38. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  39. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  40. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/mhx/vi/standard.py +1 -1
  41. {bayinx-0.3.10 → bayinx-0.3.12}/src/bayinx/py.typed +0 -0
  42. {bayinx-0.3.10 → bayinx-0.3.12}/tests/test_predictive.py +0 -0
  43. {bayinx-0.3.10 → bayinx-0.3.12}/tests/test_variational.py +0 -0
  44. {bayinx-0.3.10 → bayinx-0.3.12}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.10
3
+ Version: 0.3.12
4
4
  Summary: Bayesian Inference with JAX
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.3.10"
3
+ version = "0.3.12"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -19,7 +19,7 @@ build-backend = "hatchling.build"
19
19
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
20
 
21
21
  [tool.bumpversion]
22
- current_version = "0.3.10"
22
+ current_version = "0.3.12"
23
23
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
24
  serialize = ["{major}.{minor}.{patch}"]
25
25
  search = "{current_version}"
@@ -1,3 +1,3 @@
1
1
  from bayinx.constraints.lower import Lower
2
2
 
3
- __all__ = ['Lower']
3
+ __all__ = ["Lower"]
@@ -0,0 +1,16 @@
1
+ from ._constraint import Constraint
2
+ from ._flow import Flow
3
+ from ._model import Model, constrain
4
+ from ._optimization import optimize_model
5
+ from ._parameter import Parameter
6
+ from ._variational import Variational
7
+
8
+ __all__ = [
9
+ "Constraint",
10
+ "Flow",
11
+ "Model",
12
+ "constrain",
13
+ "optimize_model",
14
+ "Parameter",
15
+ "Variational",
16
+ ]
@@ -0,0 +1,94 @@
1
+ from typing import Any, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.lax as lax
5
+ import jax.numpy as jnp
6
+ import optax as opx
7
+ from jaxtyping import PyTree, Scalar
8
+ from optax import GradientTransformation, OptState, Schedule
9
+
10
+ from ._model import Model
11
+
12
+
13
+ @eqx.filter_jit
14
+ def optimize_model(
15
+ model: Model,
16
+ max_iters: int,
17
+ data: Any = None,
18
+ learning_rate: float = 1,
19
+ weight_decay: float = 0.0,
20
+ tolerance: float = 1e-4,
21
+ ) -> Model:
22
+ """
23
+ Optimize the dynamic parameters of the model.
24
+
25
+ # Parameters
26
+ - `max_iters`: Maximum number of iterations for the optimization loop.
27
+ - `data`: Data to evaluate the model with.
28
+ - `learning_rate`: Initial learning rate for optimization.
29
+ - `weight_decay`: Weight decay for the optimizer.
30
+ - `tolerance`: Relative tolerance of loss decrease for stopping early (not implemented in the loop).
31
+ """
32
+ # Get dynamic and static componts of model
33
+ dyn, static = eqx.partition(model, model.filter_spec)
34
+
35
+ # Derive gradient for posterior
36
+ @eqx.filter_jit
37
+ @eqx.filter_grad
38
+ def eval_grad(dyn: Model):
39
+ # Reconstruct model
40
+ model: Model = eqx.combine(dyn, static)
41
+
42
+ # Evaluate posterior
43
+ return model.eval(data)
44
+
45
+ # Construct scheduler
46
+ schedule: Schedule = opx.warmup_cosine_decay_schedule(
47
+ init_value=1e-16,
48
+ peak_value=learning_rate,
49
+ warmup_steps=int(max_iters / 10),
50
+ decay_steps=max_iters - int(max_iters / 10),
51
+ )
52
+
53
+ optim: GradientTransformation = opx.chain(
54
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
55
+ )
56
+ opt_state: OptState = optim.init(dyn)
57
+
58
+ @eqx.filter_jit
59
+ def condition(state: Tuple[PyTree, OptState, Scalar]):
60
+ # Unpack iteration state
61
+ current_opt_dyn, opt_state, i = state
62
+
63
+ return i < max_iters
64
+
65
+ @eqx.filter_jit
66
+ def body(state: Tuple[PyTree, OptState, Scalar]):
67
+ # Unpack iteration state
68
+ dyn, opt_state, i = state
69
+
70
+ # Update iteration
71
+ i = i + 1
72
+
73
+ # Evaluate gradient of posterior
74
+ updates = eval_grad(dyn)
75
+
76
+ # Compute updates
77
+ updates, opt_state = optim.update(
78
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
79
+ )
80
+
81
+ # Update model
82
+ dyn = eqx.apply_updates(dyn, updates)
83
+
84
+ return dyn, opt_state, i
85
+
86
+ # Run optimization loop
87
+ dyn = lax.while_loop(
88
+ cond_fun=condition,
89
+ body_fun=body,
90
+ init_val=(dyn, opt_state, jnp.array(0, jnp.uint32)),
91
+ )[0]
92
+
93
+ # Return optimized model
94
+ return eqx.combine(dyn, static)
@@ -5,6 +5,8 @@ import jax.tree as jt
5
5
  from jaxtyping import PyTree
6
6
 
7
7
  T = TypeVar("T", bound=PyTree)
8
+
9
+
8
10
  class Parameter(eqx.Module, Generic[T]):
9
11
  """
10
12
  A container for a parameter of a `Model`.
@@ -13,7 +13,9 @@ from optax import GradientTransformation, OptState, Schedule
13
13
 
14
14
  from ._model import Model
15
15
 
16
- M = TypeVar('M', bound=Model)
16
+ M = TypeVar("M", bound=Model)
17
+
18
+
17
19
  class Variational(eqx.Module, Generic[M]):
18
20
  """
19
21
  An abstract base class used to define variational methods.
@@ -80,6 +82,7 @@ class Variational(eqx.Module, Generic[M]):
80
82
  # Evaluate posterior density
81
83
  return model.eval(data)
82
84
 
85
+ # TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
83
86
  @eqx.filter_jit
84
87
  def fit(
85
88
  self,
@@ -107,7 +110,10 @@ class Variational(eqx.Module, Generic[M]):
107
110
 
108
111
  # Construct scheduler
109
112
  schedule: Schedule = opx.warmup_cosine_decay_schedule(
110
- init_value=1e-16, peak_value=learning_rate, warmup_steps=int(max_iters/10), decay_steps=max_iters-int(max_iters/10)
113
+ init_value=1e-16,
114
+ peak_value=learning_rate,
115
+ warmup_steps=int(max_iters / 10),
116
+ decay_steps=max_iters - int(max_iters / 10),
111
117
  )
112
118
 
113
119
  # Initialize optimizer
@@ -163,7 +169,11 @@ class Variational(eqx.Module, Generic[M]):
163
169
 
164
170
  @eqx.filter_jit
165
171
  def posterior_predictive(
166
- self, func: Callable[[M, Any], Array], n: int, data: Any = None, key: Key = jr.PRNGKey(0)
172
+ self,
173
+ func: Callable[[M, Any], Array],
174
+ n: int,
175
+ data: Any = None,
176
+ key: Key = jr.PRNGKey(0),
167
177
  ) -> Array:
168
178
  # Sample draws from the variational approximation
169
179
  draws: Array = self.sample(n, key)
@@ -171,11 +181,11 @@ class Variational(eqx.Module, Generic[M]):
171
181
  # Evaluate posterior predictive
172
182
  @jax.jit
173
183
  @jax.vmap
174
- def evaluate(draw: Array, data: Any = None):
184
+ def evaluate(draw: Array):
175
185
  # Reconstruct model
176
186
  model: M = self._unflatten(draw)
177
187
 
178
188
  # Evaluate
179
189
  return func(model, data)
180
190
 
181
- return evaluate(draws, data)
191
+ return evaluate(draws)
@@ -1,3 +1,3 @@
1
1
  from bayinx.dists import censored, gamma2, normal, posnormal
2
2
 
3
- __all__ = ['censored', "gamma2", "normal", "posnormal"]
3
+ __all__ = ["censored", "gamma2", "normal", "posnormal"]
@@ -78,12 +78,13 @@ def logprob(
78
78
 
79
79
  return evals
80
80
 
81
+
81
82
  def sample(
82
83
  n: int,
83
84
  mu: Float[ArrayLike, "..."],
84
85
  sigma: Float[ArrayLike, "..."],
85
86
  censor: Float[ArrayLike, "..."] = jnp.inf,
86
- key: Key = jr.PRNGKey(0)
87
+ key: Key = jr.PRNGKey(0),
87
88
  ) -> Float[Array, "..."]:
88
89
  """
89
90
  Sample from a right-censored positive Normal distribution.
@@ -2,4 +2,4 @@ from bayinx.mhx.vi.meanfield import MeanField
2
2
  from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
3
3
  from bayinx.mhx.vi.standard import Standard
4
4
 
5
- __all__ = ['MeanField', 'NormalizingFlow', 'Standard']
5
+ __all__ = ["MeanField", "NormalizingFlow", "Standard"]
@@ -10,7 +10,9 @@ from jaxtyping import Array, Float, Key, Scalar
10
10
  from bayinx.core import Model, Variational
11
11
  from bayinx.dists import normal
12
12
 
13
- M = TypeVar('M', bound=Model)
13
+ M = TypeVar("M", bound=Model)
14
+
15
+
14
16
  class MeanField(Variational, Generic[M]):
15
17
  """
16
18
  A fully factorized Gaussian approximation to a posterior distribution.
@@ -19,7 +21,7 @@ class MeanField(Variational, Generic[M]):
19
21
  - `var_params`: The variational parameters for the approximation.
20
22
  """
21
23
 
22
- var_params: Dict[str, Float[Array, "..."]] #todo: just expand to attributes
24
+ var_params: Dict[str, Float[Array, "..."]] # todo: just expand to attributes
23
25
 
24
26
  def __init__(self, model: M):
25
27
  """
@@ -9,7 +9,9 @@ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Flow, Model, Variational
11
11
 
12
- M = TypeVar('M', bound=Model)
12
+ M = TypeVar("M", bound=Model)
13
+
14
+
13
15
  class NormalizingFlow(Variational, Generic[M]):
14
16
  """
15
17
  An ordered collection of diffeomorphisms that map a base distribution to a
File without changes
@@ -1,7 +0,0 @@
1
- from ._constraint import Constraint
2
- from ._flow import Flow
3
- from ._model import Model, constrain
4
- from ._parameter import Parameter
5
- from ._variational import Variational
6
-
7
- __all__ = ["Constraint", "Flow", "Model", "constrain", "Parameter", "Variational"]
File without changes
File without changes
File without changes
File without changes
@@ -1,4 +1,3 @@
1
-
2
1
  import equinox as eqx
3
2
  import jax.numpy as jnp
4
3
  import jax.random as jr
@@ -17,6 +16,7 @@ class Standard(Variational[M]):
17
16
  # Attributes
18
17
  - `dim`: Dimension of the parameter space.
19
18
  """
19
+
20
20
  dim: int
21
21
 
22
22
  def __init__(self, model: M):
File without changes
File without changes