bayinx 0.1.0__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. bayinx-0.2.3/.github/workflows/publish.yml +29 -0
  2. {bayinx-0.1.0 → bayinx-0.2.3}/.gitignore +5 -0
  3. bayinx-0.2.3/PKG-INFO +40 -0
  4. bayinx-0.2.3/README.md +27 -0
  5. bayinx-0.2.3/pyproject.toml +21 -0
  6. bayinx-0.2.3/src/bayinx/__init__.py +1 -0
  7. bayinx-0.2.3/src/bayinx/core/__init__.py +3 -0
  8. bayinx-0.2.3/src/bayinx/core/flow.py +68 -0
  9. bayinx-0.2.3/src/bayinx/core/model.py +55 -0
  10. bayinx-0.2.3/src/bayinx/core/utils.py +54 -0
  11. bayinx-0.2.3/src/bayinx/core/variational.py +159 -0
  12. bayinx-0.2.3/src/bayinx/dists/bernoulli.py +33 -0
  13. bayinx-0.2.3/src/bayinx/dists/normal.py +83 -0
  14. bayinx-0.2.3/src/bayinx/machinery/variational/__init__.py +5 -0
  15. bayinx-0.2.3/src/bayinx/machinery/variational/flows/__init__.py +3 -0
  16. bayinx-0.2.3/src/bayinx/machinery/variational/flows/affine.py +68 -0
  17. bayinx-0.2.3/src/bayinx/machinery/variational/flows/planar.py +76 -0
  18. bayinx-0.2.3/src/bayinx/machinery/variational/flows/radial.py +95 -0
  19. bayinx-0.2.3/src/bayinx/machinery/variational/flows/sylvester.py +76 -0
  20. bayinx-0.2.3/src/bayinx/machinery/variational/meanfield.py +124 -0
  21. bayinx-0.2.3/src/bayinx/machinery/variational/normalizing_flow.py +152 -0
  22. bayinx-0.2.3/src/bayinx/machinery/variational/standard.py +67 -0
  23. bayinx-0.2.3/tests/__init__.py +0 -0
  24. bayinx-0.2.3/tests/test_variational.py +130 -0
  25. {bayinx-0.1.0 → bayinx-0.2.3}/uv.lock +190 -12
  26. bayinx-0.1.0/.python-version +0 -1
  27. bayinx-0.1.0/.vscode/settings.json +0 -3
  28. bayinx-0.1.0/PKG-INFO +0 -8
  29. bayinx-0.1.0/packages/baycompx/pyproject.toml +0 -11
  30. bayinx-0.1.0/packages/baycompx/src/baycompx/__init__.py +0 -2
  31. bayinx-0.1.0/packages/baycompx/src/baycompx/dists/normal.py +0 -93
  32. bayinx-0.1.0/pyproject.toml +0 -21
  33. {bayinx-0.1.0/packages/baycompx/src/baycompx → bayinx-0.2.3/src/bayinx}/dists/__init__.py +0 -0
  34. /bayinx-0.1.0/README.md → /bayinx-0.2.3/src/bayinx/dists/binomial.py +0 -0
  35. /bayinx-0.1.0/packages/baycompx/README.md → /bayinx-0.2.3/src/bayinx/dists/gamma.py +0 -0
  36. /bayinx-0.1.0/packages/baycompx/src/baycompx/py.typed → /bayinx-0.2.3/src/bayinx/dists/gamma2.py +0 -0
  37. {bayinx-0.1.0/src/bayinx → bayinx-0.2.3/src/bayinx/machinery}/__init__.py +0 -0
  38. {bayinx-0.1.0 → bayinx-0.2.3}/src/bayinx/py.typed +0 -0
@@ -0,0 +1,29 @@
1
+ name: "Publish"
2
+
3
+ on:
4
+ release:
5
+ types: ["published"]
6
+
7
+ jobs:
8
+ run:
9
+ name: "Build and publish release"
10
+ runs-on: ubuntu-latest
11
+ permissions:
12
+ id-token: write # Required for trusted publishing
13
+
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+
17
+ - name: Install uv
18
+ uses: astral-sh/setup-uv@v3
19
+ with:
20
+ enable-cache: true
21
+
22
+ - name: Set up Python
23
+ run: uv python install 3.12
24
+
25
+ - name: Build
26
+ run: uv build
27
+
28
+ - name: Publish
29
+ run: uv publish
@@ -5,6 +5,11 @@ build/
5
5
  dist/
6
6
  wheels/
7
7
  *.egg-info
8
+ .pytest_cache
9
+ .benchmarks
8
10
 
9
11
  # Virtual environments
10
12
  .venv
13
+
14
+ # Other
15
+ .ruff_cache
bayinx-0.2.3/PKG-INFO ADDED
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.4
2
+ Name: bayinx
3
+ Version: 0.2.3
4
+ Summary: Bayesian Inference with JAX
5
+ Requires-Python: >=3.12
6
+ Requires-Dist: equinox>=0.11.12
7
+ Requires-Dist: jax>=0.4.38
8
+ Requires-Dist: jaxtyping>=0.2.36
9
+ Requires-Dist: optax>=0.2.4
10
+ Requires-Dist: pytest-benchmark>=5.1.0
11
+ Requires-Dist: pytest>=8.3.5
12
+ Description-Content-Type: text/markdown
13
+
14
+ # <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
15
+
16
+ The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is equivalent to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
17
+
18
+ In the short-term, I'm going to focus on:
19
+ 1) Implementing as much machinery as I feel is enough.
20
+ 2) Figuring out how to design the `Model` superclass to have something like the `transformed pars {}` block but unifies transformations and constraints.
21
+ 3) Figuring out how to design the library to automatically recognize what kind of machinery is amenable to a given probabilistic model.
22
+
23
+ In the long-term, I'm going to focus on:
24
+ 1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
25
+ 2) How to make working with the posterior as easy as possible.
26
+ - That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
27
+
28
+ Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
29
+
30
+ Additionally, when I get around to it I'd like the package documentation to also include theoretical and implementation details for all machinery implemented(with overthinking boxes because I do like that design from McElreath's book).
31
+
32
+
33
+ # TODO
34
+ - Find some way to discern between models with all floating-point parameters and weirder models with integer parameters. Useful for restricting variational methods like `MeanField` to `Model`s that only have floating-point parameters.
35
+ - Look into adaptively tuning ADAM hyperparameters.
36
+ - Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
37
+ - Low-rank affine flow?
38
+ - https://arxiv.org/pdf/1803.05649 implement sylvester flows.
39
+ - Learn how to generate documentation lol.
40
+ - Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
bayinx-0.2.3/README.md ADDED
@@ -0,0 +1,27 @@
1
+ # <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
2
+
3
+ The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is equivalent to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
4
+
5
+ In the short-term, I'm going to focus on:
6
+ 1) Implementing as much machinery as I feel is enough.
7
+ 2) Figuring out how to design the `Model` superclass to have something like the `transformed pars {}` block but unifies transformations and constraints.
8
+ 3) Figuring out how to design the library to automatically recognize what kind of machinery is amenable to a given probabilistic model.
9
+
10
+ In the long-term, I'm going to focus on:
11
+ 1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
12
+ 2) How to make working with the posterior as easy as possible.
13
+ - That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
14
+
15
+ Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
16
+
17
+ Additionally, when I get around to it I'd like the package documentation to also include theoretical and implementation details for all machinery implemented(with overthinking boxes because I do like that design from McElreath's book).
18
+
19
+
20
+ # TODO
21
+ - Find some way to discern between models with all floating-point parameters and weirder models with integer parameters. Useful for restricting variational methods like `MeanField` to `Model`s that only have floating-point parameters.
22
+ - Look into adaptively tuning ADAM hyperparameters.
23
+ - Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
24
+ - Low-rank affine flow?
25
+ - https://arxiv.org/pdf/1803.05649 implement sylvester flows.
26
+ - Learn how to generate documentation lol.
27
+ - Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
@@ -0,0 +1,21 @@
1
+ [project]
2
+ name = "bayinx"
3
+ version = "0.2.3"
4
+ description = "Bayesian Inference with JAX"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "equinox>=0.11.12",
9
+ "jax>=0.4.38",
10
+ "jaxtyping>=0.2.36",
11
+ "optax>=0.2.4",
12
+ "pytest>=8.3.5",
13
+ "pytest-benchmark>=5.1.0",
14
+ ]
15
+
16
+ [build-system]
17
+ requires = ["hatchling"]
18
+ build-backend = "hatchling.build"
19
+
20
+ [tool.pytest.ini_options]
21
+ addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
@@ -0,0 +1 @@
1
+ from bayinx.core.model import Model as Model
@@ -0,0 +1,3 @@
1
+ from bayinx.core.flow import Flow as Flow
2
+ from bayinx.core.model import Model as Model
3
+ from bayinx.core.variational import Variational as Variational
@@ -0,0 +1,68 @@
1
+ from abc import abstractmethod
2
+ from typing import Callable, Dict, Self, Tuple
3
+
4
+ import equinox as eqx
5
+ from jaxtyping import Array, Float
6
+
7
+ from bayinx.core.utils import __MyMeta
8
+
9
+
10
+ class Flow(eqx.Module, metaclass=__MyMeta):
11
+ """
12
+ A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
13
+
14
+ # Attributes
15
+ - `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
16
+ - `constraints`: A dictionary of functions that constrain their corresponding parameter.
17
+ """
18
+
19
+ params: Dict[str, Float[Array, "..."]]
20
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
21
+
22
+ @abstractmethod
23
+ def forward(self, draws: Array) -> Array:
24
+ """
25
+ Computes the forward transformation of `draws`.
26
+ """
27
+ pass
28
+
29
+ @abstractmethod
30
+ def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
31
+ """
32
+ Computes the log-absolute-determinant of the Jacobian at `draws` and applies the forward transformation.
33
+
34
+ # Returns
35
+ A tuple of JAX Arrays containing the log-absolute-determinant of the Jacobians and transformed draws.
36
+ """
37
+ pass
38
+
39
+ def __init_subclass__(cls):
40
+ # Add contrain_pars method
41
+ def constrain_pars(self: Self):
42
+ """
43
+ Constrain `params` to the appropriate domain.
44
+
45
+ # Returns
46
+ A dictionary of transformed JAX Arrays representing the constrained parameters.
47
+ """
48
+ t_params = self.params
49
+
50
+ for par, map in self.constraints.items():
51
+ t_params[par] = map(t_params[par])
52
+
53
+ return t_params
54
+
55
+ cls.constrain_pars = eqx.filter_jit(constrain_pars)
56
+
57
+ # Add transform_pars method if not present
58
+ if not callable(getattr(cls, "transform_pars", None)):
59
+ def transform_pars(self: Self) -> Dict[str, Array]:
60
+ """
61
+ Apply a custom transformation to `params` if needed.
62
+
63
+ # Returns
64
+ A dictionary of transformed JAX Arrays representing the transformed parameters.
65
+ """
66
+ return self.constrain_pars()
67
+
68
+ cls.transform_pars = eqx.filter_jit(transform_pars)
@@ -0,0 +1,55 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Callable, Dict
3
+
4
+ import equinox as eqx
5
+ from jaxtyping import Array, Scalar
6
+
7
+ from bayinx.core.utils import __MyMeta
8
+
9
+
10
+ class Model(eqx.Module, metaclass=__MyMeta):
11
+ """
12
+ A superclass used to define probabilistic models.
13
+
14
+ # Attributes
15
+ - `params`: A dictionary of JAX Arrays representing parameters of the model.
16
+ - `constraints`: A dictionary of functions that constrain their corresponding parameter.
17
+ """
18
+
19
+ params: Dict[str, Array]
20
+ constraints: Dict[str, Callable[[Array], Array]]
21
+
22
+ @abstractmethod
23
+ def eval(self, data: Any) -> Scalar:
24
+ pass
25
+
26
+ def __init_subclass__(cls):
27
+ # Add constrain method
28
+ def constrain_pars(self: Model) -> Dict[str, Array]:
29
+ """
30
+ Constrain `params` to the appropriate domain.
31
+
32
+ # Returns
33
+ A dictionary of transformed JAX Arrays representing the constrained parameters.
34
+ """
35
+ t_params = self.params
36
+
37
+ for par, map in self.constraints.items():
38
+ t_params[par] = map(t_params[par])
39
+
40
+ return t_params
41
+
42
+ cls.constrain_pars = eqx.filter_jit(constrain_pars)
43
+
44
+ # Add transform_pars method if not present
45
+ if not callable(getattr(cls, "transform_pars", None)):
46
+ def transform_pars(self: Model) -> Dict[str, Array]:
47
+ """
48
+ Apply a custom transformation to `params` if needed.
49
+
50
+ # Returns
51
+ A dictionary of transformed JAX Arrays representing the transformed parameters.
52
+ """
53
+ return self.constrain_pars()
54
+
55
+ cls.transform_pars = eqx.filter_jit(transform_pars)
@@ -0,0 +1,54 @@
1
+ from typing import Callable, Dict
2
+
3
+ import equinox as eqx
4
+ from jaxtyping import Array
5
+
6
+
7
+ class __MyMeta(type(eqx.Module)):
8
+ """
9
+ Metaclass to ensure attribute types are respected.
10
+ """
11
+
12
+ def __call__(cls, *args, **kwargs):
13
+ obj = super().__call__(*args, **kwargs)
14
+
15
+ # Check parameters are a Dict of JAX Arrays
16
+ if not isinstance(obj.params, Dict):
17
+ raise ValueError(
18
+ f"Model {cls.__name__} must initialize 'params' as a dictionary."
19
+ )
20
+
21
+ for key, value in obj.params.items():
22
+ if not isinstance(value, Array):
23
+ raise TypeError(f"Parameter '{key}' must be a JAX Array.")
24
+
25
+ # Check constraints are a Dict of functions
26
+ if not isinstance(obj.constraints, Dict):
27
+ raise ValueError(
28
+ f"Model {cls.__name__} must initialize 'constraints' as a dictionary."
29
+ )
30
+
31
+ for key, value in obj.constraints.items():
32
+ if not isinstance(value, Callable):
33
+ raise TypeError(f"Constraint for parameter '{key}' must be a function.")
34
+
35
+ # Check that the constrain method returns a dict equivalent to `params`
36
+ t_params: Dict[str, Array] = obj.constrain_pars()
37
+
38
+ if not isinstance(t_params, Dict):
39
+ raise ValueError(
40
+ f"The 'constrain' method of {cls.__name__} must return a Dict of JAX Arrays."
41
+ )
42
+
43
+ for key, value in t_params.items():
44
+ if not isinstance(value, Array):
45
+ raise TypeError(f"Constrained parameter '{key}' must be a JAX Array.")
46
+
47
+ if not value.shape == obj.params[key].shape:
48
+ raise ValueError(
49
+ f"Constrained parameter '{key}' must have same shape as unconstrained counterpart."
50
+ )
51
+
52
+ # Check transform_pars
53
+
54
+ return obj
@@ -0,0 +1,159 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Callable, Self, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax
6
+ import jax.lax as lax
7
+ import jax.numpy as jnp
8
+ import jax.random as jr
9
+ import optax as opx
10
+ from jaxtyping import Array, Float, Key, PyTree, Scalar
11
+ from optax import GradientTransformation, OptState, Schedule
12
+
13
+ from bayinx.core import Model
14
+
15
+
16
+ class Variational(eqx.Module):
17
+ """
18
+ A superclass used to define variational methods.
19
+
20
+ # Attributes
21
+ - `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
22
+ - `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
23
+ """
24
+
25
+ _unflatten: Callable[[Float[Array, "..."]], Model]
26
+ _constraints: Model
27
+
28
+ @abstractmethod
29
+ def sample(self, n: int, key: Key) -> Array:
30
+ """
31
+ Sample from the variational distribution.
32
+ """
33
+ pass
34
+
35
+ @abstractmethod
36
+ def eval(self, draws: Array) -> Array:
37
+ """
38
+ Evaluate the variational distribution at `draws`.
39
+ """
40
+ pass
41
+
42
+ @abstractmethod
43
+ def elbo(self, n: int, key: Key, data: Any = None) -> Array:
44
+ """
45
+ Evaluate the ELBO.
46
+ """
47
+ pass
48
+
49
+ @abstractmethod
50
+ def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
51
+ """
52
+ Evaluate the gradient of the ELBO.
53
+ """
54
+ pass
55
+
56
+ @abstractmethod
57
+ def filter_spec(self):
58
+ """
59
+ Filter specification for dynamic and static components of the `Variational`.
60
+ """
61
+ pass
62
+
63
+ def __init_subclass__(cls):
64
+ """
65
+ Construct methods that are shared across all VI methods.
66
+ """
67
+
68
+ def eval_model(self, draws: Array, data: Any = None) -> Array:
69
+ """
70
+ Reconstruct models from variational draws and evaluate their posterior density.
71
+
72
+ # Parameters
73
+ - `draws`: A set of variational draws.
74
+ - `data`: Data used to evaluate the posterior(if needed).
75
+ """
76
+ # Unflatten variational draw
77
+ model: Model = self._unflatten(draws)
78
+
79
+ # Combine with constraints
80
+ model: Model = eqx.combine(model, self._constraints)
81
+
82
+ # Evaluate posterior density
83
+ return model.eval(data)
84
+
85
+ cls.eval_model = jax.vmap(eqx.filter_jit(eval_model), (None, 0, None))
86
+
87
+ def fit(
88
+ self,
89
+ max_iters: int,
90
+ data: Any = None,
91
+ learning_rate: float = 1,
92
+ tolerance: float = 1e-4,
93
+ var_draws: int = 1,
94
+ key: Key = jr.PRNGKey(0),
95
+ ) -> Self:
96
+ """
97
+ Optimize the variational distribution.
98
+
99
+ # Parameters
100
+ - `max_iters`: Maximum number of iterations for the optimization loop.
101
+ - `data`: Data to evaluate the posterior density with(if available).
102
+ - `learning_rate`: Initial learning rate for optimization.
103
+ - `tolerance`: Relative tolerance of ELBO decrease for stopping early.
104
+ - `var_draws`: Number of variational draws to draw each iteration.
105
+ - `key`: A PRNG key.
106
+ """
107
+ # Construct scheduler
108
+ schedule: Schedule = opx.cosine_decay_schedule(
109
+ init_value=learning_rate, decay_steps=max_iters
110
+ )
111
+
112
+ # Initialize optimizer
113
+ optim: GradientTransformation = opx.chain(
114
+ opx.scale(-1.0), opx.nadam(schedule)
115
+ )
116
+ opt_state: OptState = optim.init(eqx.filter(self, self.filter_spec()))
117
+
118
+ # Optimization loop helper functions
119
+ @eqx.filter_jit
120
+ def condition(state: Tuple[Self, OptState, Scalar, Key]):
121
+ # Unpack iteration state
122
+ self, opt_state, i, key = state
123
+
124
+ return i < max_iters
125
+
126
+ @eqx.filter_jit
127
+ def body(state: Tuple[Self, OptState, Scalar, Key]):
128
+ # Unpack iteration state
129
+ self, opt_state, i, key = state
130
+
131
+ # Update iteration
132
+ i = i + 1
133
+
134
+ # Update PRNG key
135
+ key, _ = jr.split(key)
136
+
137
+ # Compute gradient of the ELBO
138
+ updates: PyTree = self.elbo_grad(var_draws, key, data)
139
+
140
+ # Compute updates
141
+ updates, opt_state = optim.update(
142
+ updates, opt_state, eqx.filter(self, self.filter_spec())
143
+ )
144
+
145
+ # Update variational distribution
146
+ self: Self = eqx.apply_updates(self, updates)
147
+
148
+ return self, opt_state, i, key
149
+
150
+ # Run optimization loop
151
+ self = lax.while_loop(
152
+ cond_fun=condition,
153
+ body_fun=body,
154
+ init_val=(self, opt_state, jnp.array(0, jnp.uint32), key),
155
+ )[0]
156
+
157
+ return self
158
+
159
+ cls.fit = eqx.filter_jit(fit)
@@ -0,0 +1,33 @@
1
+ import jax.lax as lax
2
+ from jaxtyping import Array, ArrayLike, Real, UInt
3
+
4
+
5
+ # MARK: Functions ----
6
+ def prob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
7
+ """
8
+ The probability mass function (PMF) for a Bernoulli distribution.
9
+
10
+ # Parameters
11
+ - `x`: Value(s) at which to evaluate the PDF.
12
+ - `p`: The probability parameter(s).
13
+
14
+ # Returns
15
+ The PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
16
+ """
17
+
18
+ return lax.pow(p, x) * lax.pow(1 - p, 1 - x)
19
+
20
+
21
+ def logprob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
22
+ """
23
+ The log probability mass function (log PMF) for a Bernoulli distribution.
24
+
25
+ # Parameters
26
+ - `x`: Value(s) at which to evaluate the log PMF.
27
+ - `p`: The probability parameter(s).
28
+
29
+ # Returns
30
+ The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
31
+ """
32
+
33
+ return lax.log(p) * x + (1 - x) * lax.log(1 - p)
@@ -0,0 +1,83 @@
1
+ # MARK: Imports ----
2
+ import jax.lax as _lax
3
+
4
+ ## Typing
5
+ from jaxtyping import Array, Real
6
+
7
+ # MARK: Constants
8
+ _PI = 3.141592653589793
9
+
10
+
11
+ # MARK: Functions ----
12
+ def prob(
13
+ x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
14
+ ) -> Real[Array, "..."]:
15
+ """
16
+ The probability density function (PDF) for a Normal distribution.
17
+
18
+ # Parameters
19
+ - `x`: Value(s) at which to evaluate the PDF.
20
+ - `mu`: The mean/location parameter(s).
21
+ - `sigma`: The non-negative standard deviation parameter(s).
22
+
23
+ # Returns
24
+ The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
25
+ """
26
+
27
+ return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / (
28
+ sigma * _lax.sqrt(2.0 * _PI)
29
+ )
30
+
31
+
32
+ def logprob(
33
+ x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
34
+ ) -> Real[Array, "..."]:
35
+ """
36
+ The log of the probability density function (log PDF) for a Normal distribution.
37
+
38
+ # Parameters
39
+ - `x`: Value(s) at which to evaluate the log PDF.
40
+ - `mu`: The mean/location parameter(s).
41
+ - `sigma`: The non-negative standard deviation parameter(s).
42
+
43
+ # Returns
44
+ The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
45
+ """
46
+
47
+ return -_lax.log(sigma * _lax.sqrt(2.0 * _PI)) - 0.5 * _lax.square((x - mu) / sigma)
48
+
49
+
50
+ def uprob(
51
+ x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
52
+ ) -> Real[Array, "..."]:
53
+ """
54
+ The unnormalized probability density function (uPDF) for a Normal distribution.
55
+
56
+ # Parameters
57
+ - `x`: Value(s) at which to evaluate the uPDF.
58
+ - `mu`: The mean/location parameter(s).
59
+ - `sigma`: The non-negative standard deviation parameter(s).
60
+
61
+ # Returns
62
+ The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
63
+ """
64
+
65
+ return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma
66
+
67
+
68
+ def ulogprob(
69
+ x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
70
+ ) -> Real[Array, "..."]:
71
+ """
72
+ The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
73
+
74
+ # Parameters
75
+ - `x`: Value(s) at which to evaluate the log uPDF.
76
+ - `mu`: The mean/location parameter(s).
77
+ - `sigma`: The non-negative standard deviation parameter(s).
78
+
79
+ # Returns
80
+ The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
81
+ """
82
+
83
+ return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma)
@@ -0,0 +1,5 @@
1
+ from bayinx.machinery.variational.meanfield import MeanField as MeanField
2
+ from bayinx.machinery.variational.normalizing_flow import (
3
+ NormalizingFlow as NormalizingFlow,
4
+ )
5
+ from bayinx.machinery.variational.standard import Standard as Standard
@@ -0,0 +1,3 @@
1
+ from bayinx.machinery.variational.flows.affine import Affine as Affine
2
+ from bayinx.machinery.variational.flows.planar import Planar as Planar
3
+ from bayinx.machinery.variational.flows.radial import Radial as Radial
@@ -0,0 +1,68 @@
1
+ from functools import partial
2
+ from typing import Callable, Dict, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jaxtyping import Array, Float, Scalar
8
+
9
+ from bayinx.core import Flow
10
+
11
+
12
+ class Affine(Flow):
13
+ """
14
+ An affine flow.
15
+
16
+ # Attributes
17
+ - `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
18
+ - `constraints`: A dictionary of constraining transformations.
19
+ """
20
+
21
+ params: Dict[str, Float[Array, "..."]]
22
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
23
+ eqx.field(static=True)
24
+ )
25
+
26
+ def __init__(self, dim: int):
27
+ """
28
+ Initializes an affine flow.
29
+
30
+ # Parameters
31
+ - `dim`: The dimension of the parameter space.
32
+ """
33
+ self.params = {
34
+ "shift": jnp.zeros(dim),
35
+ "scale": jnp.zeros((dim, dim)),
36
+ }
37
+
38
+ self.constraints = {"scale": lambda m: jnp.tril(jnp.exp(m))}
39
+
40
+ @eqx.filter_jit
41
+ def forward(self, draws: Array) -> Array:
42
+ params = self.constrain_pars()
43
+
44
+ # Extract parameters
45
+ shift: Array = params["shift"]
46
+ scale: Array = params["scale"]
47
+
48
+ # Compute forward transformation
49
+ draws = draws @ scale + shift
50
+
51
+ return draws
52
+
53
+ @eqx.filter_jit
54
+ @partial(jax.vmap, in_axes=(None, 0))
55
+ def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
56
+ params = self.constrain_pars()
57
+
58
+ # Extract parameters
59
+ shift: Array = params["shift"]
60
+ scale: Array = params["scale"]
61
+
62
+ # Compute forward transformation
63
+ draws = draws @ scale + shift
64
+
65
+ # Compute ladj
66
+ ladj: Scalar = jnp.log(jnp.diag(scale)).sum()
67
+
68
+ return ladj, draws