bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +4 -3
- bayinx/constraints/identity.py +26 -0
- bayinx/constraints/interval.py +62 -0
- bayinx/constraints/lower.py +31 -24
- bayinx/constraints/upper.py +57 -0
- bayinx/core/__init__.py +0 -7
- bayinx/core/constraint.py +32 -0
- bayinx/core/context.py +42 -0
- bayinx/core/distribution.py +34 -0
- bayinx/core/flow.py +99 -0
- bayinx/core/model.py +228 -0
- bayinx/core/node.py +201 -0
- bayinx/core/types.py +17 -0
- bayinx/core/utils.py +109 -0
- bayinx/core/variational.py +170 -0
- bayinx/dists/__init__.py +5 -3
- bayinx/dists/bernoulli.py +180 -11
- bayinx/dists/binomial.py +215 -0
- bayinx/dists/exponential.py +211 -0
- bayinx/dists/normal.py +131 -59
- bayinx/dists/poisson.py +203 -0
- bayinx/flows/__init__.py +5 -0
- bayinx/flows/diagaffine.py +120 -0
- bayinx/flows/fullaffine.py +123 -0
- bayinx/flows/lowrankaffine.py +165 -0
- bayinx/flows/planar.py +155 -0
- bayinx/flows/radial.py +1 -0
- bayinx/flows/sylvester.py +225 -0
- bayinx/nodes/__init__.py +3 -0
- bayinx/nodes/continuous.py +64 -0
- bayinx/nodes/observed.py +36 -0
- bayinx/nodes/stochastic.py +25 -0
- bayinx/ops.py +104 -0
- bayinx/posterior.py +220 -0
- bayinx/vi/__init__.py +0 -0
- bayinx/{mhx/vi → vi}/meanfield.py +33 -29
- bayinx/vi/normalizing_flow.py +246 -0
- bayinx/vi/standard.py +95 -0
- bayinx-0.5.3.dist-info/METADATA +93 -0
- bayinx-0.5.3.dist-info/RECORD +44 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
- bayinx/core/_constraint.py +0 -28
- bayinx/core/_flow.py +0 -80
- bayinx/core/_model.py +0 -98
- bayinx/core/_parameter.py +0 -44
- bayinx/core/_variational.py +0 -181
- bayinx/dists/censored/__init__.py +0 -3
- bayinx/dists/censored/gamma2/__init__.py +0 -3
- bayinx/dists/censored/gamma2/r.py +0 -68
- bayinx/dists/censored/posnormal/__init__.py +0 -3
- bayinx/dists/censored/posnormal/r.py +0 -116
- bayinx/dists/gamma2.py +0 -49
- bayinx/dists/posnormal.py +0 -260
- bayinx/dists/uniform.py +0 -75
- bayinx/mhx/__init__.py +0 -1
- bayinx/mhx/vi/__init__.py +0 -5
- bayinx/mhx/vi/flows/__init__.py +0 -3
- bayinx/mhx/vi/flows/fullaffine.py +0 -75
- bayinx/mhx/vi/flows/planar.py +0 -74
- bayinx/mhx/vi/flows/radial.py +0 -94
- bayinx/mhx/vi/flows/sylvester.py +0 -19
- bayinx/mhx/vi/normalizing_flow.py +0 -149
- bayinx/mhx/vi/standard.py +0 -63
- bayinx-0.3.10.dist-info/METADATA +0 -39
- bayinx-0.3.10.dist-info/RECORD +0 -35
- /bayinx/{py.typed → flows/otflow.py} +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/core/node.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from typing import Any, Generic, Iterator, Self
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax.tree as jt
|
|
5
|
+
from jaxtyping import PyTree
|
|
6
|
+
|
|
7
|
+
from bayinx.core.types import T
|
|
8
|
+
from bayinx.core.utils import _extract_obj, _merge_filter_specs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Node(eqx.Module, Generic[T]):
|
|
12
|
+
"""
|
|
13
|
+
A thin wrapper for nodes of a probabilistic model.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Attributes
|
|
17
|
+
- `obj`: An internal realization of the node.
|
|
18
|
+
- `_filter_spec`: An internal filter specification for `obj`.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
obj: T
|
|
22
|
+
_filter_spec: PyTree
|
|
23
|
+
|
|
24
|
+
def __init__(self, obj, filter_spec):
|
|
25
|
+
self.obj = obj
|
|
26
|
+
self._filter_spec = filter_spec
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def filter_spec(self) -> Self:
|
|
30
|
+
"""
|
|
31
|
+
An outer filter specification for the full node.
|
|
32
|
+
"""
|
|
33
|
+
# Generate empty specification
|
|
34
|
+
node_filter_spec: Self = jt.map(lambda _: False, self)
|
|
35
|
+
|
|
36
|
+
# Filter based on inner filter specification for 'obj'
|
|
37
|
+
node_filter_spec = eqx.tree_at(
|
|
38
|
+
lambda node: node.obj,
|
|
39
|
+
node_filter_spec,
|
|
40
|
+
replace=self._filter_spec,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
return node_filter_spec
|
|
44
|
+
|
|
45
|
+
# Wrappers around internal dunder methods ----
|
|
46
|
+
def __getitem__(self, key: Any) -> "Node":
|
|
47
|
+
if isinstance(key, Node):
|
|
48
|
+
raise TypeError("Subsetting nodes with nodes is not yet supported.")
|
|
49
|
+
|
|
50
|
+
# Subset internally
|
|
51
|
+
new_obj = self.obj[key]
|
|
52
|
+
if type(self.obj) is type(self._filter_spec):
|
|
53
|
+
new_filter_spec = self._filter_spec[key]
|
|
54
|
+
else:
|
|
55
|
+
new_filter_spec = self._filter_spec
|
|
56
|
+
|
|
57
|
+
# Create new subsetted node
|
|
58
|
+
return type(self)(new_obj, new_filter_spec)
|
|
59
|
+
|
|
60
|
+
def __iter__(self) -> Iterator["Node"]:
|
|
61
|
+
for obj_i, spec_i in zip(self.obj, self._filter_spec):
|
|
62
|
+
# Create a new Node for the current element
|
|
63
|
+
yield Node(obj_i, spec_i)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
## Arithmetic ----
|
|
67
|
+
def __add__(self, other: Any) -> "Node":
|
|
68
|
+
# Extract internal objects and their filter specifications
|
|
69
|
+
lhs_obj, lhs_filter_spec = _extract_obj(self)
|
|
70
|
+
rhs_obj, rhs_filter_spec = _extract_obj(other)
|
|
71
|
+
|
|
72
|
+
# Perform addition
|
|
73
|
+
new_obj = lhs_obj + rhs_obj
|
|
74
|
+
|
|
75
|
+
# Merge filter specifications
|
|
76
|
+
new_filter_spec = _merge_filter_specs(
|
|
77
|
+
[lhs_filter_spec, rhs_filter_spec],
|
|
78
|
+
[lhs_obj, rhs_obj],
|
|
79
|
+
new_obj
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return Node(new_obj, new_filter_spec)
|
|
83
|
+
|
|
84
|
+
def __sub__(self, other: Any) -> "Node":
|
|
85
|
+
# Extract internal objects and their filter specifications
|
|
86
|
+
lhs_obj, lhs_filter_spec = _extract_obj(self)
|
|
87
|
+
rhs_obj, rhs_filter_spec = _extract_obj(other)
|
|
88
|
+
|
|
89
|
+
# Perform subtraction
|
|
90
|
+
new_obj = lhs_obj - rhs_obj
|
|
91
|
+
|
|
92
|
+
# Merge filter specifications
|
|
93
|
+
new_filter_spec = _merge_filter_specs(
|
|
94
|
+
[lhs_filter_spec, rhs_filter_spec],
|
|
95
|
+
[lhs_obj, rhs_obj],
|
|
96
|
+
new_obj
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return Node(new_obj, new_filter_spec)
|
|
100
|
+
|
|
101
|
+
def __mul__(self, other: Any) -> "Node":
|
|
102
|
+
# Extract internal objects and their filter specifications
|
|
103
|
+
lhs_obj, lhs_filter_spec = _extract_obj(self)
|
|
104
|
+
rhs_obj, rhs_filter_spec = _extract_obj(other)
|
|
105
|
+
|
|
106
|
+
# Perform multiplication
|
|
107
|
+
new_obj = lhs_obj * rhs_obj
|
|
108
|
+
|
|
109
|
+
# Merge filter specifications
|
|
110
|
+
new_filter_spec = _merge_filter_specs(
|
|
111
|
+
[lhs_filter_spec, rhs_filter_spec],
|
|
112
|
+
[lhs_obj, rhs_obj],
|
|
113
|
+
new_obj
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return Node(new_obj, new_filter_spec)
|
|
117
|
+
|
|
118
|
+
def __matmul__(self, other: Any) -> "Node":
|
|
119
|
+
# Extract internal objects and their filter specifications
|
|
120
|
+
lhs_obj, lhs_filter_spec = _extract_obj(self)
|
|
121
|
+
rhs_obj, rhs_filter_spec = _extract_obj(other)
|
|
122
|
+
|
|
123
|
+
# Perform matrix multiplication
|
|
124
|
+
new_obj = lhs_obj @ rhs_obj
|
|
125
|
+
|
|
126
|
+
# Merge filter specifications
|
|
127
|
+
new_filter_spec = _merge_filter_specs(
|
|
128
|
+
[lhs_filter_spec, rhs_filter_spec],
|
|
129
|
+
[lhs_obj, rhs_obj],
|
|
130
|
+
new_obj
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return Node(new_obj, new_filter_spec)
|
|
134
|
+
|
|
135
|
+
def __truediv__(self, other: Any) -> "Node":
|
|
136
|
+
# Extract internal objects and their filter specifications
|
|
137
|
+
lhs_obj, lhs_filter_spec = _extract_obj(self)
|
|
138
|
+
rhs_obj, rhs_filter_spec = _extract_obj(other)
|
|
139
|
+
|
|
140
|
+
# Perform true division
|
|
141
|
+
new_obj = lhs_obj / rhs_obj
|
|
142
|
+
|
|
143
|
+
# Merge filter specifications
|
|
144
|
+
new_filter_spec = _merge_filter_specs(
|
|
145
|
+
[lhs_filter_spec, rhs_filter_spec],
|
|
146
|
+
[lhs_obj, rhs_obj],
|
|
147
|
+
new_obj
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return Node(new_obj, new_filter_spec)
|
|
151
|
+
|
|
152
|
+
def __floordiv__(self, other: Any) -> "Node":
|
|
153
|
+
# Extract internal objects and their filter specifications
|
|
154
|
+
lhs_obj, lhs_filter_spec = _extract_obj(self)
|
|
155
|
+
rhs_obj, rhs_filter_spec = _extract_obj(other)
|
|
156
|
+
|
|
157
|
+
# Perform floor division
|
|
158
|
+
new_obj = lhs_obj // rhs_obj
|
|
159
|
+
|
|
160
|
+
# Merge filter specifications
|
|
161
|
+
new_filter_spec = _merge_filter_specs(
|
|
162
|
+
[lhs_filter_spec, rhs_filter_spec],
|
|
163
|
+
[lhs_obj, rhs_obj],
|
|
164
|
+
new_obj
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
return Node(new_obj, new_filter_spec)
|
|
168
|
+
|
|
169
|
+
def __pow__(self, other: Any) -> "Node":
|
|
170
|
+
# Extract internal objects and their filter specifications
|
|
171
|
+
lhs_obj, lhs_filter_spec = _extract_obj(self)
|
|
172
|
+
rhs_obj, rhs_filter_spec = _extract_obj(other)
|
|
173
|
+
|
|
174
|
+
# Perform floor division
|
|
175
|
+
new_obj = lhs_obj ** rhs_obj
|
|
176
|
+
|
|
177
|
+
# Merge filter specifications
|
|
178
|
+
new_filter_spec = _merge_filter_specs(
|
|
179
|
+
[lhs_filter_spec, rhs_filter_spec],
|
|
180
|
+
[lhs_obj, rhs_obj],
|
|
181
|
+
new_obj
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
return Node(new_obj, new_filter_spec)
|
|
185
|
+
|
|
186
|
+
def __mod__(self, other: Any) -> "Node":
|
|
187
|
+
# Extract internal objects and their filter specifications
|
|
188
|
+
lhs_obj, lhs_filter_spec = _extract_obj(self)
|
|
189
|
+
rhs_obj, rhs_filter_spec = _extract_obj(other)
|
|
190
|
+
|
|
191
|
+
# Perform modulus
|
|
192
|
+
new_obj = lhs_obj % rhs_obj
|
|
193
|
+
|
|
194
|
+
# Merge filter specifications
|
|
195
|
+
new_filter_spec = _merge_filter_specs(
|
|
196
|
+
[lhs_filter_spec, rhs_filter_spec],
|
|
197
|
+
[lhs_obj, rhs_obj],
|
|
198
|
+
new_obj
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return Node(new_obj, new_filter_spec)
|
bayinx/core/types.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from typing import Generic, Protocol, TypeVar, runtime_checkable
|
|
2
|
+
|
|
3
|
+
from jaxtyping import PyTree
|
|
4
|
+
|
|
5
|
+
from bayinx.core.constraint import Constraint
|
|
6
|
+
|
|
7
|
+
T = TypeVar("T", bound=PyTree)
|
|
8
|
+
|
|
9
|
+
@runtime_checkable
|
|
10
|
+
class HasConstraint(Protocol, Generic[T]):
|
|
11
|
+
"""
|
|
12
|
+
Protocol for probabilistic nodes that have constraints.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
obj: T
|
|
16
|
+
_filter_spec: PyTree
|
|
17
|
+
_constraint: Constraint
|
bayinx/core/utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import jax.tree as jt
|
|
5
|
+
from jaxtyping import PyTree
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _extract_shape_params(shape_spec: int | str | Tuple[int | str, ...]) -> set[str]:
|
|
9
|
+
"""
|
|
10
|
+
Extract parameter names from shape specification.
|
|
11
|
+
"""
|
|
12
|
+
params = set()
|
|
13
|
+
if isinstance(shape_spec, str):
|
|
14
|
+
params.add(shape_spec)
|
|
15
|
+
elif isinstance(shape_spec, tuple):
|
|
16
|
+
for item in shape_spec:
|
|
17
|
+
if isinstance(item, str):
|
|
18
|
+
params.add(item)
|
|
19
|
+
#
|
|
20
|
+
return params
|
|
21
|
+
|
|
22
|
+
def _resolve_shape_spec(
|
|
23
|
+
shape_spec: None | int | str | Tuple[int | str, ...],
|
|
24
|
+
shape_values: Dict[str, int]
|
|
25
|
+
) -> None | Tuple[int, ...]:
|
|
26
|
+
"""
|
|
27
|
+
Replaces named dimensions in a shape specification with their integer or tuple values.
|
|
28
|
+
|
|
29
|
+
# Example
|
|
30
|
+
For `shape_values = {'k': 5, 's': (3,2,1)}`:
|
|
31
|
+
`('k', 4, 's')` --> `(5, 4, 3, 2, 1)`
|
|
32
|
+
"""
|
|
33
|
+
if shape_spec is None:
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
# Coerce to tuple for uniform processing
|
|
37
|
+
if isinstance(shape_spec, (int, str)):
|
|
38
|
+
shape_spec = (shape_spec,)
|
|
39
|
+
|
|
40
|
+
resolved_spec: List[int] = []
|
|
41
|
+
for dim in shape_spec:
|
|
42
|
+
if isinstance(dim, str):
|
|
43
|
+
if dim in shape_values:
|
|
44
|
+
resolved_value = shape_values[dim] # Grab initialized value
|
|
45
|
+
|
|
46
|
+
if isinstance(resolved_value, int):
|
|
47
|
+
# Scalar shape dimension (e.g., 'k' -> 3)
|
|
48
|
+
resolved_spec.append(resolved_value)
|
|
49
|
+
elif isinstance(resolved_value, tuple):
|
|
50
|
+
# Packed shape dimension (e.g., 'shape' -> (3, 2, 1))
|
|
51
|
+
resolved_spec.extend(resolved_value)
|
|
52
|
+
else:
|
|
53
|
+
raise TypeError(f"Shape parameter '{dim}' resolved to an unsupported type: {type(resolved_value).__name__}")
|
|
54
|
+
else: # dim not in shape_values
|
|
55
|
+
raise TypeError(f"Shape parameter '{dim}' was not initialized with a value.")
|
|
56
|
+
elif isinstance(dim, int):
|
|
57
|
+
# Literal integer (e.g., 3 -> 3)
|
|
58
|
+
resolved_spec.append(dim)
|
|
59
|
+
else:
|
|
60
|
+
raise TypeError(f"Shape parameter {dim} was incorrectly specified (must be 'int' or 'str', got '{type(dim).__name__}').")
|
|
61
|
+
|
|
62
|
+
return tuple(resolved_spec)
|
|
63
|
+
|
|
64
|
+
def _extract_obj(x: Any) -> Tuple[Any, Any]:
|
|
65
|
+
"""
|
|
66
|
+
Extract the object and its (potentially implicit) filter specification.
|
|
67
|
+
"""
|
|
68
|
+
from bayinx.core.node import Node
|
|
69
|
+
|
|
70
|
+
if isinstance(x, Node):
|
|
71
|
+
obj = x.obj
|
|
72
|
+
filter_spec = x._filter_spec
|
|
73
|
+
else:
|
|
74
|
+
obj: Any = x # type: ignore
|
|
75
|
+
filter_spec = True # implicit filter specification
|
|
76
|
+
|
|
77
|
+
return (obj, filter_spec)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _merge_filter_specs(
|
|
81
|
+
filter_specs: List[PyTree],
|
|
82
|
+
objs: Optional[List[PyTree]] = None,
|
|
83
|
+
obj: Optional[PyTree] = None
|
|
84
|
+
) -> PyTree:
|
|
85
|
+
"""
|
|
86
|
+
Merge filter specifications that share the type of `obj`.
|
|
87
|
+
|
|
88
|
+
If `obj` and `objs` are not provided then `filter_specs` will be merged as is.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def _merge(*args):
|
|
92
|
+
return all(args)
|
|
93
|
+
|
|
94
|
+
if objs is None and obj is None:
|
|
95
|
+
filter_spec: Any = jt.map(_merge, *filter_specs) # type: ignore
|
|
96
|
+
else:
|
|
97
|
+
selected_specs: List[PyTree] = []
|
|
98
|
+
|
|
99
|
+
# Include filter specs whose objects share the correct type
|
|
100
|
+
for cur_obj, cur_spec in zip(objs, filter_specs): # type: ignore
|
|
101
|
+
if type(obj) is type(cur_obj):
|
|
102
|
+
selected_specs.append(cur_spec)
|
|
103
|
+
|
|
104
|
+
if len(selected_specs) != 0:
|
|
105
|
+
filter_spec = jt.map(_merge, *selected_specs)
|
|
106
|
+
else:
|
|
107
|
+
filter_spec = True
|
|
108
|
+
|
|
109
|
+
return filter_spec
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Callable, Generic, Self, Tuple, TypeVar
|
|
4
|
+
|
|
5
|
+
import equinox as eqx
|
|
6
|
+
import jax
|
|
7
|
+
import jax.lax as lax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import jax.random as jr
|
|
10
|
+
import optax as opx
|
|
11
|
+
from jaxtyping import Array, PRNGKeyArray, PyTree, Scalar
|
|
12
|
+
from optax import GradientTransformation, OptState
|
|
13
|
+
|
|
14
|
+
from bayinx.core.model import Model
|
|
15
|
+
|
|
16
|
+
M = TypeVar("M", bound=Model)
|
|
17
|
+
|
|
18
|
+
class Variational(eqx.Module, Generic[M]):
|
|
19
|
+
"""
|
|
20
|
+
An abstract base class used to define variational methods.
|
|
21
|
+
|
|
22
|
+
# Attributes
|
|
23
|
+
- `dim`: The dimension of the support.
|
|
24
|
+
- `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
|
|
25
|
+
- `_static`: The static component of a partitioned `Model` used to initialize the `Variational` object.
|
|
26
|
+
"""
|
|
27
|
+
dim: int
|
|
28
|
+
_unflatten: Callable[[Array], M]
|
|
29
|
+
_static: M
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def filter_spec(self) -> Self:
|
|
34
|
+
"""
|
|
35
|
+
Filter specification for dynamic and static components of the
|
|
36
|
+
`Variational` object.
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def sample(self, n: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Array:
|
|
42
|
+
"""
|
|
43
|
+
Sample from the variational distribution.
|
|
44
|
+
"""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def eval(self, draws: Array) -> Array:
|
|
49
|
+
"""
|
|
50
|
+
Evaluate the variational distribution at `draws`.
|
|
51
|
+
"""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def elbo(self, n: int, batch_size: int, key: PRNGKeyArray) -> Array:
|
|
56
|
+
"""
|
|
57
|
+
Evaluate the ELBO.
|
|
58
|
+
"""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> M:
|
|
63
|
+
"""
|
|
64
|
+
Evaluate the gradient of the ELBO.
|
|
65
|
+
"""
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def elbo_and_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Tuple[Scalar, PyTree]:
|
|
70
|
+
"""
|
|
71
|
+
Evaluate the ELBO and its gradient.
|
|
72
|
+
"""
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
@eqx.filter_jit
|
|
76
|
+
def reconstruct_model(self, draw: Array) -> M:
|
|
77
|
+
# Unflatten variational draw
|
|
78
|
+
model: M = self._unflatten(draw)
|
|
79
|
+
|
|
80
|
+
# Combine with static components
|
|
81
|
+
model: M = eqx.combine(model, self._static)
|
|
82
|
+
|
|
83
|
+
return model
|
|
84
|
+
|
|
85
|
+
@eqx.filter_jit
|
|
86
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
|
87
|
+
def eval_model(self, draws: Array) -> Array:
|
|
88
|
+
"""
|
|
89
|
+
Reconstruct models from variational draws and evaluate their posterior.
|
|
90
|
+
|
|
91
|
+
# Parameters
|
|
92
|
+
- `draws`: A set of variational draws.
|
|
93
|
+
"""
|
|
94
|
+
# Unflatten variational draw
|
|
95
|
+
model: M = self.reconstruct_model(draws)
|
|
96
|
+
|
|
97
|
+
# Evaluate posterior
|
|
98
|
+
return model()
|
|
99
|
+
|
|
100
|
+
@eqx.filter_jit
|
|
101
|
+
def fit(
|
|
102
|
+
self,
|
|
103
|
+
max_iters: int,
|
|
104
|
+
learning_rate: float,
|
|
105
|
+
tolerance: float,
|
|
106
|
+
grad_draws: int,
|
|
107
|
+
batch_size: int,
|
|
108
|
+
key: PRNGKeyArray = jr.key(0),
|
|
109
|
+
) -> Self:
|
|
110
|
+
"""
|
|
111
|
+
Optimize the variational distribution.
|
|
112
|
+
"""
|
|
113
|
+
# Partition variational
|
|
114
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
|
115
|
+
|
|
116
|
+
# Construct scheduler
|
|
117
|
+
schedule = opx.warmup_cosine_decay_schedule(
|
|
118
|
+
1e-8,
|
|
119
|
+
learning_rate,
|
|
120
|
+
int(max_iters / 10),
|
|
121
|
+
max_iters - int(max_iters / 10)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Initialize optimizer
|
|
125
|
+
optim: GradientTransformation = opx.chain(
|
|
126
|
+
opx.scale(-1.0), opx.adam(schedule) # replace learning_rate with scheduler
|
|
127
|
+
)
|
|
128
|
+
opt_state: OptState = optim.init(dyn)
|
|
129
|
+
|
|
130
|
+
def condition(state: Tuple[Self, OptState, Scalar, PRNGKeyArray]):
|
|
131
|
+
# Unpack iteration state
|
|
132
|
+
dyn, opt_state, i, key = state
|
|
133
|
+
|
|
134
|
+
return i < max_iters
|
|
135
|
+
|
|
136
|
+
def body(state: Tuple[Self, OptState, Scalar, PRNGKeyArray]):
|
|
137
|
+
# Unpack iteration state
|
|
138
|
+
dyn, opt_state, i, key = state
|
|
139
|
+
|
|
140
|
+
# Update iteration
|
|
141
|
+
i = i + 1
|
|
142
|
+
|
|
143
|
+
# Update PRNG key
|
|
144
|
+
key, _ = jr.split(key)
|
|
145
|
+
|
|
146
|
+
# Reconstruct variational
|
|
147
|
+
vari: Self = eqx.combine(dyn, static)
|
|
148
|
+
|
|
149
|
+
# Compute gradient of the ELBO for update
|
|
150
|
+
update: M = vari.elbo_grad(grad_draws, batch_size, key)
|
|
151
|
+
|
|
152
|
+
# Transform update through optimizer
|
|
153
|
+
update, opt_state = optim.update( # type: ignore
|
|
154
|
+
update, opt_state, eqx.filter(dyn, dyn.filter_spec) # type: ignore
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Update variational distribution
|
|
158
|
+
dyn: Self = eqx.apply_updates(dyn, update)
|
|
159
|
+
|
|
160
|
+
return dyn, opt_state, i, key
|
|
161
|
+
|
|
162
|
+
# Run optimization loop
|
|
163
|
+
dyn = lax.while_loop(
|
|
164
|
+
cond_fun=condition,
|
|
165
|
+
body_fun=body,
|
|
166
|
+
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
|
167
|
+
)[0]
|
|
168
|
+
|
|
169
|
+
# Return optimized variational
|
|
170
|
+
return eqx.combine(dyn, static)
|
bayinx/dists/__init__.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
from .bernoulli import Bernoulli as Bernoulli
|
|
2
|
+
from .binomial import Binomial as Binomial
|
|
3
|
+
from .exponential import Exponential as Exponential
|
|
4
|
+
from .normal import Normal as Normal
|
|
5
|
+
from .poisson import Poisson as Poisson
|