bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +4 -3
- bayinx/constraints/identity.py +26 -0
- bayinx/constraints/interval.py +62 -0
- bayinx/constraints/lower.py +31 -24
- bayinx/constraints/upper.py +57 -0
- bayinx/core/__init__.py +0 -7
- bayinx/core/constraint.py +32 -0
- bayinx/core/context.py +42 -0
- bayinx/core/distribution.py +34 -0
- bayinx/core/flow.py +99 -0
- bayinx/core/model.py +228 -0
- bayinx/core/node.py +201 -0
- bayinx/core/types.py +17 -0
- bayinx/core/utils.py +109 -0
- bayinx/core/variational.py +170 -0
- bayinx/dists/__init__.py +5 -3
- bayinx/dists/bernoulli.py +180 -11
- bayinx/dists/binomial.py +215 -0
- bayinx/dists/exponential.py +211 -0
- bayinx/dists/normal.py +131 -59
- bayinx/dists/poisson.py +203 -0
- bayinx/flows/__init__.py +5 -0
- bayinx/flows/diagaffine.py +120 -0
- bayinx/flows/fullaffine.py +123 -0
- bayinx/flows/lowrankaffine.py +165 -0
- bayinx/flows/planar.py +155 -0
- bayinx/flows/radial.py +1 -0
- bayinx/flows/sylvester.py +225 -0
- bayinx/nodes/__init__.py +3 -0
- bayinx/nodes/continuous.py +64 -0
- bayinx/nodes/observed.py +36 -0
- bayinx/nodes/stochastic.py +25 -0
- bayinx/ops.py +104 -0
- bayinx/posterior.py +220 -0
- bayinx/vi/__init__.py +0 -0
- bayinx/{mhx/vi → vi}/meanfield.py +33 -29
- bayinx/vi/normalizing_flow.py +246 -0
- bayinx/vi/standard.py +95 -0
- bayinx-0.5.3.dist-info/METADATA +93 -0
- bayinx-0.5.3.dist-info/RECORD +44 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
- bayinx/core/_constraint.py +0 -28
- bayinx/core/_flow.py +0 -80
- bayinx/core/_model.py +0 -98
- bayinx/core/_parameter.py +0 -44
- bayinx/core/_variational.py +0 -181
- bayinx/dists/censored/__init__.py +0 -3
- bayinx/dists/censored/gamma2/__init__.py +0 -3
- bayinx/dists/censored/gamma2/r.py +0 -68
- bayinx/dists/censored/posnormal/__init__.py +0 -3
- bayinx/dists/censored/posnormal/r.py +0 -116
- bayinx/dists/gamma2.py +0 -49
- bayinx/dists/posnormal.py +0 -260
- bayinx/dists/uniform.py +0 -75
- bayinx/mhx/__init__.py +0 -1
- bayinx/mhx/vi/__init__.py +0 -5
- bayinx/mhx/vi/flows/__init__.py +0 -3
- bayinx/mhx/vi/flows/fullaffine.py +0 -75
- bayinx/mhx/vi/flows/planar.py +0 -74
- bayinx/mhx/vi/flows/radial.py +0 -94
- bayinx/mhx/vi/flows/sylvester.py +0 -19
- bayinx/mhx/vi/normalizing_flow.py +0 -149
- bayinx/mhx/vi/standard.py +0 -63
- bayinx-0.3.10.dist-info/METADATA +0 -39
- bayinx-0.3.10.dist-info/RECORD +0 -35
- /bayinx/{py.typed → flows/otflow.py} +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import jax.tree as jt
|
|
6
|
+
import numpy as np
|
|
7
|
+
from jaxtyping import Array, PyTree
|
|
8
|
+
|
|
9
|
+
from bayinx.constraints import Identity
|
|
10
|
+
from bayinx.core.constraint import Constraint
|
|
11
|
+
from bayinx.core.types import T
|
|
12
|
+
from bayinx.nodes.stochastic import Stochastic
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def is_float_like(element: Any) -> bool:
|
|
16
|
+
"""
|
|
17
|
+
Check if `element` is float-like.
|
|
18
|
+
|
|
19
|
+
The structure of this function is borrowed from the `Equinox` library.
|
|
20
|
+
"""
|
|
21
|
+
if hasattr(element, "__jax_array__"):
|
|
22
|
+
element = element.__jax_array__()
|
|
23
|
+
if isinstance(element, (np.ndarray, np.generic)):
|
|
24
|
+
return bool(np.issubdtype(element.dtype, np.floating))
|
|
25
|
+
elif isinstance(element, Array):
|
|
26
|
+
return jnp.issubdtype(element.dtype, jnp.floating)
|
|
27
|
+
else:
|
|
28
|
+
return isinstance(element, float)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Continuous(Stochastic[T]):
|
|
32
|
+
"""
|
|
33
|
+
A container for continuous stochastic nodes of a probabilistic model.
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Attributes
|
|
37
|
+
- `obj`: An internal realization of the node.
|
|
38
|
+
- `_filter_spec`: An internal filter specification for `obj`.
|
|
39
|
+
- `_constraint`: A constraining transformation.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
_constraint: Constraint
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
obj: T,
|
|
48
|
+
constraint: Constraint = Identity(),
|
|
49
|
+
filter_spec: Optional[PyTree] = None
|
|
50
|
+
):
|
|
51
|
+
if filter_spec is None: # Default filter specification
|
|
52
|
+
# Generate empty specification
|
|
53
|
+
filter_spec = jt.map(lambda _: False, obj)
|
|
54
|
+
|
|
55
|
+
# Specify float-like leaves
|
|
56
|
+
filter_spec = eqx.tree_at(
|
|
57
|
+
where=lambda obj: obj,
|
|
58
|
+
pytree=filter_spec,
|
|
59
|
+
replace=jt.map(is_float_like, obj),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.obj = obj
|
|
63
|
+
self._filter_spec = filter_spec
|
|
64
|
+
self._constraint = constraint
|
bayinx/nodes/observed.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax.tree as jt
|
|
5
|
+
from jaxtyping import PyTree
|
|
6
|
+
|
|
7
|
+
from bayinx.core.node import Node
|
|
8
|
+
from bayinx.core.types import T
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Observed(Node[T]):
|
|
12
|
+
"""
|
|
13
|
+
A container for observed nodes of a probabilistic model.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Attributes
|
|
17
|
+
- `obj`: An internal realization of the node.
|
|
18
|
+
- `_filter_spec`: An internal filter specification for `obj`.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self, obj: T, filter_spec: Optional[PyTree] = None
|
|
23
|
+
):
|
|
24
|
+
if filter_spec is None: # Default filter specification
|
|
25
|
+
# Generate empty specification
|
|
26
|
+
filter_spec = jt.map(lambda _: False, obj)
|
|
27
|
+
|
|
28
|
+
# Specify array-like leaves
|
|
29
|
+
filter_spec = eqx.tree_at(
|
|
30
|
+
where=lambda obj: obj,
|
|
31
|
+
pytree=filter_spec,
|
|
32
|
+
replace=jt.map(eqx.is_array_like, obj),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
self.obj = obj
|
|
36
|
+
self._filter_spec = filter_spec
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from bayinx.core.node import Node
|
|
5
|
+
from bayinx.core.types import T
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Stochastic(Node[T]):
|
|
9
|
+
"""
|
|
10
|
+
A container for stochastic (unobserved) nodes of a probabilistic model.
|
|
11
|
+
|
|
12
|
+
Subclasses can be constructed with defined filter specifications (implement the `filter_spec` property).
|
|
13
|
+
|
|
14
|
+
# Attributes
|
|
15
|
+
- `obj`: An internal realization of the node.
|
|
16
|
+
- `_filter_spec`: An internal filter specification for `obj`.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
obj: T,
|
|
23
|
+
filter_spec: Optional[T],
|
|
24
|
+
):
|
|
25
|
+
pass
|
bayinx/ops.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import jax.tree as jt
|
|
3
|
+
from jaxtyping import Array, ArrayLike, Real
|
|
4
|
+
|
|
5
|
+
from bayinx.core.node import Node
|
|
6
|
+
from bayinx.core.utils import _extract_obj
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def exp(node: Node) -> Node:
|
|
10
|
+
"""
|
|
11
|
+
Apply the exponential transformation (jnp.exp) to a node.
|
|
12
|
+
"""
|
|
13
|
+
obj, filter_spec = _extract_obj(node)
|
|
14
|
+
|
|
15
|
+
# Helper function for the single-leaf exponential transform
|
|
16
|
+
def leaf_exp(x: Real[ArrayLike, "..."]) -> Array:
|
|
17
|
+
return jnp.exp(x)
|
|
18
|
+
|
|
19
|
+
# Apply exponential
|
|
20
|
+
new_obj = jt.map(leaf_exp, obj)
|
|
21
|
+
|
|
22
|
+
return type(node)(new_obj, filter_spec)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def log(node: Node) -> Node:
|
|
26
|
+
"""
|
|
27
|
+
Apply the natural logarithm transformation (jnp.log) to a node.
|
|
28
|
+
Handles input value restrictions (must be positive).
|
|
29
|
+
"""
|
|
30
|
+
obj, filter_spec = _extract_obj(node)
|
|
31
|
+
|
|
32
|
+
# Helper function for the single-leaf log transform
|
|
33
|
+
def leaf_log(x: Real[ArrayLike, "..."]) -> Array:
|
|
34
|
+
return jnp.log(x)
|
|
35
|
+
|
|
36
|
+
# Apply logarithm
|
|
37
|
+
new_obj = jt.map(leaf_log, obj)
|
|
38
|
+
|
|
39
|
+
return type(node)(new_obj, filter_spec)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def sin(node: Node) -> Node:
|
|
43
|
+
"""
|
|
44
|
+
Apply the sine transformation (jnp.sin) to a node.
|
|
45
|
+
"""
|
|
46
|
+
obj, filter_spec = _extract_obj(node)
|
|
47
|
+
|
|
48
|
+
# Helper function for the single-leaf sine transform
|
|
49
|
+
def leaf_sin(x: Real[ArrayLike, "..."]) -> Array:
|
|
50
|
+
return jnp.sin(x)
|
|
51
|
+
|
|
52
|
+
# Apply sine
|
|
53
|
+
new_obj = jt.map(leaf_sin, obj)
|
|
54
|
+
|
|
55
|
+
return type(node)(new_obj, filter_spec)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def cos(node: Node) -> Node:
|
|
59
|
+
"""
|
|
60
|
+
Apply the cosine transformation (jnp.cos) to a node.
|
|
61
|
+
"""
|
|
62
|
+
obj, filter_spec = _extract_obj(node)
|
|
63
|
+
|
|
64
|
+
# Helper function for the single-leaf cosine transform
|
|
65
|
+
def leaf_cos(x: Real[ArrayLike, "..."]) -> Array:
|
|
66
|
+
return jnp.cos(x)
|
|
67
|
+
|
|
68
|
+
# Apply cosine
|
|
69
|
+
new_obj = jt.map(leaf_cos, obj)
|
|
70
|
+
|
|
71
|
+
return type(node)(new_obj, filter_spec)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def tanh(node: Node) -> Node:
|
|
75
|
+
"""
|
|
76
|
+
Apply the hyperbolic tangent transformation (jnp.tanh) to a node.
|
|
77
|
+
"""
|
|
78
|
+
obj, filter_spec = _extract_obj(node)
|
|
79
|
+
|
|
80
|
+
# Helper function for the single-leaf tanh transform
|
|
81
|
+
def leaf_tanh(x: Real[ArrayLike, "..."]) -> Array:
|
|
82
|
+
return jnp.tanh(x)
|
|
83
|
+
|
|
84
|
+
# Apply tanh
|
|
85
|
+
new_obj = jt.map(leaf_tanh, obj)
|
|
86
|
+
|
|
87
|
+
return type(node)(new_obj, filter_spec)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def sigmoid(node: Node) -> Node:
|
|
91
|
+
"""
|
|
92
|
+
Apply the sigmoid (logistic) transformation to a node.
|
|
93
|
+
Sigmoid formula: 1 / (1 + exp(-x))
|
|
94
|
+
"""
|
|
95
|
+
obj, filter_spec = _extract_obj(node)
|
|
96
|
+
|
|
97
|
+
# Helper function for the single-leaf sigmoid transform
|
|
98
|
+
def leaf_sigmoid(x: Real[ArrayLike, "..."]) -> Array:
|
|
99
|
+
return 1.0 / (1.0 + jnp.exp(-x)) # type: ignore
|
|
100
|
+
|
|
101
|
+
# Apply sigmoid
|
|
102
|
+
new_obj = jt.map(leaf_sigmoid, obj)
|
|
103
|
+
|
|
104
|
+
return type(node)(new_obj, filter_spec)
|
bayinx/posterior.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type
|
|
4
|
+
|
|
5
|
+
import equinox as eqx
|
|
6
|
+
import jax
|
|
7
|
+
import jax.random as jr
|
|
8
|
+
from jax.lax import scan
|
|
9
|
+
from jaxtyping import Array, PRNGKeyArray
|
|
10
|
+
|
|
11
|
+
from bayinx.core.flow import FlowSpec
|
|
12
|
+
from bayinx.core.node import Node
|
|
13
|
+
from bayinx.core.variational import M
|
|
14
|
+
from bayinx.vi.normalizing_flow import NormalizingFlow
|
|
15
|
+
from bayinx.vi.standard import Standard
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Posterior(Generic[M]):
|
|
19
|
+
"""
|
|
20
|
+
The posterior distribution for a model.
|
|
21
|
+
|
|
22
|
+
# Attributes
|
|
23
|
+
- `vari`: The variational approximation of the posterior.
|
|
24
|
+
- `config` The configuration for the posterior.
|
|
25
|
+
"""
|
|
26
|
+
vari: NormalizingFlow[M]
|
|
27
|
+
config: Dict[str, Any]
|
|
28
|
+
|
|
29
|
+
def __init__(self, model_def: Type[M], **kwargs: Any):
|
|
30
|
+
# (hopefully) omit intermediate model construction through jit
|
|
31
|
+
@eqx.filter_jit
|
|
32
|
+
def construct_base(model_def):
|
|
33
|
+
# Construct model
|
|
34
|
+
model = model_def(**kwargs)
|
|
35
|
+
|
|
36
|
+
return Standard(model)
|
|
37
|
+
|
|
38
|
+
# Construct standard normal base distribution
|
|
39
|
+
self.vari = construct_base(model_def)
|
|
40
|
+
|
|
41
|
+
# Include default attributes
|
|
42
|
+
self.config = {
|
|
43
|
+
"learning_rate": 0.1 / self.vari.dim**0.5,
|
|
44
|
+
"tolerance": 1e-4,
|
|
45
|
+
"grad_draws": 4,
|
|
46
|
+
"batch_size": 1
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def configure(
|
|
51
|
+
self,
|
|
52
|
+
flowspecs: Optional[List[FlowSpec]] = None,
|
|
53
|
+
learning_rate: Optional[float] = None,
|
|
54
|
+
tolerance: Optional[float] = None,
|
|
55
|
+
grad_draws: Optional[int] = None,
|
|
56
|
+
batch_size: Optional[int] = None
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Configure the variational approximation.
|
|
60
|
+
|
|
61
|
+
# Parameters
|
|
62
|
+
- `flowspecs`: The specification for a sequence of flows.
|
|
63
|
+
"""
|
|
64
|
+
# Append new NF architecture
|
|
65
|
+
if flowspecs is not None:
|
|
66
|
+
# Initialize NF architecture
|
|
67
|
+
flows = [
|
|
68
|
+
flowspec.construct(self.vari.dim) for flowspec in flowspecs
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
if isinstance(self.vari, Standard):
|
|
72
|
+
# Construct new normalizing flow
|
|
73
|
+
self.vari = NormalizingFlow(
|
|
74
|
+
base = self.vari,
|
|
75
|
+
flows = flows,
|
|
76
|
+
_static = self.vari._static,
|
|
77
|
+
_unflatten = self.vari._unflatten
|
|
78
|
+
)
|
|
79
|
+
elif isinstance(self.vari, NormalizingFlow):
|
|
80
|
+
# Freeze current flows
|
|
81
|
+
for flow in self.vari.flows:
|
|
82
|
+
object.__setattr__(flow, 'static', True) # kind of illegal but I need to avoid copies
|
|
83
|
+
|
|
84
|
+
# Append new flows
|
|
85
|
+
self.vari.flows.extend(flows)
|
|
86
|
+
|
|
87
|
+
# Include other settings
|
|
88
|
+
if learning_rate is not None:
|
|
89
|
+
self.config["learning_rate"] = learning_rate
|
|
90
|
+
if tolerance is not None:
|
|
91
|
+
self.config["tolerance"] = tolerance
|
|
92
|
+
if grad_draws is not None:
|
|
93
|
+
self.config["grad_draws"] = grad_draws
|
|
94
|
+
if batch_size is not None:
|
|
95
|
+
self.config["batch_size"] = batch_size
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def fit(
|
|
99
|
+
self,
|
|
100
|
+
max_iters: int = 50_000,
|
|
101
|
+
learning_rate: Optional[float] = None,
|
|
102
|
+
tolerance: Optional[float] = None,
|
|
103
|
+
grad_draws: Optional[int] = None,
|
|
104
|
+
batch_size: Optional[int] = None,
|
|
105
|
+
key: PRNGKeyArray = jr.key(0),
|
|
106
|
+
):
|
|
107
|
+
# Include settings
|
|
108
|
+
if learning_rate is not None:
|
|
109
|
+
self.config["learning_rate"] = learning_rate
|
|
110
|
+
if tolerance is not None:
|
|
111
|
+
self.config["tolerance"] = tolerance
|
|
112
|
+
if grad_draws is not None:
|
|
113
|
+
self.config["grad_draws"] = grad_draws
|
|
114
|
+
if batch_size is not None:
|
|
115
|
+
self.config["batch_size"] = batch_size
|
|
116
|
+
|
|
117
|
+
if isinstance(self.vari, Standard):
|
|
118
|
+
# Construct default sequence of optimization
|
|
119
|
+
print("TODO")
|
|
120
|
+
else:
|
|
121
|
+
# Optimize variational approximation with user-specified flows
|
|
122
|
+
self.vari = self.vari.fit(
|
|
123
|
+
max_iters,
|
|
124
|
+
self.config["learning_rate"],
|
|
125
|
+
self.config["tolerance"],
|
|
126
|
+
self.config["grad_draws"],
|
|
127
|
+
self.config["batch_size"],
|
|
128
|
+
key
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def sample(
|
|
132
|
+
self,
|
|
133
|
+
node: str,
|
|
134
|
+
n_draws: int,
|
|
135
|
+
batch_size: Optional[int] = None,
|
|
136
|
+
key: PRNGKeyArray = jr.key(0)
|
|
137
|
+
) -> Array:
|
|
138
|
+
"""
|
|
139
|
+
Sample a node from the posterior distribution.
|
|
140
|
+
|
|
141
|
+
# Parameters
|
|
142
|
+
- `node`: The name of the node.
|
|
143
|
+
- `n_draws`: The number of draws from the posterior.
|
|
144
|
+
- `batch_size`: The number of draws for the full model ever initialized in memory at once.
|
|
145
|
+
- `key`: The PRNG key.
|
|
146
|
+
"""
|
|
147
|
+
if batch_size is None:
|
|
148
|
+
batch_size = n_draws
|
|
149
|
+
|
|
150
|
+
# Split keys
|
|
151
|
+
keys = jr.split(key, n_draws // batch_size)
|
|
152
|
+
|
|
153
|
+
@partial(jax.vmap, in_axes = 0)
|
|
154
|
+
def reconstruct_and_subset(draw: Array):
|
|
155
|
+
model = self.vari.reconstruct_model(draw).constrain()[0]
|
|
156
|
+
|
|
157
|
+
return getattr(model, node).obj
|
|
158
|
+
|
|
159
|
+
def batched_sample(carry: None, key: PRNGKeyArray):
|
|
160
|
+
# Sample draws
|
|
161
|
+
draws = self.vari.sample(batch_size, key)
|
|
162
|
+
|
|
163
|
+
return None, reconstruct_and_subset(draws)
|
|
164
|
+
|
|
165
|
+
posterior_draws = scan(
|
|
166
|
+
batched_sample,
|
|
167
|
+
init=None,
|
|
168
|
+
xs=keys,
|
|
169
|
+
length=n_draws // batch_size
|
|
170
|
+
)[1].squeeze()
|
|
171
|
+
|
|
172
|
+
return posterior_draws
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def predictive(
|
|
176
|
+
self,
|
|
177
|
+
func: Callable[[M, PRNGKeyArray], Node[Array] | Array],
|
|
178
|
+
n_draws: int,
|
|
179
|
+
batch_size: Optional[int] = None,
|
|
180
|
+
key: PRNGKeyArray = jr.key(0)
|
|
181
|
+
) -> Array:
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
"""
|
|
185
|
+
if batch_size is None:
|
|
186
|
+
batch_size = n_draws
|
|
187
|
+
|
|
188
|
+
# Split keys
|
|
189
|
+
keys = jr.split(key, n_draws // batch_size)
|
|
190
|
+
|
|
191
|
+
@partial(jax.vmap, in_axes = (0, 0))
|
|
192
|
+
def reconstruct_and_predict(draw: Array, key: PRNGKeyArray) -> Array:
|
|
193
|
+
model = self.vari.reconstruct_model(draw).constrain()[0]
|
|
194
|
+
|
|
195
|
+
# Compute predictive
|
|
196
|
+
obj = func(model, key)
|
|
197
|
+
|
|
198
|
+
# Coerce from Node if needed
|
|
199
|
+
if isinstance(obj, Node):
|
|
200
|
+
obj: Array = obj.obj # type: ignore
|
|
201
|
+
|
|
202
|
+
return obj
|
|
203
|
+
|
|
204
|
+
def batched_sample(carry: None, key: PRNGKeyArray) -> Tuple[None, Array]:
|
|
205
|
+
# Sample draws
|
|
206
|
+
draws = self.vari.sample(batch_size, key)
|
|
207
|
+
|
|
208
|
+
# Generate additional keys for each draw
|
|
209
|
+
keys = jr.split(key, batch_size)
|
|
210
|
+
|
|
211
|
+
return None, reconstruct_and_predict(draws, keys)
|
|
212
|
+
|
|
213
|
+
posterior_draws: Array = scan(
|
|
214
|
+
batched_sample,
|
|
215
|
+
init=None,
|
|
216
|
+
xs=keys,
|
|
217
|
+
length=n_draws // batch_size
|
|
218
|
+
)[1].squeeze()
|
|
219
|
+
|
|
220
|
+
return posterior_draws
|
bayinx/vi/__init__.py
ADDED
|
File without changes
|
|
@@ -1,25 +1,28 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Generic, Self
|
|
2
2
|
|
|
3
3
|
import equinox as eqx
|
|
4
4
|
import jax.numpy as jnp
|
|
5
5
|
import jax.random as jr
|
|
6
6
|
import jax.tree_util as jtu
|
|
7
7
|
from jax.flatten_util import ravel_pytree
|
|
8
|
-
from jaxtyping import Array,
|
|
8
|
+
from jaxtyping import Array, PRNGKeyArray, Scalar
|
|
9
9
|
|
|
10
|
-
from bayinx.core import
|
|
10
|
+
from bayinx.core.variational import M, Variational
|
|
11
11
|
from bayinx.dists import normal
|
|
12
12
|
|
|
13
|
-
|
|
13
|
+
|
|
14
14
|
class MeanField(Variational, Generic[M]):
|
|
15
15
|
"""
|
|
16
16
|
A fully factorized Gaussian approximation to a posterior distribution.
|
|
17
17
|
|
|
18
18
|
# Attributes
|
|
19
|
-
- `
|
|
19
|
+
- `dim`: The dimension of the support.
|
|
20
|
+
- `mean`: The mean of the unconstrained approximation.
|
|
21
|
+
- `log_std` The log-transformed standard deviation of the unconstrained approximation.
|
|
20
22
|
"""
|
|
21
23
|
|
|
22
|
-
|
|
24
|
+
mean: Array
|
|
25
|
+
log_std: Array
|
|
23
26
|
|
|
24
27
|
def __init__(self, model: M):
|
|
25
28
|
"""
|
|
@@ -27,28 +30,31 @@ class MeanField(Variational, Generic[M]):
|
|
|
27
30
|
|
|
28
31
|
# Parameters
|
|
29
32
|
- `model`: A probabilistic `Model` object.
|
|
33
|
+
- `init_log_std`: The initial log-transformed standard deviation of the Gaussian approximation.
|
|
30
34
|
"""
|
|
31
35
|
# Partition model
|
|
32
|
-
params, self.
|
|
36
|
+
params, self._static = eqx.partition(model, model.filter_spec)
|
|
33
37
|
|
|
34
38
|
# Flatten params component
|
|
35
39
|
params, self._unflatten = ravel_pytree(params)
|
|
36
40
|
|
|
37
41
|
# Initialize variational parameters
|
|
38
|
-
self.
|
|
39
|
-
|
|
40
|
-
"log_std": jnp.zeros(params.size, dtype=params.dtype),
|
|
41
|
-
}
|
|
42
|
+
self.mean = params
|
|
43
|
+
self.log_std = jnp.full(params.size, 0.0)
|
|
42
44
|
|
|
43
45
|
@property
|
|
44
|
-
@eqx.filter_jit
|
|
45
46
|
def filter_spec(self):
|
|
46
47
|
# Generate empty specification
|
|
47
48
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
|
48
49
|
|
|
49
50
|
# Specify variational parameters
|
|
50
51
|
filter_spec = eqx.tree_at(
|
|
51
|
-
lambda mf: mf.
|
|
52
|
+
lambda mf: mf.mean,
|
|
53
|
+
filter_spec,
|
|
54
|
+
replace=True,
|
|
55
|
+
)
|
|
56
|
+
filter_spec = eqx.tree_at(
|
|
57
|
+
lambda mf: mf.log_std,
|
|
52
58
|
filter_spec,
|
|
53
59
|
replace=True,
|
|
54
60
|
)
|
|
@@ -56,12 +62,11 @@ class MeanField(Variational, Generic[M]):
|
|
|
56
62
|
return filter_spec
|
|
57
63
|
|
|
58
64
|
@eqx.filter_jit
|
|
59
|
-
def sample(self, n: int, key:
|
|
65
|
+
def sample(self, n: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Array:
|
|
60
66
|
# Sample variational draws
|
|
61
67
|
draws: Array = (
|
|
62
|
-
jr.normal(key=key, shape=(n, self.
|
|
63
|
-
|
|
64
|
-
+ self.var_params["mean"]
|
|
68
|
+
jr.normal(key=key, shape=(n, self.mean.size)) * jnp.exp(self.log_std)
|
|
69
|
+
+ self.mean
|
|
65
70
|
)
|
|
66
71
|
|
|
67
72
|
return draws
|
|
@@ -70,23 +75,23 @@ class MeanField(Variational, Generic[M]):
|
|
|
70
75
|
def eval(self, draws: Array) -> Array:
|
|
71
76
|
return normal.logprob(
|
|
72
77
|
x=draws,
|
|
73
|
-
mu=self.
|
|
74
|
-
sigma=jnp.exp(self.
|
|
78
|
+
mu=self.mean,
|
|
79
|
+
sigma=jnp.exp(self.log_std),
|
|
75
80
|
).sum(axis=1)
|
|
76
81
|
|
|
77
82
|
@eqx.filter_jit
|
|
78
|
-
def elbo(self, n: int,
|
|
83
|
+
def elbo(self, n: int, batch_size: int, key: PRNGKeyArray) -> Scalar:
|
|
79
84
|
dyn, static = eqx.partition(self, self.filter_spec)
|
|
80
85
|
|
|
81
86
|
@eqx.filter_jit
|
|
82
|
-
def elbo(dyn: Self, n: int, key:
|
|
87
|
+
def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
|
|
83
88
|
vari = eqx.combine(dyn, static)
|
|
84
89
|
|
|
85
90
|
# Sample draws from variational distribution
|
|
86
91
|
draws: Array = vari.sample(n, key)
|
|
87
92
|
|
|
88
93
|
# Evaluate posterior density for each draw
|
|
89
|
-
posterior_evals: Array = vari.eval_model(draws
|
|
94
|
+
posterior_evals: Array = vari.eval_model(draws)
|
|
90
95
|
|
|
91
96
|
# Evaluate variational density for each draw
|
|
92
97
|
variational_evals: Array = vari.eval(draws)
|
|
@@ -94,23 +99,22 @@ class MeanField(Variational, Generic[M]):
|
|
|
94
99
|
# Evaluate ELBO
|
|
95
100
|
return jnp.mean(posterior_evals - variational_evals)
|
|
96
101
|
|
|
97
|
-
return elbo(dyn, n, key
|
|
102
|
+
return elbo(dyn, n, key)
|
|
98
103
|
|
|
99
104
|
@eqx.filter_jit
|
|
100
|
-
def elbo_grad(self, n: int,
|
|
105
|
+
def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Self:
|
|
101
106
|
dyn, static = eqx.partition(self, self.filter_spec)
|
|
102
107
|
|
|
103
|
-
@eqx.filter_grad
|
|
104
108
|
@eqx.filter_jit
|
|
105
|
-
|
|
106
|
-
|
|
109
|
+
@eqx.filter_grad
|
|
110
|
+
def elbo_grad(dyn: Self, n: int, key: PRNGKeyArray):
|
|
107
111
|
vari = eqx.combine(dyn, static)
|
|
108
112
|
|
|
109
113
|
# Sample draws from variational distribution
|
|
110
114
|
draws: Array = vari.sample(n, key)
|
|
111
115
|
|
|
112
116
|
# Evaluate posterior density for each draw
|
|
113
|
-
posterior_evals: Array = vari.eval_model(draws
|
|
117
|
+
posterior_evals: Array = vari.eval_model(draws)
|
|
114
118
|
|
|
115
119
|
# Evaluate variational density for each draw
|
|
116
120
|
variational_evals: Array = vari.eval(draws)
|
|
@@ -118,4 +122,4 @@ class MeanField(Variational, Generic[M]):
|
|
|
118
122
|
# Evaluate ELBO
|
|
119
123
|
return jnp.mean(posterior_evals - variational_evals)
|
|
120
124
|
|
|
121
|
-
return elbo_grad(dyn, n, key
|
|
125
|
+
return elbo_grad(dyn, n, key)
|