bayinx 0.1.0__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx-0.2.3/.github/workflows/publish.yml +29 -0
- {bayinx-0.1.0 → bayinx-0.2.3}/.gitignore +5 -0
- bayinx-0.2.3/PKG-INFO +40 -0
- bayinx-0.2.3/README.md +27 -0
- bayinx-0.2.3/pyproject.toml +21 -0
- bayinx-0.2.3/src/bayinx/__init__.py +1 -0
- bayinx-0.2.3/src/bayinx/core/__init__.py +3 -0
- bayinx-0.2.3/src/bayinx/core/flow.py +68 -0
- bayinx-0.2.3/src/bayinx/core/model.py +55 -0
- bayinx-0.2.3/src/bayinx/core/utils.py +54 -0
- bayinx-0.2.3/src/bayinx/core/variational.py +159 -0
- bayinx-0.2.3/src/bayinx/dists/bernoulli.py +33 -0
- bayinx-0.2.3/src/bayinx/dists/normal.py +83 -0
- bayinx-0.2.3/src/bayinx/machinery/variational/__init__.py +5 -0
- bayinx-0.2.3/src/bayinx/machinery/variational/flows/__init__.py +3 -0
- bayinx-0.2.3/src/bayinx/machinery/variational/flows/affine.py +68 -0
- bayinx-0.2.3/src/bayinx/machinery/variational/flows/planar.py +76 -0
- bayinx-0.2.3/src/bayinx/machinery/variational/flows/radial.py +95 -0
- bayinx-0.2.3/src/bayinx/machinery/variational/flows/sylvester.py +76 -0
- bayinx-0.2.3/src/bayinx/machinery/variational/meanfield.py +124 -0
- bayinx-0.2.3/src/bayinx/machinery/variational/normalizing_flow.py +152 -0
- bayinx-0.2.3/src/bayinx/machinery/variational/standard.py +67 -0
- bayinx-0.2.3/tests/__init__.py +0 -0
- bayinx-0.2.3/tests/test_variational.py +130 -0
- {bayinx-0.1.0 → bayinx-0.2.3}/uv.lock +190 -12
- bayinx-0.1.0/.python-version +0 -1
- bayinx-0.1.0/.vscode/settings.json +0 -3
- bayinx-0.1.0/PKG-INFO +0 -8
- bayinx-0.1.0/packages/baycompx/pyproject.toml +0 -11
- bayinx-0.1.0/packages/baycompx/src/baycompx/__init__.py +0 -2
- bayinx-0.1.0/packages/baycompx/src/baycompx/dists/normal.py +0 -93
- bayinx-0.1.0/pyproject.toml +0 -21
- {bayinx-0.1.0/packages/baycompx/src/baycompx → bayinx-0.2.3/src/bayinx}/dists/__init__.py +0 -0
- /bayinx-0.1.0/README.md → /bayinx-0.2.3/src/bayinx/dists/binomial.py +0 -0
- /bayinx-0.1.0/packages/baycompx/README.md → /bayinx-0.2.3/src/bayinx/dists/gamma.py +0 -0
- /bayinx-0.1.0/packages/baycompx/src/baycompx/py.typed → /bayinx-0.2.3/src/bayinx/dists/gamma2.py +0 -0
- {bayinx-0.1.0/src/bayinx → bayinx-0.2.3/src/bayinx/machinery}/__init__.py +0 -0
- {bayinx-0.1.0 → bayinx-0.2.3}/src/bayinx/py.typed +0 -0
@@ -0,0 +1,29 @@
|
|
1
|
+
name: "Publish"
|
2
|
+
|
3
|
+
on:
|
4
|
+
release:
|
5
|
+
types: ["published"]
|
6
|
+
|
7
|
+
jobs:
|
8
|
+
run:
|
9
|
+
name: "Build and publish release"
|
10
|
+
runs-on: ubuntu-latest
|
11
|
+
permissions:
|
12
|
+
id-token: write # Required for trusted publishing
|
13
|
+
|
14
|
+
steps:
|
15
|
+
- uses: actions/checkout@v4
|
16
|
+
|
17
|
+
- name: Install uv
|
18
|
+
uses: astral-sh/setup-uv@v3
|
19
|
+
with:
|
20
|
+
enable-cache: true
|
21
|
+
|
22
|
+
- name: Set up Python
|
23
|
+
run: uv python install 3.12
|
24
|
+
|
25
|
+
- name: Build
|
26
|
+
run: uv build
|
27
|
+
|
28
|
+
- name: Publish
|
29
|
+
run: uv publish
|
bayinx-0.2.3/PKG-INFO
ADDED
@@ -0,0 +1,40 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: bayinx
|
3
|
+
Version: 0.2.3
|
4
|
+
Summary: Bayesian Inference with JAX
|
5
|
+
Requires-Python: >=3.12
|
6
|
+
Requires-Dist: equinox>=0.11.12
|
7
|
+
Requires-Dist: jax>=0.4.38
|
8
|
+
Requires-Dist: jaxtyping>=0.2.36
|
9
|
+
Requires-Dist: optax>=0.2.4
|
10
|
+
Requires-Dist: pytest-benchmark>=5.1.0
|
11
|
+
Requires-Dist: pytest>=8.3.5
|
12
|
+
Description-Content-Type: text/markdown
|
13
|
+
|
14
|
+
# <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
15
|
+
|
16
|
+
The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is equivalent to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
|
17
|
+
|
18
|
+
In the short-term, I'm going to focus on:
|
19
|
+
1) Implementing as much machinery as I feel is enough.
|
20
|
+
2) Figuring out how to design the `Model` superclass to have something like the `transformed pars {}` block but unifies transformations and constraints.
|
21
|
+
3) Figuring out how to design the library to automatically recognize what kind of machinery is amenable to a given probabilistic model.
|
22
|
+
|
23
|
+
In the long-term, I'm going to focus on:
|
24
|
+
1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
|
25
|
+
2) How to make working with the posterior as easy as possible.
|
26
|
+
- That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
|
27
|
+
|
28
|
+
Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
|
29
|
+
|
30
|
+
Additionally, when I get around to it I'd like the package documentation to also include theoretical and implementation details for all machinery implemented(with overthinking boxes because I do like that design from McElreath's book).
|
31
|
+
|
32
|
+
|
33
|
+
# TODO
|
34
|
+
- Find some way to discern between models with all floating-point parameters and weirder models with integer parameters. Useful for restricting variational methods like `MeanField` to `Model`s that only have floating-point parameters.
|
35
|
+
- Look into adaptively tuning ADAM hyperparameters.
|
36
|
+
- Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
|
37
|
+
- Low-rank affine flow?
|
38
|
+
- https://arxiv.org/pdf/1803.05649 implement sylvester flows.
|
39
|
+
- Learn how to generate documentation lol.
|
40
|
+
- Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
|
bayinx-0.2.3/README.md
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
# <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
2
|
+
|
3
|
+
The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is equivalent to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
|
4
|
+
|
5
|
+
In the short-term, I'm going to focus on:
|
6
|
+
1) Implementing as much machinery as I feel is enough.
|
7
|
+
2) Figuring out how to design the `Model` superclass to have something like the `transformed pars {}` block but unifies transformations and constraints.
|
8
|
+
3) Figuring out how to design the library to automatically recognize what kind of machinery is amenable to a given probabilistic model.
|
9
|
+
|
10
|
+
In the long-term, I'm going to focus on:
|
11
|
+
1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
|
12
|
+
2) How to make working with the posterior as easy as possible.
|
13
|
+
- That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
|
14
|
+
|
15
|
+
Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
|
16
|
+
|
17
|
+
Additionally, when I get around to it I'd like the package documentation to also include theoretical and implementation details for all machinery implemented(with overthinking boxes because I do like that design from McElreath's book).
|
18
|
+
|
19
|
+
|
20
|
+
# TODO
|
21
|
+
- Find some way to discern between models with all floating-point parameters and weirder models with integer parameters. Useful for restricting variational methods like `MeanField` to `Model`s that only have floating-point parameters.
|
22
|
+
- Look into adaptively tuning ADAM hyperparameters.
|
23
|
+
- Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
|
24
|
+
- Low-rank affine flow?
|
25
|
+
- https://arxiv.org/pdf/1803.05649 implement sylvester flows.
|
26
|
+
- Learn how to generate documentation lol.
|
27
|
+
- Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
|
@@ -0,0 +1,21 @@
|
|
1
|
+
[project]
|
2
|
+
name = "bayinx"
|
3
|
+
version = "0.2.3"
|
4
|
+
description = "Bayesian Inference with JAX"
|
5
|
+
readme = "README.md"
|
6
|
+
requires-python = ">=3.12"
|
7
|
+
dependencies = [
|
8
|
+
"equinox>=0.11.12",
|
9
|
+
"jax>=0.4.38",
|
10
|
+
"jaxtyping>=0.2.36",
|
11
|
+
"optax>=0.2.4",
|
12
|
+
"pytest>=8.3.5",
|
13
|
+
"pytest-benchmark>=5.1.0",
|
14
|
+
]
|
15
|
+
|
16
|
+
[build-system]
|
17
|
+
requires = ["hatchling"]
|
18
|
+
build-backend = "hatchling.build"
|
19
|
+
|
20
|
+
[tool.pytest.ini_options]
|
21
|
+
addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
|
@@ -0,0 +1 @@
|
|
1
|
+
from bayinx.core.model import Model as Model
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Callable, Dict, Self, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
from jaxtyping import Array, Float
|
6
|
+
|
7
|
+
from bayinx.core.utils import __MyMeta
|
8
|
+
|
9
|
+
|
10
|
+
class Flow(eqx.Module, metaclass=__MyMeta):
|
11
|
+
"""
|
12
|
+
A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
|
13
|
+
|
14
|
+
# Attributes
|
15
|
+
- `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
|
16
|
+
- `constraints`: A dictionary of functions that constrain their corresponding parameter.
|
17
|
+
"""
|
18
|
+
|
19
|
+
params: Dict[str, Float[Array, "..."]]
|
20
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def forward(self, draws: Array) -> Array:
|
24
|
+
"""
|
25
|
+
Computes the forward transformation of `draws`.
|
26
|
+
"""
|
27
|
+
pass
|
28
|
+
|
29
|
+
@abstractmethod
|
30
|
+
def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
|
31
|
+
"""
|
32
|
+
Computes the log-absolute-determinant of the Jacobian at `draws` and applies the forward transformation.
|
33
|
+
|
34
|
+
# Returns
|
35
|
+
A tuple of JAX Arrays containing the log-absolute-determinant of the Jacobians and transformed draws.
|
36
|
+
"""
|
37
|
+
pass
|
38
|
+
|
39
|
+
def __init_subclass__(cls):
|
40
|
+
# Add contrain_pars method
|
41
|
+
def constrain_pars(self: Self):
|
42
|
+
"""
|
43
|
+
Constrain `params` to the appropriate domain.
|
44
|
+
|
45
|
+
# Returns
|
46
|
+
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
47
|
+
"""
|
48
|
+
t_params = self.params
|
49
|
+
|
50
|
+
for par, map in self.constraints.items():
|
51
|
+
t_params[par] = map(t_params[par])
|
52
|
+
|
53
|
+
return t_params
|
54
|
+
|
55
|
+
cls.constrain_pars = eqx.filter_jit(constrain_pars)
|
56
|
+
|
57
|
+
# Add transform_pars method if not present
|
58
|
+
if not callable(getattr(cls, "transform_pars", None)):
|
59
|
+
def transform_pars(self: Self) -> Dict[str, Array]:
|
60
|
+
"""
|
61
|
+
Apply a custom transformation to `params` if needed.
|
62
|
+
|
63
|
+
# Returns
|
64
|
+
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
65
|
+
"""
|
66
|
+
return self.constrain_pars()
|
67
|
+
|
68
|
+
cls.transform_pars = eqx.filter_jit(transform_pars)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Callable, Dict
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
from jaxtyping import Array, Scalar
|
6
|
+
|
7
|
+
from bayinx.core.utils import __MyMeta
|
8
|
+
|
9
|
+
|
10
|
+
class Model(eqx.Module, metaclass=__MyMeta):
|
11
|
+
"""
|
12
|
+
A superclass used to define probabilistic models.
|
13
|
+
|
14
|
+
# Attributes
|
15
|
+
- `params`: A dictionary of JAX Arrays representing parameters of the model.
|
16
|
+
- `constraints`: A dictionary of functions that constrain their corresponding parameter.
|
17
|
+
"""
|
18
|
+
|
19
|
+
params: Dict[str, Array]
|
20
|
+
constraints: Dict[str, Callable[[Array], Array]]
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def eval(self, data: Any) -> Scalar:
|
24
|
+
pass
|
25
|
+
|
26
|
+
def __init_subclass__(cls):
|
27
|
+
# Add constrain method
|
28
|
+
def constrain_pars(self: Model) -> Dict[str, Array]:
|
29
|
+
"""
|
30
|
+
Constrain `params` to the appropriate domain.
|
31
|
+
|
32
|
+
# Returns
|
33
|
+
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
34
|
+
"""
|
35
|
+
t_params = self.params
|
36
|
+
|
37
|
+
for par, map in self.constraints.items():
|
38
|
+
t_params[par] = map(t_params[par])
|
39
|
+
|
40
|
+
return t_params
|
41
|
+
|
42
|
+
cls.constrain_pars = eqx.filter_jit(constrain_pars)
|
43
|
+
|
44
|
+
# Add transform_pars method if not present
|
45
|
+
if not callable(getattr(cls, "transform_pars", None)):
|
46
|
+
def transform_pars(self: Model) -> Dict[str, Array]:
|
47
|
+
"""
|
48
|
+
Apply a custom transformation to `params` if needed.
|
49
|
+
|
50
|
+
# Returns
|
51
|
+
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
52
|
+
"""
|
53
|
+
return self.constrain_pars()
|
54
|
+
|
55
|
+
cls.transform_pars = eqx.filter_jit(transform_pars)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
from typing import Callable, Dict
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
from jaxtyping import Array
|
5
|
+
|
6
|
+
|
7
|
+
class __MyMeta(type(eqx.Module)):
|
8
|
+
"""
|
9
|
+
Metaclass to ensure attribute types are respected.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def __call__(cls, *args, **kwargs):
|
13
|
+
obj = super().__call__(*args, **kwargs)
|
14
|
+
|
15
|
+
# Check parameters are a Dict of JAX Arrays
|
16
|
+
if not isinstance(obj.params, Dict):
|
17
|
+
raise ValueError(
|
18
|
+
f"Model {cls.__name__} must initialize 'params' as a dictionary."
|
19
|
+
)
|
20
|
+
|
21
|
+
for key, value in obj.params.items():
|
22
|
+
if not isinstance(value, Array):
|
23
|
+
raise TypeError(f"Parameter '{key}' must be a JAX Array.")
|
24
|
+
|
25
|
+
# Check constraints are a Dict of functions
|
26
|
+
if not isinstance(obj.constraints, Dict):
|
27
|
+
raise ValueError(
|
28
|
+
f"Model {cls.__name__} must initialize 'constraints' as a dictionary."
|
29
|
+
)
|
30
|
+
|
31
|
+
for key, value in obj.constraints.items():
|
32
|
+
if not isinstance(value, Callable):
|
33
|
+
raise TypeError(f"Constraint for parameter '{key}' must be a function.")
|
34
|
+
|
35
|
+
# Check that the constrain method returns a dict equivalent to `params`
|
36
|
+
t_params: Dict[str, Array] = obj.constrain_pars()
|
37
|
+
|
38
|
+
if not isinstance(t_params, Dict):
|
39
|
+
raise ValueError(
|
40
|
+
f"The 'constrain' method of {cls.__name__} must return a Dict of JAX Arrays."
|
41
|
+
)
|
42
|
+
|
43
|
+
for key, value in t_params.items():
|
44
|
+
if not isinstance(value, Array):
|
45
|
+
raise TypeError(f"Constrained parameter '{key}' must be a JAX Array.")
|
46
|
+
|
47
|
+
if not value.shape == obj.params[key].shape:
|
48
|
+
raise ValueError(
|
49
|
+
f"Constrained parameter '{key}' must have same shape as unconstrained counterpart."
|
50
|
+
)
|
51
|
+
|
52
|
+
# Check transform_pars
|
53
|
+
|
54
|
+
return obj
|
@@ -0,0 +1,159 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Callable, Self, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax
|
6
|
+
import jax.lax as lax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import jax.random as jr
|
9
|
+
import optax as opx
|
10
|
+
from jaxtyping import Array, Float, Key, PyTree, Scalar
|
11
|
+
from optax import GradientTransformation, OptState, Schedule
|
12
|
+
|
13
|
+
from bayinx.core import Model
|
14
|
+
|
15
|
+
|
16
|
+
class Variational(eqx.Module):
|
17
|
+
"""
|
18
|
+
A superclass used to define variational methods.
|
19
|
+
|
20
|
+
# Attributes
|
21
|
+
- `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
|
22
|
+
- `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
|
23
|
+
"""
|
24
|
+
|
25
|
+
_unflatten: Callable[[Float[Array, "..."]], Model]
|
26
|
+
_constraints: Model
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def sample(self, n: int, key: Key) -> Array:
|
30
|
+
"""
|
31
|
+
Sample from the variational distribution.
|
32
|
+
"""
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def eval(self, draws: Array) -> Array:
|
37
|
+
"""
|
38
|
+
Evaluate the variational distribution at `draws`.
|
39
|
+
"""
|
40
|
+
pass
|
41
|
+
|
42
|
+
@abstractmethod
|
43
|
+
def elbo(self, n: int, key: Key, data: Any = None) -> Array:
|
44
|
+
"""
|
45
|
+
Evaluate the ELBO.
|
46
|
+
"""
|
47
|
+
pass
|
48
|
+
|
49
|
+
@abstractmethod
|
50
|
+
def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
|
51
|
+
"""
|
52
|
+
Evaluate the gradient of the ELBO.
|
53
|
+
"""
|
54
|
+
pass
|
55
|
+
|
56
|
+
@abstractmethod
|
57
|
+
def filter_spec(self):
|
58
|
+
"""
|
59
|
+
Filter specification for dynamic and static components of the `Variational`.
|
60
|
+
"""
|
61
|
+
pass
|
62
|
+
|
63
|
+
def __init_subclass__(cls):
|
64
|
+
"""
|
65
|
+
Construct methods that are shared across all VI methods.
|
66
|
+
"""
|
67
|
+
|
68
|
+
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
69
|
+
"""
|
70
|
+
Reconstruct models from variational draws and evaluate their posterior density.
|
71
|
+
|
72
|
+
# Parameters
|
73
|
+
- `draws`: A set of variational draws.
|
74
|
+
- `data`: Data used to evaluate the posterior(if needed).
|
75
|
+
"""
|
76
|
+
# Unflatten variational draw
|
77
|
+
model: Model = self._unflatten(draws)
|
78
|
+
|
79
|
+
# Combine with constraints
|
80
|
+
model: Model = eqx.combine(model, self._constraints)
|
81
|
+
|
82
|
+
# Evaluate posterior density
|
83
|
+
return model.eval(data)
|
84
|
+
|
85
|
+
cls.eval_model = jax.vmap(eqx.filter_jit(eval_model), (None, 0, None))
|
86
|
+
|
87
|
+
def fit(
|
88
|
+
self,
|
89
|
+
max_iters: int,
|
90
|
+
data: Any = None,
|
91
|
+
learning_rate: float = 1,
|
92
|
+
tolerance: float = 1e-4,
|
93
|
+
var_draws: int = 1,
|
94
|
+
key: Key = jr.PRNGKey(0),
|
95
|
+
) -> Self:
|
96
|
+
"""
|
97
|
+
Optimize the variational distribution.
|
98
|
+
|
99
|
+
# Parameters
|
100
|
+
- `max_iters`: Maximum number of iterations for the optimization loop.
|
101
|
+
- `data`: Data to evaluate the posterior density with(if available).
|
102
|
+
- `learning_rate`: Initial learning rate for optimization.
|
103
|
+
- `tolerance`: Relative tolerance of ELBO decrease for stopping early.
|
104
|
+
- `var_draws`: Number of variational draws to draw each iteration.
|
105
|
+
- `key`: A PRNG key.
|
106
|
+
"""
|
107
|
+
# Construct scheduler
|
108
|
+
schedule: Schedule = opx.cosine_decay_schedule(
|
109
|
+
init_value=learning_rate, decay_steps=max_iters
|
110
|
+
)
|
111
|
+
|
112
|
+
# Initialize optimizer
|
113
|
+
optim: GradientTransformation = opx.chain(
|
114
|
+
opx.scale(-1.0), opx.nadam(schedule)
|
115
|
+
)
|
116
|
+
opt_state: OptState = optim.init(eqx.filter(self, self.filter_spec()))
|
117
|
+
|
118
|
+
# Optimization loop helper functions
|
119
|
+
@eqx.filter_jit
|
120
|
+
def condition(state: Tuple[Self, OptState, Scalar, Key]):
|
121
|
+
# Unpack iteration state
|
122
|
+
self, opt_state, i, key = state
|
123
|
+
|
124
|
+
return i < max_iters
|
125
|
+
|
126
|
+
@eqx.filter_jit
|
127
|
+
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
128
|
+
# Unpack iteration state
|
129
|
+
self, opt_state, i, key = state
|
130
|
+
|
131
|
+
# Update iteration
|
132
|
+
i = i + 1
|
133
|
+
|
134
|
+
# Update PRNG key
|
135
|
+
key, _ = jr.split(key)
|
136
|
+
|
137
|
+
# Compute gradient of the ELBO
|
138
|
+
updates: PyTree = self.elbo_grad(var_draws, key, data)
|
139
|
+
|
140
|
+
# Compute updates
|
141
|
+
updates, opt_state = optim.update(
|
142
|
+
updates, opt_state, eqx.filter(self, self.filter_spec())
|
143
|
+
)
|
144
|
+
|
145
|
+
# Update variational distribution
|
146
|
+
self: Self = eqx.apply_updates(self, updates)
|
147
|
+
|
148
|
+
return self, opt_state, i, key
|
149
|
+
|
150
|
+
# Run optimization loop
|
151
|
+
self = lax.while_loop(
|
152
|
+
cond_fun=condition,
|
153
|
+
body_fun=body,
|
154
|
+
init_val=(self, opt_state, jnp.array(0, jnp.uint32), key),
|
155
|
+
)[0]
|
156
|
+
|
157
|
+
return self
|
158
|
+
|
159
|
+
cls.fit = eqx.filter_jit(fit)
|
@@ -0,0 +1,33 @@
|
|
1
|
+
import jax.lax as lax
|
2
|
+
from jaxtyping import Array, ArrayLike, Real, UInt
|
3
|
+
|
4
|
+
|
5
|
+
# MARK: Functions ----
|
6
|
+
def prob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
|
7
|
+
"""
|
8
|
+
The probability mass function (PMF) for a Bernoulli distribution.
|
9
|
+
|
10
|
+
# Parameters
|
11
|
+
- `x`: Value(s) at which to evaluate the PDF.
|
12
|
+
- `p`: The probability parameter(s).
|
13
|
+
|
14
|
+
# Returns
|
15
|
+
The PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
16
|
+
"""
|
17
|
+
|
18
|
+
return lax.pow(p, x) * lax.pow(1 - p, 1 - x)
|
19
|
+
|
20
|
+
|
21
|
+
def logprob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
|
22
|
+
"""
|
23
|
+
The log probability mass function (log PMF) for a Bernoulli distribution.
|
24
|
+
|
25
|
+
# Parameters
|
26
|
+
- `x`: Value(s) at which to evaluate the log PMF.
|
27
|
+
- `p`: The probability parameter(s).
|
28
|
+
|
29
|
+
# Returns
|
30
|
+
The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
31
|
+
"""
|
32
|
+
|
33
|
+
return lax.log(p) * x + (1 - x) * lax.log(1 - p)
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# MARK: Imports ----
|
2
|
+
import jax.lax as _lax
|
3
|
+
|
4
|
+
## Typing
|
5
|
+
from jaxtyping import Array, Real
|
6
|
+
|
7
|
+
# MARK: Constants
|
8
|
+
_PI = 3.141592653589793
|
9
|
+
|
10
|
+
|
11
|
+
# MARK: Functions ----
|
12
|
+
def prob(
|
13
|
+
x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
|
14
|
+
) -> Real[Array, "..."]:
|
15
|
+
"""
|
16
|
+
The probability density function (PDF) for a Normal distribution.
|
17
|
+
|
18
|
+
# Parameters
|
19
|
+
- `x`: Value(s) at which to evaluate the PDF.
|
20
|
+
- `mu`: The mean/location parameter(s).
|
21
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
22
|
+
|
23
|
+
# Returns
|
24
|
+
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
25
|
+
"""
|
26
|
+
|
27
|
+
return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / (
|
28
|
+
sigma * _lax.sqrt(2.0 * _PI)
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
def logprob(
|
33
|
+
x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
|
34
|
+
) -> Real[Array, "..."]:
|
35
|
+
"""
|
36
|
+
The log of the probability density function (log PDF) for a Normal distribution.
|
37
|
+
|
38
|
+
# Parameters
|
39
|
+
- `x`: Value(s) at which to evaluate the log PDF.
|
40
|
+
- `mu`: The mean/location parameter(s).
|
41
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
42
|
+
|
43
|
+
# Returns
|
44
|
+
The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
45
|
+
"""
|
46
|
+
|
47
|
+
return -_lax.log(sigma * _lax.sqrt(2.0 * _PI)) - 0.5 * _lax.square((x - mu) / sigma)
|
48
|
+
|
49
|
+
|
50
|
+
def uprob(
|
51
|
+
x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
|
52
|
+
) -> Real[Array, "..."]:
|
53
|
+
"""
|
54
|
+
The unnormalized probability density function (uPDF) for a Normal distribution.
|
55
|
+
|
56
|
+
# Parameters
|
57
|
+
- `x`: Value(s) at which to evaluate the uPDF.
|
58
|
+
- `mu`: The mean/location parameter(s).
|
59
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
60
|
+
|
61
|
+
# Returns
|
62
|
+
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
63
|
+
"""
|
64
|
+
|
65
|
+
return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma
|
66
|
+
|
67
|
+
|
68
|
+
def ulogprob(
|
69
|
+
x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
|
70
|
+
) -> Real[Array, "..."]:
|
71
|
+
"""
|
72
|
+
The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
|
73
|
+
|
74
|
+
# Parameters
|
75
|
+
- `x`: Value(s) at which to evaluate the log uPDF.
|
76
|
+
- `mu`: The mean/location parameter(s).
|
77
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
78
|
+
|
79
|
+
# Returns
|
80
|
+
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
81
|
+
"""
|
82
|
+
|
83
|
+
return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from functools import partial
|
2
|
+
from typing import Callable, Dict, Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from jaxtyping import Array, Float, Scalar
|
8
|
+
|
9
|
+
from bayinx.core import Flow
|
10
|
+
|
11
|
+
|
12
|
+
class Affine(Flow):
|
13
|
+
"""
|
14
|
+
An affine flow.
|
15
|
+
|
16
|
+
# Attributes
|
17
|
+
- `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
|
18
|
+
- `constraints`: A dictionary of constraining transformations.
|
19
|
+
"""
|
20
|
+
|
21
|
+
params: Dict[str, Float[Array, "..."]]
|
22
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
|
23
|
+
eqx.field(static=True)
|
24
|
+
)
|
25
|
+
|
26
|
+
def __init__(self, dim: int):
|
27
|
+
"""
|
28
|
+
Initializes an affine flow.
|
29
|
+
|
30
|
+
# Parameters
|
31
|
+
- `dim`: The dimension of the parameter space.
|
32
|
+
"""
|
33
|
+
self.params = {
|
34
|
+
"shift": jnp.zeros(dim),
|
35
|
+
"scale": jnp.zeros((dim, dim)),
|
36
|
+
}
|
37
|
+
|
38
|
+
self.constraints = {"scale": lambda m: jnp.tril(jnp.exp(m))}
|
39
|
+
|
40
|
+
@eqx.filter_jit
|
41
|
+
def forward(self, draws: Array) -> Array:
|
42
|
+
params = self.constrain_pars()
|
43
|
+
|
44
|
+
# Extract parameters
|
45
|
+
shift: Array = params["shift"]
|
46
|
+
scale: Array = params["scale"]
|
47
|
+
|
48
|
+
# Compute forward transformation
|
49
|
+
draws = draws @ scale + shift
|
50
|
+
|
51
|
+
return draws
|
52
|
+
|
53
|
+
@eqx.filter_jit
|
54
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
55
|
+
def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
|
56
|
+
params = self.constrain_pars()
|
57
|
+
|
58
|
+
# Extract parameters
|
59
|
+
shift: Array = params["shift"]
|
60
|
+
scale: Array = params["scale"]
|
61
|
+
|
62
|
+
# Compute forward transformation
|
63
|
+
draws = draws @ scale + shift
|
64
|
+
|
65
|
+
# Compute ladj
|
66
|
+
ladj: Scalar = jnp.log(jnp.diag(scale)).sum()
|
67
|
+
|
68
|
+
return ladj, draws
|