bayinx 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/core/constraints.py +7 -10
- bayinx/core/flow.py +4 -4
- bayinx/core/model.py +3 -3
- bayinx/core/variational.py +11 -11
- bayinx/dists/normal.py +6 -4
- bayinx/dists/uniform.py +4 -4
- bayinx/mhx/vi/flows/fullaffine.py +11 -17
- bayinx/mhx/vi/flows/planar.py +4 -4
- bayinx/mhx/vi/flows/radial.py +3 -3
- bayinx/mhx/vi/meanfield.py +6 -8
- bayinx/mhx/vi/normalizing_flow.py +22 -26
- {bayinx-0.2.27.dist-info → bayinx-0.2.28.dist-info}/METADATA +1 -1
- bayinx-0.2.28.dist-info/RECORD +27 -0
- bayinx/core/utils.py +0 -1
- bayinx-0.2.27.dist-info/RECORD +0 -28
- {bayinx-0.2.27.dist-info → bayinx-0.2.28.dist-info}/WHEEL +0 -0
bayinx/core/constraints.py
CHANGED
@@ -9,16 +9,12 @@ from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
|
|
9
9
|
class Constraint(eqx.Module):
|
10
10
|
"""
|
11
11
|
Abstract base class for defining parameter constraints.
|
12
|
-
|
13
|
-
Subclasses should implement the `constrain` method to apply the
|
14
|
-
transformation and compute the ladj adjustment.
|
15
12
|
"""
|
13
|
+
|
16
14
|
@abstractmethod
|
17
15
|
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
18
16
|
"""
|
19
|
-
Applies the constraining transformation to an unconstrained input
|
20
|
-
and computes the log absolute determinant of the Jacobian (ladj)
|
21
|
-
of this transformation.
|
17
|
+
Applies the constraining transformation to an unconstrained input and computes the log-absolute-jacobian of the transformation.
|
22
18
|
|
23
19
|
# Parameters
|
24
20
|
- `x`: The unconstrained JAX Array-like input.
|
@@ -26,7 +22,7 @@ class Constraint(eqx.Module):
|
|
26
22
|
# Returns
|
27
23
|
A tuple containing:
|
28
24
|
- The constrained JAX Array.
|
29
|
-
- A scalar JAX Array representing the
|
25
|
+
- A scalar JAX Array representing the laj of the transformation.
|
30
26
|
"""
|
31
27
|
pass
|
32
28
|
|
@@ -35,6 +31,7 @@ class LowerBound(Constraint):
|
|
35
31
|
"""
|
36
32
|
Enforces a lower bound on the parameter.
|
37
33
|
"""
|
34
|
+
|
38
35
|
lb: ScalarLike
|
39
36
|
|
40
37
|
def __init__(self, lb: ScalarLike):
|
@@ -42,7 +39,7 @@ class LowerBound(Constraint):
|
|
42
39
|
|
43
40
|
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
44
41
|
"""
|
45
|
-
Applies the lower bound constraint and computes the
|
42
|
+
Applies the lower bound constraint and computes the laj.
|
46
43
|
|
47
44
|
# Parameters
|
48
45
|
- `x`: The unconstrained JAX Array-like input.
|
@@ -50,10 +47,10 @@ class LowerBound(Constraint):
|
|
50
47
|
# Parameters
|
51
48
|
A tuple containing:
|
52
49
|
- The constrained JAX Array (x > self.lb).
|
53
|
-
- A scalar JAX Array representing the
|
50
|
+
- A scalar JAX Array representing the laj of the transformation.
|
54
51
|
"""
|
55
52
|
# Compute transformation adjustment
|
56
|
-
ladj = jnp.sum(x)
|
53
|
+
ladj: Scalar = jnp.sum(x)
|
57
54
|
|
58
55
|
# Compute transformation
|
59
56
|
x = jnp.exp(x) + self.lb
|
bayinx/core/flow.py
CHANGED
@@ -8,11 +8,11 @@ from jaxtyping import Array, Float
|
|
8
8
|
|
9
9
|
class Flow(eqx.Module):
|
10
10
|
"""
|
11
|
-
|
11
|
+
An abstract base class for a flow(of a normalizing flow).
|
12
12
|
|
13
13
|
# Attributes
|
14
14
|
- `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
|
15
|
-
- `constraints`: A dictionary of functions that constrain their corresponding parameter.
|
15
|
+
- `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
|
16
16
|
"""
|
17
17
|
|
18
18
|
params: Dict[str, Float[Array, "..."]]
|
@@ -28,10 +28,10 @@ class Flow(eqx.Module):
|
|
28
28
|
@abstractmethod
|
29
29
|
def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
|
30
30
|
"""
|
31
|
-
Computes the log-absolute-
|
31
|
+
Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
|
32
32
|
|
33
33
|
# Returns
|
34
|
-
A tuple of JAX Arrays containing the log-absolute-
|
34
|
+
A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
|
35
35
|
"""
|
36
36
|
pass
|
37
37
|
|
bayinx/core/model.py
CHANGED
@@ -11,11 +11,11 @@ from bayinx.core.constraints import Constraint
|
|
11
11
|
|
12
12
|
class Model(eqx.Module):
|
13
13
|
"""
|
14
|
-
|
14
|
+
An abstract base class used to define probabilistic models.
|
15
15
|
|
16
16
|
# Attributes
|
17
17
|
- `params`: A dictionary of JAX Arrays representing parameters of the model.
|
18
|
-
- `constraints`: A dictionary of
|
18
|
+
- `constraints`: A dictionary of constraints.
|
19
19
|
"""
|
20
20
|
|
21
21
|
params: Dict[str, Array]
|
@@ -63,7 +63,7 @@ class Model(eqx.Module):
|
|
63
63
|
|
64
64
|
return t_params, target
|
65
65
|
|
66
|
-
|
66
|
+
# Add default transform method
|
67
67
|
def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
|
68
68
|
"""
|
69
69
|
Apply a custom transformation to `params` if needed.
|
bayinx/core/variational.py
CHANGED
@@ -8,7 +8,7 @@ import jax.lax as lax
|
|
8
8
|
import jax.numpy as jnp
|
9
9
|
import jax.random as jr
|
10
10
|
import optax as opx
|
11
|
-
from jaxtyping import Array,
|
11
|
+
from jaxtyping import Array, Key, PyTree, Scalar
|
12
12
|
from optax import GradientTransformation, OptState, Schedule
|
13
13
|
|
14
14
|
from bayinx.core import Model
|
@@ -16,16 +16,23 @@ from bayinx.core import Model
|
|
16
16
|
|
17
17
|
class Variational(eqx.Module):
|
18
18
|
"""
|
19
|
-
|
19
|
+
An abstract base class used to define variational methods.
|
20
20
|
|
21
21
|
# Attributes
|
22
22
|
- `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
|
23
23
|
- `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
|
24
24
|
"""
|
25
25
|
|
26
|
-
_unflatten: Callable[[
|
26
|
+
_unflatten: Callable[[Array], Model]
|
27
27
|
_constraints: Model
|
28
28
|
|
29
|
+
@abstractmethod
|
30
|
+
def filter_spec(self):
|
31
|
+
"""
|
32
|
+
Filter specification for dynamic and static components of the `Variational`.
|
33
|
+
"""
|
34
|
+
pass
|
35
|
+
|
29
36
|
@abstractmethod
|
30
37
|
def sample(self, n: int, key: Key) -> Array:
|
31
38
|
"""
|
@@ -54,13 +61,6 @@ class Variational(eqx.Module):
|
|
54
61
|
"""
|
55
62
|
pass
|
56
63
|
|
57
|
-
@abstractmethod
|
58
|
-
def filter_spec(self):
|
59
|
-
"""
|
60
|
-
Filter specification for dynamic and static components of the `Variational`.
|
61
|
-
"""
|
62
|
-
pass
|
63
|
-
|
64
64
|
@eqx.filter_jit
|
65
65
|
@partial(jax.vmap, in_axes=(None, 0, None))
|
66
66
|
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
@@ -135,7 +135,7 @@ class Variational(eqx.Module):
|
|
135
135
|
# Update PRNG key
|
136
136
|
key, _ = jr.split(key)
|
137
137
|
|
138
|
-
#
|
138
|
+
# Reconstruct variational
|
139
139
|
vari = eqx.combine(dyn, static)
|
140
140
|
|
141
141
|
# Compute gradient of the ELBO
|
bayinx/dists/normal.py
CHANGED
@@ -19,7 +19,7 @@ def prob(
|
|
19
19
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
20
20
|
"""
|
21
21
|
|
22
|
-
return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / (
|
22
|
+
return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / ( # pyright: ignore
|
23
23
|
sigma * _lax.sqrt(2.0 * __PI)
|
24
24
|
)
|
25
25
|
|
@@ -39,7 +39,9 @@ def logprob(
|
|
39
39
|
The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
40
40
|
"""
|
41
41
|
|
42
|
-
return -_lax.log(sigma * _lax.sqrt(2.0 * __PI)) - 0.5 * _lax.square(
|
42
|
+
return -_lax.log(sigma * _lax.sqrt(2.0 * __PI)) - 0.5 * _lax.square(
|
43
|
+
(x - mu) / sigma # pyright: ignore
|
44
|
+
)
|
43
45
|
|
44
46
|
|
45
47
|
def uprob(
|
@@ -57,7 +59,7 @@ def uprob(
|
|
57
59
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
58
60
|
"""
|
59
61
|
|
60
|
-
return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma
|
62
|
+
return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma # pyright: ignore
|
61
63
|
|
62
64
|
|
63
65
|
def ulogprob(
|
@@ -75,4 +77,4 @@ def ulogprob(
|
|
75
77
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
76
78
|
"""
|
77
79
|
|
78
|
-
return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma)
|
80
|
+
return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma) # pyright: ignore
|
bayinx/dists/uniform.py
CHANGED
@@ -18,7 +18,7 @@ def prob(
|
|
18
18
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
19
19
|
"""
|
20
20
|
|
21
|
-
return 1.0 / (ub - lb)
|
21
|
+
return 1.0 / (ub - lb) # pyright: ignore
|
22
22
|
|
23
23
|
|
24
24
|
def logprob(
|
@@ -36,7 +36,7 @@ def logprob(
|
|
36
36
|
The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
37
37
|
"""
|
38
38
|
|
39
|
-
return _lax.log(1.0) - _lax.log(ub - lb)
|
39
|
+
return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
|
40
40
|
|
41
41
|
|
42
42
|
def uprob(
|
@@ -54,7 +54,7 @@ def uprob(
|
|
54
54
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
55
55
|
"""
|
56
56
|
|
57
|
-
return jnp.ones(jnp.broadcast_arrays(x,lb,ub))
|
57
|
+
return jnp.ones(jnp.broadcast_arrays(x, lb, ub))
|
58
58
|
|
59
59
|
|
60
60
|
def ulogprob(
|
@@ -72,4 +72,4 @@ def ulogprob(
|
|
72
72
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
73
73
|
"""
|
74
74
|
|
75
|
-
return jnp.zeros(jnp.broadcast_arrays(x,lb,ub))
|
75
|
+
return jnp.zeros(jnp.broadcast_arrays(x, lb, ub))
|
@@ -1,29 +1,26 @@
|
|
1
1
|
from functools import partial
|
2
|
-
from typing import
|
2
|
+
from typing import Tuple
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
-
from jaxtyping import Array,
|
7
|
+
from jaxtyping import Array, Scalar
|
8
8
|
|
9
9
|
from bayinx.core import Flow
|
10
10
|
|
11
11
|
|
12
12
|
class FullAffine(Flow):
|
13
13
|
"""
|
14
|
-
|
14
|
+
A full affine flow.
|
15
15
|
|
16
16
|
# Attributes
|
17
17
|
- `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
|
18
18
|
- `constraints`: A dictionary of constraining transformations.
|
19
19
|
"""
|
20
20
|
|
21
|
-
params: Dict[str, Float[Array, "..."]]
|
22
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
23
|
-
|
24
21
|
def __init__(self, dim: int):
|
25
22
|
"""
|
26
|
-
Initializes
|
23
|
+
Initializes a full affine flow.
|
27
24
|
|
28
25
|
# Parameters
|
29
26
|
- `dim`: The dimension of the parameter space.
|
@@ -35,21 +32,18 @@ class FullAffine(Flow):
|
|
35
32
|
|
36
33
|
self.constraints = {"scale": lambda m: jnp.tril(m)}
|
37
34
|
|
35
|
+
@eqx.filter_jit
|
38
36
|
def transform_pars(self):
|
39
|
-
# Get constrained parameters
|
40
37
|
params = self.constrain_pars()
|
41
38
|
|
42
39
|
# Extract diagonal and apply exponential
|
43
|
-
diag: Array = jnp.exp(jnp.diag(params[
|
40
|
+
diag: Array = jnp.exp(jnp.diag(params["scale"]))
|
44
41
|
|
45
42
|
# Fill diagonal
|
46
|
-
params[
|
47
|
-
|
43
|
+
params["scale"] = jnp.fill_diagonal(params["scale"], diag, inplace=False)
|
48
44
|
|
49
45
|
return params
|
50
46
|
|
51
|
-
|
52
|
-
|
53
47
|
@eqx.filter_jit
|
54
48
|
def forward(self, draws: Array) -> Array:
|
55
49
|
params = self.transform_pars()
|
@@ -65,7 +59,7 @@ class FullAffine(Flow):
|
|
65
59
|
|
66
60
|
@eqx.filter_jit
|
67
61
|
@partial(jax.vmap, in_axes=(None, 0))
|
68
|
-
def adjust_density(self, draws: Array) -> Tuple[
|
62
|
+
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
69
63
|
params = self.transform_pars()
|
70
64
|
|
71
65
|
# Extract parameters
|
@@ -75,7 +69,7 @@ class FullAffine(Flow):
|
|
75
69
|
# Compute forward transformation
|
76
70
|
draws = draws @ scale + shift
|
77
71
|
|
78
|
-
# Compute
|
79
|
-
|
72
|
+
# Compute laj
|
73
|
+
laj: Scalar = jnp.log(jnp.diag(scale)).sum()
|
80
74
|
|
81
|
-
return
|
75
|
+
return draws, laj
|
bayinx/mhx/vi/flows/planar.py
CHANGED
@@ -53,7 +53,7 @@ class Planar(Flow):
|
|
53
53
|
|
54
54
|
@eqx.filter_jit
|
55
55
|
@partial(jax.vmap, in_axes=(None, 0))
|
56
|
-
def adjust_density(self, draws: Array) -> Tuple[
|
56
|
+
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
57
57
|
params = self.transform_pars()
|
58
58
|
|
59
59
|
# Extract parameters
|
@@ -67,8 +67,8 @@ class Planar(Flow):
|
|
67
67
|
# Compute forward transformation
|
68
68
|
draws = draws + u * jnp.tanh(x)
|
69
69
|
|
70
|
-
# Compute
|
70
|
+
# Compute laj
|
71
71
|
h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
|
72
|
-
|
72
|
+
laj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
|
73
73
|
|
74
|
-
return
|
74
|
+
return draws, laj
|
bayinx/mhx/vi/flows/radial.py
CHANGED
@@ -66,7 +66,7 @@ class Radial(Flow):
|
|
66
66
|
|
67
67
|
@partial(jax.vmap, in_axes=(None, 0))
|
68
68
|
@eqx.filter_jit
|
69
|
-
def adjust_density(self, draws: Array) -> Tuple[
|
69
|
+
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
70
70
|
params = self.transform_pars()
|
71
71
|
|
72
72
|
# Extract parameters
|
@@ -84,11 +84,11 @@ class Radial(Flow):
|
|
84
84
|
draws = draws + (x) * (draws - center)
|
85
85
|
|
86
86
|
# Compute density adjustment
|
87
|
-
|
87
|
+
laj = jnp.log(
|
88
88
|
jnp.abs(
|
89
89
|
(1.0 + alpha * beta / (alpha + r) ** 2.0)
|
90
90
|
* (1.0 + x) ** (center.size - 1.0)
|
91
91
|
)
|
92
92
|
)
|
93
93
|
|
94
|
-
return
|
94
|
+
return draws, laj
|
bayinx/mhx/vi/meanfield.py
CHANGED
@@ -20,8 +20,8 @@ class MeanField(Variational):
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
var_params: Dict[str, Float[Array, "..."]]
|
23
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
|
-
_constraints: Model
|
23
|
+
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
|
+
_constraints: Model
|
25
25
|
|
26
26
|
def __init__(self, model: Model):
|
27
27
|
"""
|
@@ -63,25 +63,24 @@ class MeanField(Variational):
|
|
63
63
|
|
64
64
|
@eqx.filter_jit
|
65
65
|
def filter_spec(self):
|
66
|
+
# Generate empty specification
|
66
67
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
68
|
+
|
69
|
+
# Specify variational parameters
|
67
70
|
filter_spec = eqx.tree_at(
|
68
71
|
lambda mf: mf.var_params,
|
69
72
|
filter_spec,
|
70
73
|
replace=True,
|
71
74
|
)
|
75
|
+
|
72
76
|
return filter_spec
|
73
77
|
|
74
78
|
@eqx.filter_jit
|
75
79
|
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
76
|
-
"""
|
77
|
-
Estimate the ELBO and its gradient(w.r.t the variational parameters).
|
78
|
-
"""
|
79
|
-
# Partition variational
|
80
80
|
dyn, static = eqx.partition(self, self.filter_spec())
|
81
81
|
|
82
82
|
@eqx.filter_jit
|
83
83
|
def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
|
84
|
-
# Combine
|
85
84
|
vari = eqx.combine(dyn, static)
|
86
85
|
|
87
86
|
# Sample draws from variational distribution
|
@@ -100,7 +99,6 @@ class MeanField(Variational):
|
|
100
99
|
|
101
100
|
@eqx.filter_jit
|
102
101
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
103
|
-
# Partition
|
104
102
|
dyn, static = eqx.partition(self, self.filter_spec())
|
105
103
|
|
106
104
|
@eqx.filter_grad
|
@@ -1,11 +1,11 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Self, Tuple
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.flatten_util as jfu
|
5
5
|
import jax.numpy as jnp
|
6
6
|
import jax.random as jr
|
7
7
|
import jax.tree_util as jtu
|
8
|
-
from jaxtyping import Array,
|
8
|
+
from jaxtyping import Array, Key, Scalar
|
9
9
|
|
10
10
|
from bayinx.core import Flow, Model, Variational
|
11
11
|
|
@@ -17,14 +17,11 @@ class NormalizingFlow(Variational):
|
|
17
17
|
|
18
18
|
# Attributes
|
19
19
|
- `base`: A base variational distribution.
|
20
|
-
- `flows`: An ordered collection of continuously parameterized
|
21
|
-
diffeomorphisms.
|
20
|
+
- `flows`: An ordered collection of continuously parameterized diffeomorphisms.
|
22
21
|
"""
|
23
22
|
|
24
23
|
flows: list[Flow]
|
25
24
|
base: Variational
|
26
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
27
|
-
_constraints: Model
|
28
25
|
|
29
26
|
def __init__(self, base: Variational, flows: list[Flow], model: Model):
|
30
27
|
"""
|
@@ -44,6 +41,19 @@ class NormalizingFlow(Variational):
|
|
44
41
|
self.base = base
|
45
42
|
self.flows = flows
|
46
43
|
|
44
|
+
def filter_spec(self):
|
45
|
+
# Generate empty specification
|
46
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
47
|
+
|
48
|
+
# Specify variational parameters based on each flow's filter spec.
|
49
|
+
filter_spec = eqx.tree_at(
|
50
|
+
lambda vari: vari.flows,
|
51
|
+
filter_spec,
|
52
|
+
replace=[flow.filter_spec() for flow in self.flows],
|
53
|
+
)
|
54
|
+
|
55
|
+
return filter_spec
|
56
|
+
|
47
57
|
@eqx.filter_jit
|
48
58
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)):
|
49
59
|
"""
|
@@ -65,19 +75,18 @@ class NormalizingFlow(Variational):
|
|
65
75
|
|
66
76
|
for map in self.flows:
|
67
77
|
# Compute adjustment
|
68
|
-
|
78
|
+
laj, draws = map.adjust_density(draws)
|
69
79
|
|
70
80
|
# Adjust variational density
|
71
|
-
variational_evals = variational_evals -
|
81
|
+
variational_evals = variational_evals - laj
|
72
82
|
|
73
83
|
return variational_evals
|
74
84
|
|
75
85
|
@eqx.filter_jit
|
76
86
|
def __eval(self, draws: Array, data=None) -> Tuple[Array, Array]:
|
77
87
|
"""
|
78
|
-
Evaluate the posterior and variational densities at the
|
79
|
-
`draws` to avoid extra compute
|
80
|
-
the posterior evaluation.
|
88
|
+
Evaluate the posterior and variational densities together at the
|
89
|
+
transformed `draws` to avoid extra compute.
|
81
90
|
|
82
91
|
# Parameters
|
83
92
|
- `draws`: Draws from the base variational distribution.
|
@@ -91,29 +100,16 @@ class NormalizingFlow(Variational):
|
|
91
100
|
|
92
101
|
for map in self.flows:
|
93
102
|
# Compute adjustment
|
94
|
-
|
103
|
+
draws, laj = map.adjust_density(draws)
|
95
104
|
|
96
105
|
# Adjust variational density
|
97
|
-
variational_evals = variational_evals -
|
106
|
+
variational_evals = variational_evals - laj
|
98
107
|
|
99
108
|
# Evaluate posterior at final variational draws
|
100
109
|
posterior_evals = self.eval_model(draws, data)
|
101
110
|
|
102
111
|
return posterior_evals, variational_evals
|
103
112
|
|
104
|
-
def filter_spec(self):
|
105
|
-
# Generate empty specification
|
106
|
-
filter_spec = jtu.tree_map(lambda _: False, self)
|
107
|
-
|
108
|
-
# Specify variational parameters based on each flow's filter spec.
|
109
|
-
filter_spec = eqx.tree_at(
|
110
|
-
lambda vari: vari.flows,
|
111
|
-
filter_spec,
|
112
|
-
replace=[flow.filter_spec() for flow in self.flows],
|
113
|
-
)
|
114
|
-
|
115
|
-
return filter_spec
|
116
|
-
|
117
113
|
@eqx.filter_jit
|
118
114
|
def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
|
119
115
|
dyn, static = eqx.partition(self, self.filter_spec())
|
@@ -0,0 +1,27 @@
|
|
1
|
+
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
|
+
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
+
bayinx/core/constraints.py,sha256=lbVs2-xjGRue17YRPGHz3s_mJ0ZiunpYowbD0QvcD-I,1525
|
5
|
+
bayinx/core/flow.py,sha256=A5Vw5t76LPasnMgghjw6ulBkIm5L2jBprusVt-tuwko,2296
|
6
|
+
bayinx/core/model.py,sha256=Z_HaFr0_-keMjG5tg3xxP3hGML7aDFIcCI8Y5dGrtM4,2145
|
7
|
+
bayinx/core/variational.py,sha256=W0747jfVJFAtMZqL3mpbtl2wfnARHln-dVBag4xZ23Y,4813
|
8
|
+
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
10
|
+
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
+
bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
+
bayinx/dists/normal.py,sha256=3CXSgHWnuglmP8cKVUh2Yt4Rb9_LR_mwPRXDm_LuSRo,2679
|
14
|
+
bayinx/dists/uniform.py,sha256=mogFe8VuDelM9KXE6RxGek0-tuZYFrwmo_oMOPHXleA,2359
|
15
|
+
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
16
|
+
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
17
|
+
bayinx/mhx/vi/meanfield.py,sha256=8hM1KZ52TpRPLwiQcowsJLlQ-5nJzUEcKrtDiGrFoSs,3732
|
18
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=FvxDtqGRtaEeeF-bXCYnIEAvOOXVHKUK0oCTF9ma02Y,4622
|
19
|
+
bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
|
20
|
+
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
21
|
+
bayinx/mhx/vi/flows/fullaffine.py,sha256=Kvaa8epqaqz9tdMCnf9T_-2P3Bh_TkhA6NrilKHY93A,1886
|
22
|
+
bayinx/mhx/vi/flows/planar.py,sha256=WVj-oxcRctuoRA6KJjU63ek1ZgKNG2vI-TLN0QqjtKA,1916
|
23
|
+
bayinx/mhx/vi/flows/radial.py,sha256=Obj3SraliawIHmP14F9wRpWt34y3kscY--Izy24eCvM,2499
|
24
|
+
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
25
|
+
bayinx-0.2.28.dist-info/METADATA,sha256=xe3Wlo3UlD3VuTc42ChwnPTL6lp3BZmxnuf0gnZxWv0,3058
|
26
|
+
bayinx-0.2.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
27
|
+
bayinx-0.2.28.dist-info/RECORD,,
|
bayinx/core/utils.py
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
|
bayinx-0.2.27.dist-info/RECORD
DELETED
@@ -1,28 +0,0 @@
|
|
1
|
-
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
|
-
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
-
bayinx/core/constraints.py,sha256=Y8FJX3CkgnLQ3HXuTPGuzvLtXlKs0B7z0-YymoHgdfg,1682
|
5
|
-
bayinx/core/flow.py,sha256=9swS5wh7AsIZWgG_IhQS-upcPlr7G-juaP_5rsbX6G0,2363
|
6
|
-
bayinx/core/model.py,sha256=U1xBnAXnIvFJjWF-XIT8BYjP9PtoRZY_PwyhRwJf-HA,2144
|
7
|
-
bayinx/core/utils.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
8
|
-
bayinx/core/variational.py,sha256=vUZ6u5CXCHfs6ZrA8PF4PHfmUXHTK2RJGHyZ3afFfsg,4820
|
9
|
-
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
11
|
-
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
bayinx/dists/normal.py,sha256=rtSDi0NAObH1LGRWiPZk_6cbSVv2dOPHkgxtWn6gFgM,2662
|
15
|
-
bayinx/dists/uniform.py,sha256=PSZIIc2QfNF5XYi-TLGltnr_vnAIA-MZsn1rKV8QXAo,2353
|
16
|
-
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
17
|
-
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
18
|
-
bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
|
19
|
-
bayinx/mhx/vi/normalizing_flow.py,sha256=nj7bpIoMJl6GTOXPxQCAsPArchbHd5vwwqMm3cLbMII,4791
|
20
|
-
bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
|
21
|
-
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
22
|
-
bayinx/mhx/vi/flows/fullaffine.py,sha256=nXcPTZ_GSxIg7tmVxag694Fl1F95SKFSyDyt-9EDC9I,2049
|
23
|
-
bayinx/mhx/vi/flows/planar.py,sha256=u9ZVwEeOv4fMfwiORlseCz463atsWMuid_9crRg05Z8,1919
|
24
|
-
bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
|
25
|
-
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
26
|
-
bayinx-0.2.27.dist-info/METADATA,sha256=5RPhGKmb6wWJquxrUlyt6QXWTSPEQycu5nFVZmQN9bU,3058
|
27
|
-
bayinx-0.2.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
28
|
-
bayinx-0.2.27.dist-info/RECORD,,
|
File without changes
|