bayinx 0.3.19__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayinx might be problematic. Click here for more details.
- bayinx-0.4.0/PKG-INFO +47 -0
- bayinx-0.4.0/README.md +33 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/pyproject.toml +4 -2
- bayinx-0.4.0/src/bayinx/__init__.py +3 -0
- bayinx-0.3.19/src/bayinx/constraints/lower.py → bayinx-0.4.0/src/bayinx/constraints.py +9 -13
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/__init__.py +2 -2
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_flow.py +0 -3
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_model.py +41 -12
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_optimization.py +3 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_parameter.py +7 -6
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_variational.py +7 -10
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/posnormal/r.py +6 -1
- bayinx-0.4.0/src/bayinx/dists/negbinom3.py +113 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/posnormal.py +38 -1
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/uniform.py +6 -2
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -2
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/meanfield.py +3 -6
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/normalizing_flow.py +4 -5
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/standard.py +2 -1
- bayinx-0.4.0/tests/__init__.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/tests/test_variational.py +8 -12
- {bayinx-0.3.19 → bayinx-0.4.0}/uv.lock +266 -266
- bayinx-0.3.19/PKG-INFO +0 -39
- bayinx-0.3.19/README.md +0 -27
- bayinx-0.3.19/src/bayinx/__init__.py +0 -3
- bayinx-0.3.19/src/bayinx/constraints/__init__.py +0 -3
- bayinx-0.3.19/tests/test_predictive.py +0 -45
- {bayinx-0.3.19 → bayinx-0.4.0}/.github/workflows/release_and_publish.yml +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/.gitignore +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/LICENSE +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_constraint.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/__init__.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/bernoulli.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/__init__.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/gamma2/__init__.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/gamma2/r.py +0 -0
- bayinx-0.3.19/src/bayinx/mhx/opt/__init__.py → bayinx-0.4.0/src/bayinx/dists/censored/negbinom3/r.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/posnormal/__init__.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/gamma2.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/normal.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/__init__.py +0 -0
- {bayinx-0.3.19/tests → bayinx-0.4.0/src/bayinx/mhx/opt}/__init__.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/__init__.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/planar.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/radial.py +1 -1
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
- {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/py.typed +0 -0
bayinx-0.4.0/PKG-INFO
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: bayinx
|
|
3
|
+
Version: 0.4.0
|
|
4
|
+
Summary: Bayesian Inference with JAX
|
|
5
|
+
Author: Todd McCready
|
|
6
|
+
Maintainer: Todd McCready
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
|
+
Requires-Dist: equinox>=0.11.12
|
|
10
|
+
Requires-Dist: jax>=0.4.38
|
|
11
|
+
Requires-Dist: jaxtyping>=0.2.36
|
|
12
|
+
Requires-Dist: optax>=0.2.4
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
|
|
15
|
+
# `Bayinx`: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
|
16
|
+
|
|
17
|
+
The original aim of this project was to build a PPL in Python that is similar in feel to `Stan` or `Nimble`(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
|
|
18
|
+
|
|
19
|
+
Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to `Stan` and parse it to valid Rust code. Additionally, the current state of `bayinx` is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough functionality for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite `disize` in Python with JAX, and `bayinx` makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
|
|
20
|
+
|
|
21
|
+
Instead, this project is narrowing on implementing much of `Stan`'s functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
|
|
22
|
+
|
|
23
|
+
```py
|
|
24
|
+
class NormalDist(Model):
|
|
25
|
+
x: Parameter[Array] = define(shape = (2,))
|
|
26
|
+
|
|
27
|
+
def eval(self, data: Dict[str, Array]):
|
|
28
|
+
# Constrain parameters
|
|
29
|
+
self, target = self.constrain_params() # this does nothing for the current model
|
|
30
|
+
|
|
31
|
+
# Evaluate x ~ Normal(10.0, 1.0)
|
|
32
|
+
target += normal.logprob(self.x(), 10.0, 1.0).sum()
|
|
33
|
+
|
|
34
|
+
return target
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for `disize`.
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# TODO
|
|
41
|
+
- For optimization and variational methods offer a way for users to have custom stopping conditions(perhaps stop if a single parameter has converged, etc).
|
|
42
|
+
- Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
|
|
43
|
+
- Low-rank affine flow?
|
|
44
|
+
- https://arxiv.org/pdf/1803.05649 implement sylvester flows.
|
|
45
|
+
- Learn how to generate documentation.
|
|
46
|
+
- Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
|
|
47
|
+
- Look into adaptively tuning ADAM hyperparameters for VI.
|
bayinx-0.4.0/README.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# `Bayinx`: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
|
2
|
+
|
|
3
|
+
The original aim of this project was to build a PPL in Python that is similar in feel to `Stan` or `Nimble`(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
|
|
4
|
+
|
|
5
|
+
Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to `Stan` and parse it to valid Rust code. Additionally, the current state of `bayinx` is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough functionality for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite `disize` in Python with JAX, and `bayinx` makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
|
|
6
|
+
|
|
7
|
+
Instead, this project is narrowing on implementing much of `Stan`'s functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
|
|
8
|
+
|
|
9
|
+
```py
|
|
10
|
+
class NormalDist(Model):
|
|
11
|
+
x: Parameter[Array] = define(shape = (2,))
|
|
12
|
+
|
|
13
|
+
def eval(self, data: Dict[str, Array]):
|
|
14
|
+
# Constrain parameters
|
|
15
|
+
self, target = self.constrain_params() # this does nothing for the current model
|
|
16
|
+
|
|
17
|
+
# Evaluate x ~ Normal(10.0, 1.0)
|
|
18
|
+
target += normal.logprob(self.x(), 10.0, 1.0).sum()
|
|
19
|
+
|
|
20
|
+
return target
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for `disize`.
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# TODO
|
|
27
|
+
- For optimization and variational methods offer a way for users to have custom stopping conditions(perhaps stop if a single parameter has converged, etc).
|
|
28
|
+
- Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
|
|
29
|
+
- Low-rank affine flow?
|
|
30
|
+
- https://arxiv.org/pdf/1803.05649 implement sylvester flows.
|
|
31
|
+
- Learn how to generate documentation.
|
|
32
|
+
- Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
|
|
33
|
+
- Look into adaptively tuning ADAM hyperparameters for VI.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "bayinx"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = "Bayesian Inference with JAX"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.12"
|
|
@@ -10,6 +10,8 @@ dependencies = [
|
|
|
10
10
|
"jaxtyping>=0.2.36",
|
|
11
11
|
"optax>=0.2.4",
|
|
12
12
|
]
|
|
13
|
+
authors = [{ name = "Todd McCready" }]
|
|
14
|
+
maintainers = [{ name = "Todd McCready" }]
|
|
13
15
|
|
|
14
16
|
[build-system]
|
|
15
17
|
requires = ["hatchling"]
|
|
@@ -19,7 +21,7 @@ build-backend = "hatchling.build"
|
|
|
19
21
|
addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
|
|
20
22
|
|
|
21
23
|
[tool.bumpversion]
|
|
22
|
-
current_version = "0.
|
|
24
|
+
current_version = "0.4.0"
|
|
23
25
|
parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
|
|
24
26
|
serialize = ["{major}.{minor}.{patch}"]
|
|
25
27
|
search = "{current_version}"
|
|
@@ -3,9 +3,10 @@ from typing import Tuple
|
|
|
3
3
|
import equinox as eqx
|
|
4
4
|
import jax.numpy as jnp
|
|
5
5
|
import jax.tree as jt
|
|
6
|
-
from jaxtyping import
|
|
6
|
+
from jaxtyping import Scalar, ScalarLike
|
|
7
7
|
|
|
8
8
|
from bayinx.core import Constraint, Parameter
|
|
9
|
+
from bayinx.core._parameter import T
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class Lower(Constraint):
|
|
@@ -18,33 +19,28 @@ class Lower(Constraint):
|
|
|
18
19
|
def __init__(self, lb: ScalarLike):
|
|
19
20
|
self.lb = jnp.array(lb)
|
|
20
21
|
|
|
21
|
-
|
|
22
|
-
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
|
22
|
+
def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
|
|
23
23
|
"""
|
|
24
24
|
Enforces a lower bound on the parameter and adjusts the posterior density.
|
|
25
25
|
|
|
26
26
|
# Parameters
|
|
27
|
-
- `
|
|
27
|
+
- `param`: The unconstrained `Parameter`.
|
|
28
28
|
|
|
29
29
|
# Parameters
|
|
30
30
|
A tuple containing:
|
|
31
31
|
- A modified `Parameter` with relevant leaves satisfying the constraint.
|
|
32
32
|
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
|
33
33
|
"""
|
|
34
|
-
# Extract relevant filter specification
|
|
35
|
-
filter_spec = x.filter_spec
|
|
36
|
-
|
|
37
34
|
# Extract relevant parameters(all Array)
|
|
38
|
-
|
|
35
|
+
dyn, static = eqx.partition(param, param.filter_spec)
|
|
39
36
|
|
|
40
37
|
# Compute density adjustment
|
|
41
|
-
laj:
|
|
42
|
-
laj: Scalar = jt.reduce(lambda a, b: a + b, laj)
|
|
38
|
+
laj: Scalar = jt.reduce(lambda a, b: a + b, jt.map(jnp.sum, dyn))
|
|
43
39
|
|
|
44
40
|
# Compute transformation
|
|
45
|
-
|
|
41
|
+
dyn = jt.map(lambda v: jnp.exp(v) + self.lb, dyn)
|
|
46
42
|
|
|
47
43
|
# Combine into full parameter object
|
|
48
|
-
|
|
44
|
+
param = eqx.combine(dyn, static)
|
|
49
45
|
|
|
50
|
-
return
|
|
46
|
+
return param, laj
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from ._constraint import Constraint
|
|
2
2
|
from ._flow import Flow
|
|
3
|
-
from ._model import Model,
|
|
3
|
+
from ._model import Model, define
|
|
4
4
|
from ._optimization import optimize_model
|
|
5
5
|
from ._parameter import Parameter
|
|
6
6
|
from ._variational import Variational
|
|
@@ -9,7 +9,7 @@ __all__ = [
|
|
|
9
9
|
"Constraint",
|
|
10
10
|
"Flow",
|
|
11
11
|
"Model",
|
|
12
|
-
"
|
|
12
|
+
"define",
|
|
13
13
|
"optimize_model",
|
|
14
14
|
"Parameter",
|
|
15
15
|
"Variational",
|
|
@@ -37,7 +37,6 @@ class Flow(eqx.Module):
|
|
|
37
37
|
|
|
38
38
|
# Default filter specification
|
|
39
39
|
@property
|
|
40
|
-
@eqx.filter_jit
|
|
41
40
|
def filter_spec(self):
|
|
42
41
|
"""
|
|
43
42
|
Generates a filter specification to subset relevant parameters for the flow.
|
|
@@ -54,7 +53,6 @@ class Flow(eqx.Module):
|
|
|
54
53
|
|
|
55
54
|
return filter_spec
|
|
56
55
|
|
|
57
|
-
@eqx.filter_jit
|
|
58
56
|
def constrain_params(self: Self):
|
|
59
57
|
"""
|
|
60
58
|
Constrain `params` to the appropriate domain.
|
|
@@ -69,7 +67,6 @@ class Flow(eqx.Module):
|
|
|
69
67
|
|
|
70
68
|
return t_params
|
|
71
69
|
|
|
72
|
-
@eqx.filter_jit
|
|
73
70
|
def transform_params(self: Self) -> Dict[str, Array]:
|
|
74
71
|
"""
|
|
75
72
|
Apply a custom transformation to `params` if needed.
|
|
@@ -1,19 +1,34 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from dataclasses import field, fields
|
|
3
|
-
from typing import Any, Self, Tuple
|
|
3
|
+
from typing import Any, Dict, Optional, Self, Tuple
|
|
4
4
|
|
|
5
5
|
import equinox as eqx
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import jax.tree as jt
|
|
8
|
-
from jaxtyping import Scalar
|
|
8
|
+
from jaxtyping import PyTree, Scalar
|
|
9
9
|
|
|
10
10
|
from ._constraint import Constraint
|
|
11
11
|
from ._parameter import Parameter
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def
|
|
15
|
-
|
|
16
|
-
|
|
14
|
+
def define(
|
|
15
|
+
shape: Optional[Tuple[int, ...]] = None,
|
|
16
|
+
init: Optional[PyTree] = None,
|
|
17
|
+
constraint: Optional[Constraint] = None
|
|
18
|
+
):
|
|
19
|
+
"""Define a parameter."""
|
|
20
|
+
metadata: Dict = {}
|
|
21
|
+
if constraint is not None:
|
|
22
|
+
metadata["constraint"] = constraint
|
|
23
|
+
|
|
24
|
+
if isinstance(shape, Tuple):
|
|
25
|
+
metadata["shape"] = shape
|
|
26
|
+
elif isinstance(init, PyTree):
|
|
27
|
+
metadata["init"] = init
|
|
28
|
+
else:
|
|
29
|
+
raise TypeError("Neither 'shape' nor 'init' were given as proper arguments.")
|
|
30
|
+
|
|
31
|
+
return field(metadata = metadata)
|
|
17
32
|
|
|
18
33
|
|
|
19
34
|
class Model(eqx.Module):
|
|
@@ -22,16 +37,32 @@ class Model(eqx.Module):
|
|
|
22
37
|
|
|
23
38
|
Annotate parameter attributes with `Parameter`.
|
|
24
39
|
|
|
25
|
-
Include constraints by setting them equal to `
|
|
40
|
+
Include constraints by setting them equal to `define(Constraint)`.
|
|
26
41
|
"""
|
|
27
42
|
|
|
43
|
+
def __new__(cls, *args, **kwargs):
|
|
44
|
+
obj = super().__new__(cls)
|
|
45
|
+
|
|
46
|
+
# Auto-initialize parameters based on `define` metadata
|
|
47
|
+
for f in fields(cls):
|
|
48
|
+
if "shape" in f.metadata:
|
|
49
|
+
# Construct jax Array with correct dimensions
|
|
50
|
+
setattr(obj, f.name, Parameter(jnp.zeros(f.metadata["shape"])))
|
|
51
|
+
elif "init" in f.metadata:
|
|
52
|
+
# Slot in given 'init' object
|
|
53
|
+
setattr(obj, f.name, Parameter(f.metadata["init"]))
|
|
54
|
+
|
|
55
|
+
return obj
|
|
56
|
+
|
|
57
|
+
def __init__(self):
|
|
58
|
+
return self
|
|
59
|
+
|
|
28
60
|
@abstractmethod
|
|
29
61
|
def eval(self, data: Any) -> Scalar:
|
|
30
62
|
pass
|
|
31
63
|
|
|
32
64
|
# Default filter specification
|
|
33
65
|
@property
|
|
34
|
-
@eqx.filter_jit
|
|
35
66
|
def filter_spec(self) -> Self:
|
|
36
67
|
"""
|
|
37
68
|
Generates a filter specification to subset relevant parameters for the model.
|
|
@@ -49,12 +80,11 @@ class Model(eqx.Module):
|
|
|
49
80
|
filter_spec = eqx.tree_at(
|
|
50
81
|
lambda model: getattr(model, f.name),
|
|
51
82
|
filter_spec,
|
|
52
|
-
replace=attr.filter_spec
|
|
83
|
+
replace=attr.filter_spec
|
|
53
84
|
)
|
|
54
85
|
|
|
55
86
|
return filter_spec
|
|
56
87
|
|
|
57
|
-
@eqx.filter_jit
|
|
58
88
|
def constrain_params(self) -> Tuple[Self, Scalar]:
|
|
59
89
|
"""
|
|
60
90
|
Constrain parameters to the appropriate domain.
|
|
@@ -70,14 +100,14 @@ class Model(eqx.Module):
|
|
|
70
100
|
attr = getattr(self, f.name)
|
|
71
101
|
|
|
72
102
|
# Check if constrained parameter
|
|
73
|
-
if isinstance(attr, Parameter) and "constraint" in f.metadata:
|
|
103
|
+
if isinstance(attr, Parameter) and ("constraint" in f.metadata):
|
|
74
104
|
param = attr
|
|
75
105
|
constraint = f.metadata["constraint"]
|
|
76
106
|
|
|
77
107
|
# Apply constraint
|
|
78
108
|
param, laj = constraint.constrain(param)
|
|
79
109
|
|
|
80
|
-
# Update parameters for constrained model
|
|
110
|
+
# Update parameters for constrained model at same node
|
|
81
111
|
constrained = eqx.tree_at(
|
|
82
112
|
lambda model: getattr(model, f.name), constrained, replace=param
|
|
83
113
|
)
|
|
@@ -87,7 +117,6 @@ class Model(eqx.Module):
|
|
|
87
117
|
|
|
88
118
|
return constrained, target
|
|
89
119
|
|
|
90
|
-
@eqx.filter_jit
|
|
91
120
|
def transform_params(self) -> Tuple[Self, Scalar]:
|
|
92
121
|
"""
|
|
93
122
|
Apply a custom transformation to parameters if needed(defaults to constrained parameters).
|
|
@@ -10,6 +10,8 @@ from optax import GradientTransformation, OptState, Schedule
|
|
|
10
10
|
from ._model import Model
|
|
11
11
|
|
|
12
12
|
M = TypeVar("M", bound=Model)
|
|
13
|
+
|
|
14
|
+
|
|
13
15
|
@eqx.filter_jit
|
|
14
16
|
def optimize_model(
|
|
15
17
|
model: M,
|
|
@@ -39,6 +41,7 @@ def optimize_model(
|
|
|
39
41
|
|
|
40
42
|
# Evaluate posterior
|
|
41
43
|
return model.eval(data)
|
|
44
|
+
|
|
42
45
|
eval_grad: Callable[[M], M] = eqx.filter_jit(eqx.filter_grad(eval))
|
|
43
46
|
|
|
44
47
|
# Construct scheduler
|
|
@@ -5,6 +5,8 @@ import jax.tree as jt
|
|
|
5
5
|
from jaxtyping import PyTree
|
|
6
6
|
|
|
7
7
|
T = TypeVar("T", bound=PyTree)
|
|
8
|
+
|
|
9
|
+
|
|
8
10
|
class Parameter(eqx.Module, Generic[T]):
|
|
9
11
|
"""
|
|
10
12
|
A container for a parameter of a `Model`.
|
|
@@ -26,19 +28,18 @@ class Parameter(eqx.Module, Generic[T]):
|
|
|
26
28
|
|
|
27
29
|
# Default filter specification
|
|
28
30
|
@property
|
|
29
|
-
@eqx.filter_jit
|
|
30
31
|
def filter_spec(self) -> Self:
|
|
31
32
|
"""
|
|
32
|
-
Generates a filter specification to filter
|
|
33
|
+
Generates a filter specification to filter for dynamic parameters.
|
|
33
34
|
"""
|
|
34
35
|
# Generate empty specification
|
|
35
|
-
filter_spec = jt.map(lambda _: False, self)
|
|
36
|
+
filter_spec: Self = jt.map(lambda _: False, self)
|
|
36
37
|
|
|
37
|
-
# Specify Array leaves
|
|
38
|
+
# Specify Array-like leaves
|
|
38
39
|
filter_spec = eqx.tree_at(
|
|
39
|
-
lambda
|
|
40
|
+
lambda param: param.vals,
|
|
40
41
|
filter_spec,
|
|
41
|
-
replace=jt.map(eqx.
|
|
42
|
+
replace=jt.map(eqx.is_inexact_array_like, self.vals),
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
return filter_spec
|
|
@@ -22,11 +22,11 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
22
22
|
|
|
23
23
|
# Attributes
|
|
24
24
|
- `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
|
|
25
|
-
- `
|
|
25
|
+
- `_static`: The static component of a partitioned `Model` used to initialize the `Variational` object.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
28
|
_unflatten: Callable[[Array], M]
|
|
29
|
-
|
|
29
|
+
_static: M
|
|
30
30
|
|
|
31
31
|
@abstractmethod
|
|
32
32
|
def filter_spec(self):
|
|
@@ -69,7 +69,7 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
69
69
|
model: M = self._unflatten(draw)
|
|
70
70
|
|
|
71
71
|
# Combine with constraints
|
|
72
|
-
model: M = eqx.combine(model, self.
|
|
72
|
+
model: M = eqx.combine(model, self._static)
|
|
73
73
|
|
|
74
74
|
return model
|
|
75
75
|
|
|
@@ -89,7 +89,6 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
89
89
|
# Evaluate posterior density
|
|
90
90
|
return model.eval(data)
|
|
91
91
|
|
|
92
|
-
# TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
|
|
93
92
|
@eqx.filter_jit
|
|
94
93
|
def fit(
|
|
95
94
|
self,
|
|
@@ -116,11 +115,9 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
116
115
|
dyn, static = eqx.partition(self, self.filter_spec)
|
|
117
116
|
|
|
118
117
|
# Construct scheduler
|
|
119
|
-
schedule: Schedule = opx.
|
|
120
|
-
init_value=
|
|
121
|
-
|
|
122
|
-
warmup_steps=int(max_iters / 10),
|
|
123
|
-
decay_steps=max_iters - int(max_iters / 10),
|
|
118
|
+
schedule: Schedule = opx.cosine_decay_schedule(
|
|
119
|
+
init_value=learning_rate,
|
|
120
|
+
decay_steps=max_iters,
|
|
124
121
|
)
|
|
125
122
|
|
|
126
123
|
# Initialize optimizer
|
|
@@ -175,7 +172,7 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
175
172
|
return eqx.combine(dyn, static)
|
|
176
173
|
|
|
177
174
|
@eqx.filter_jit
|
|
178
|
-
def
|
|
175
|
+
def _posterior_predictive(
|
|
179
176
|
self,
|
|
180
177
|
func: Callable[[M, Any], Array],
|
|
181
178
|
n: int,
|
|
@@ -111,6 +111,11 @@ def sample(
|
|
|
111
111
|
|
|
112
112
|
# Construct draws
|
|
113
113
|
draws = jr.uniform(key, shape)
|
|
114
|
-
draws = mu + sigma * ndtri(
|
|
114
|
+
draws = mu + sigma * ndtri(
|
|
115
|
+
normal.cdf(-mu / sigma, 0.0, 1.0) + draws * normal.cdf(mu / sigma, 0.0, 1.0)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Censor draws
|
|
119
|
+
draws.at[censor <= draws].set(censor)
|
|
115
120
|
|
|
116
121
|
return draws
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax.scipy.special import gammaln
|
|
3
|
+
from jaxtyping import Array, ArrayLike, Float, UInt
|
|
4
|
+
|
|
5
|
+
__PI = 3.141592653589793
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def __binom(x, y):
|
|
9
|
+
"""
|
|
10
|
+
Helper function for the Binomial coefficient.
|
|
11
|
+
"""
|
|
12
|
+
return jnp.exp(gammaln(x + 1) - gammaln(y + 1) - gammaln(x - y + 1))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def prob(
|
|
16
|
+
x: UInt[ArrayLike, "..."],
|
|
17
|
+
mu: Float[ArrayLike, "..."],
|
|
18
|
+
phi: Float[ArrayLike, "..."],
|
|
19
|
+
) -> Float[Array, "..."]:
|
|
20
|
+
"""
|
|
21
|
+
The probability mass function (PMF) for a (mean-inverse overdispersion parameterized) Negatvie Binomial distribution.
|
|
22
|
+
|
|
23
|
+
# Parameters
|
|
24
|
+
- `x`: Where to evaluate the PMF.
|
|
25
|
+
- `mu`: The mean.
|
|
26
|
+
- `phi`: The inverse overdispersion.
|
|
27
|
+
|
|
28
|
+
# Returns
|
|
29
|
+
The PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `phi`.
|
|
30
|
+
"""
|
|
31
|
+
# Cast to Array
|
|
32
|
+
x, mu, phi = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(phi)
|
|
33
|
+
|
|
34
|
+
return jnp.exp(logprob(x, mu, phi))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def logprob(
|
|
38
|
+
x: UInt[ArrayLike, "..."],
|
|
39
|
+
mu: Float[ArrayLike, "..."],
|
|
40
|
+
phi: Float[ArrayLike, "..."],
|
|
41
|
+
) -> Float[Array, "..."]:
|
|
42
|
+
"""
|
|
43
|
+
The log-transformed probability mass function (PMF) for a (mean-inverse overdispersion parameterized) Negatvie Binomial distribution.
|
|
44
|
+
|
|
45
|
+
# Parameters
|
|
46
|
+
- `x`: Where to evaluate the log PMF.
|
|
47
|
+
- `mu`: The mean.
|
|
48
|
+
- `phi`: The inverse overdispersion.
|
|
49
|
+
|
|
50
|
+
# Returns
|
|
51
|
+
The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `phi`.
|
|
52
|
+
"""
|
|
53
|
+
# Cast to Array
|
|
54
|
+
x, mu, phi = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(phi)
|
|
55
|
+
|
|
56
|
+
# Evaluate log PMF
|
|
57
|
+
evals: Array = jnp.where(
|
|
58
|
+
x < 0,
|
|
59
|
+
-jnp.inf,
|
|
60
|
+
(
|
|
61
|
+
gammaln(x + phi)
|
|
62
|
+
- gammaln(x + 1)
|
|
63
|
+
- gammaln(phi)
|
|
64
|
+
+ x * (jnp.log(mu) - jnp.log(mu + phi))
|
|
65
|
+
+ phi * (jnp.log(phi) - jnp.log(mu + phi))
|
|
66
|
+
),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return evals
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def cdf(
|
|
73
|
+
x: Float[ArrayLike, "..."],
|
|
74
|
+
mu: Float[ArrayLike, "..."],
|
|
75
|
+
sigma: Float[ArrayLike, "..."],
|
|
76
|
+
) -> Float[Array, "..."]:
|
|
77
|
+
# Cast to Array
|
|
78
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
79
|
+
|
|
80
|
+
return jnp.array(1.0)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def logcdf(
|
|
84
|
+
x: Float[ArrayLike, "..."],
|
|
85
|
+
mu: Float[ArrayLike, "..."],
|
|
86
|
+
sigma: Float[ArrayLike, "..."],
|
|
87
|
+
) -> Float[Array, "..."]:
|
|
88
|
+
# Cast to Array
|
|
89
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
90
|
+
|
|
91
|
+
return jnp.array(1.0)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def ccdf(
|
|
95
|
+
x: Float[ArrayLike, "..."],
|
|
96
|
+
mu: Float[ArrayLike, "..."],
|
|
97
|
+
sigma: Float[ArrayLike, "..."],
|
|
98
|
+
) -> Float[Array, "..."]:
|
|
99
|
+
# Cast to Array
|
|
100
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
101
|
+
|
|
102
|
+
return jnp.array(1.0)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def logccdf(
|
|
106
|
+
x: Float[ArrayLike, "..."],
|
|
107
|
+
mu: Float[ArrayLike, "..."],
|
|
108
|
+
sigma: Float[ArrayLike, "..."],
|
|
109
|
+
) -> Float[Array, "..."]:
|
|
110
|
+
# Cast to Array
|
|
111
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
112
|
+
|
|
113
|
+
return jnp.array(1.0)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import jax.numpy as jnp
|
|
2
|
-
|
|
2
|
+
import jax.random as jr
|
|
3
|
+
from jax.scipy.special import ndtri
|
|
4
|
+
from jaxtyping import Array, ArrayLike, Float, Key
|
|
3
5
|
|
|
4
6
|
from bayinx.dists import normal
|
|
5
7
|
|
|
@@ -251,3 +253,38 @@ def logccdf(
|
|
|
251
253
|
evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
|
|
252
254
|
|
|
253
255
|
return evals
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def sample(
|
|
259
|
+
n: int,
|
|
260
|
+
mu: Float[ArrayLike, "..."],
|
|
261
|
+
sigma: Float[ArrayLike, "..."],
|
|
262
|
+
key: Key = jr.PRNGKey(0),
|
|
263
|
+
) -> Float[Array, "..."]:
|
|
264
|
+
"""
|
|
265
|
+
Sample from a positive Normal distribution.
|
|
266
|
+
|
|
267
|
+
# Parameters
|
|
268
|
+
- `n`: Number of draws to sample per-parameter.
|
|
269
|
+
- `mu`: The mean.
|
|
270
|
+
- `sigma`: The standard deviation.
|
|
271
|
+
|
|
272
|
+
# Returns
|
|
273
|
+
Draws from a positive Normal distribution. The output will have the shape of (n,) + the broadcasted shapes of `mu` and `sigma`.
|
|
274
|
+
"""
|
|
275
|
+
# Cast to Array
|
|
276
|
+
mu, sigma = (
|
|
277
|
+
jnp.asarray(mu),
|
|
278
|
+
jnp.asarray(sigma),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Derive shape
|
|
282
|
+
shape = (n,) + jnp.broadcast_shapes(mu.shape, sigma.shape)
|
|
283
|
+
|
|
284
|
+
# Construct draws
|
|
285
|
+
draws = jr.uniform(key, shape)
|
|
286
|
+
draws = mu + sigma * ndtri(
|
|
287
|
+
normal.cdf(-mu / sigma, 0.0, 1.0) + draws * normal.cdf(mu / sigma, 0.0, 1.0)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
return draws
|
|
@@ -83,8 +83,12 @@ def ulogprob(
|
|
|
83
83
|
|
|
84
84
|
return jnp.zeros(jnp.broadcast_shapes(x.shape, lb.shape, ub.shape))
|
|
85
85
|
|
|
86
|
+
|
|
86
87
|
def sample(
|
|
87
|
-
n: int,
|
|
88
|
+
n: int,
|
|
89
|
+
lb: Float[ArrayLike, "..."],
|
|
90
|
+
ub: Float[ArrayLike, "..."],
|
|
91
|
+
key: Key = jr.PRNGKey(0),
|
|
88
92
|
) -> Float[Array, "..."]:
|
|
89
93
|
"""
|
|
90
94
|
Sample from a Uniform distribution.
|
|
@@ -104,6 +108,6 @@ def sample(
|
|
|
104
108
|
shape = (n,) + jnp.broadcast_shapes(lb.shape, ub.shape)
|
|
105
109
|
|
|
106
110
|
# Construct draws
|
|
107
|
-
draws = jr.uniform(key, shape, minval
|
|
111
|
+
draws = jr.uniform(key, shape, minval=lb, maxval=ub)
|
|
108
112
|
|
|
109
113
|
return draws
|
|
@@ -34,7 +34,7 @@ class MeanField(Variational, Generic[M]):
|
|
|
34
34
|
- `init_log_std`: The initial log-transformed standard deviation of the Gaussian approximation.
|
|
35
35
|
"""
|
|
36
36
|
# Partition model
|
|
37
|
-
params, self.
|
|
37
|
+
params, self._static = eqx.partition(model, model.filter_spec)
|
|
38
38
|
|
|
39
39
|
# Flatten params component
|
|
40
40
|
params, self._unflatten = ravel_pytree(params)
|
|
@@ -44,7 +44,6 @@ class MeanField(Variational, Generic[M]):
|
|
|
44
44
|
self.log_std = jnp.full(params.size, init_log_std, params.dtype)
|
|
45
45
|
|
|
46
46
|
@property
|
|
47
|
-
@eqx.filter_jit
|
|
48
47
|
def filter_spec(self):
|
|
49
48
|
# Generate empty specification
|
|
50
49
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
|
@@ -67,8 +66,7 @@ class MeanField(Variational, Generic[M]):
|
|
|
67
66
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
|
68
67
|
# Sample variational draws
|
|
69
68
|
draws: Array = (
|
|
70
|
-
jr.normal(key=key, shape=(n, self.mean.size))
|
|
71
|
-
* jnp.exp(self.log_std)
|
|
69
|
+
jr.normal(key=key, shape=(n, self.mean.size)) * jnp.exp(self.log_std)
|
|
72
70
|
+ self.mean
|
|
73
71
|
)
|
|
74
72
|
|
|
@@ -108,10 +106,9 @@ class MeanField(Variational, Generic[M]):
|
|
|
108
106
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
|
109
107
|
dyn, static = eqx.partition(self, self.filter_spec)
|
|
110
108
|
|
|
111
|
-
@eqx.filter_grad
|
|
112
109
|
@eqx.filter_jit
|
|
110
|
+
@eqx.filter_grad
|
|
113
111
|
def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
|
|
114
|
-
# Combine
|
|
115
112
|
vari = eqx.combine(dyn, static)
|
|
116
113
|
|
|
117
114
|
# Sample draws from variational distribution
|