bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +4 -3
- bayinx/constraints/identity.py +26 -0
- bayinx/constraints/interval.py +62 -0
- bayinx/constraints/lower.py +31 -24
- bayinx/constraints/upper.py +57 -0
- bayinx/core/__init__.py +0 -7
- bayinx/core/constraint.py +32 -0
- bayinx/core/context.py +42 -0
- bayinx/core/distribution.py +34 -0
- bayinx/core/flow.py +99 -0
- bayinx/core/model.py +228 -0
- bayinx/core/node.py +201 -0
- bayinx/core/types.py +17 -0
- bayinx/core/utils.py +109 -0
- bayinx/core/variational.py +170 -0
- bayinx/dists/__init__.py +5 -3
- bayinx/dists/bernoulli.py +180 -11
- bayinx/dists/binomial.py +215 -0
- bayinx/dists/exponential.py +211 -0
- bayinx/dists/normal.py +131 -59
- bayinx/dists/poisson.py +203 -0
- bayinx/flows/__init__.py +5 -0
- bayinx/flows/diagaffine.py +120 -0
- bayinx/flows/fullaffine.py +123 -0
- bayinx/flows/lowrankaffine.py +165 -0
- bayinx/flows/planar.py +155 -0
- bayinx/flows/radial.py +1 -0
- bayinx/flows/sylvester.py +225 -0
- bayinx/nodes/__init__.py +3 -0
- bayinx/nodes/continuous.py +64 -0
- bayinx/nodes/observed.py +36 -0
- bayinx/nodes/stochastic.py +25 -0
- bayinx/ops.py +104 -0
- bayinx/posterior.py +220 -0
- bayinx/vi/__init__.py +0 -0
- bayinx/{mhx/vi → vi}/meanfield.py +33 -29
- bayinx/vi/normalizing_flow.py +246 -0
- bayinx/vi/standard.py +95 -0
- bayinx-0.5.3.dist-info/METADATA +93 -0
- bayinx-0.5.3.dist-info/RECORD +44 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
- bayinx/core/_constraint.py +0 -28
- bayinx/core/_flow.py +0 -80
- bayinx/core/_model.py +0 -98
- bayinx/core/_parameter.py +0 -44
- bayinx/core/_variational.py +0 -181
- bayinx/dists/censored/__init__.py +0 -3
- bayinx/dists/censored/gamma2/__init__.py +0 -3
- bayinx/dists/censored/gamma2/r.py +0 -68
- bayinx/dists/censored/posnormal/__init__.py +0 -3
- bayinx/dists/censored/posnormal/r.py +0 -116
- bayinx/dists/gamma2.py +0 -49
- bayinx/dists/posnormal.py +0 -260
- bayinx/dists/uniform.py +0 -75
- bayinx/mhx/__init__.py +0 -1
- bayinx/mhx/vi/__init__.py +0 -5
- bayinx/mhx/vi/flows/__init__.py +0 -3
- bayinx/mhx/vi/flows/fullaffine.py +0 -75
- bayinx/mhx/vi/flows/planar.py +0 -74
- bayinx/mhx/vi/flows/radial.py +0 -94
- bayinx/mhx/vi/flows/sylvester.py +0 -19
- bayinx/mhx/vi/normalizing_flow.py +0 -149
- bayinx/mhx/vi/standard.py +0 -63
- bayinx-0.3.10.dist-info/METADATA +0 -39
- bayinx-0.3.10.dist-info/RECORD +0 -35
- /bayinx/{py.typed → flows/otflow.py} +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from typing import Callable, Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
from jaxtyping import Array, Float, PRNGKeyArray, Scalar
|
|
8
|
+
|
|
9
|
+
from bayinx.core.flow import FlowLayer, FlowSpec
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FullAffineLayer(FlowLayer):
|
|
13
|
+
"""
|
|
14
|
+
A full affine flow.
|
|
15
|
+
|
|
16
|
+
# Attributes
|
|
17
|
+
- `params`: The parameters of the full affine flow.
|
|
18
|
+
- `constraints`: The constraining transformations for the parameters of the full affine flow.
|
|
19
|
+
- `static`: Whether the flow layer is frozen (parameters are not subject to further optimization).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
params: Dict[str, Array]
|
|
23
|
+
constraints: Dict[str, Callable[[Array], Array]]
|
|
24
|
+
static: bool
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def __init__(self, dim: int, key: PRNGKeyArray):
|
|
28
|
+
"""
|
|
29
|
+
Initializes a full affine flow.
|
|
30
|
+
|
|
31
|
+
# Parameters
|
|
32
|
+
- `dim`: The dimension of the parameter space.
|
|
33
|
+
"""
|
|
34
|
+
self.static = False
|
|
35
|
+
|
|
36
|
+
# Split key
|
|
37
|
+
k1, k2 = jr.split(key)
|
|
38
|
+
|
|
39
|
+
# Initialize parameters
|
|
40
|
+
self.params = {
|
|
41
|
+
"shift": jr.normal(key, (dim, )) / dim**0.5,
|
|
42
|
+
"scale": jr.normal(key, (dim, dim)) / dim**0.5,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
# Define constraints
|
|
46
|
+
if dim == 1:
|
|
47
|
+
self.constraints = {"scale": jnp.exp}
|
|
48
|
+
else:
|
|
49
|
+
def constrain_scale(scale: Array):
|
|
50
|
+
# Extract diagonal and apply exponential
|
|
51
|
+
diag: Array = jnp.exp(jnp.diag(scale))
|
|
52
|
+
|
|
53
|
+
# Return matrix with modified diagonal
|
|
54
|
+
return jnp.fill_diagonal(jnp.tril(scale), diag, inplace=False)
|
|
55
|
+
|
|
56
|
+
self.constraints = {"scale": constrain_scale}
|
|
57
|
+
|
|
58
|
+
@eqx.filter_jit
|
|
59
|
+
def forward(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
|
|
60
|
+
params = self.transform_params()
|
|
61
|
+
|
|
62
|
+
# Extract parameters
|
|
63
|
+
shift: Float[Array, " n_dim"] = params["shift"]
|
|
64
|
+
scale: Float[Array, "n_dim n_dim"] = params["scale"]
|
|
65
|
+
|
|
66
|
+
# Compute forward transformation
|
|
67
|
+
draws = (scale @ draws.T).T + shift
|
|
68
|
+
|
|
69
|
+
return draws
|
|
70
|
+
|
|
71
|
+
def __adjust(self, draw: Float[Array, " n_dim"]) -> Float[Array, " n_dim"]:
|
|
72
|
+
params = self.transform_params()
|
|
73
|
+
|
|
74
|
+
# Extract parameters
|
|
75
|
+
scale: Float[Array, "n_dim n_dim"] = params["scale"]
|
|
76
|
+
|
|
77
|
+
# Compute log-Jacobian adjustments
|
|
78
|
+
lja: Array = jnp.log(jnp.diag(scale)).sum()
|
|
79
|
+
|
|
80
|
+
assert lja.shape == ()
|
|
81
|
+
|
|
82
|
+
return lja
|
|
83
|
+
|
|
84
|
+
@eqx.filter_jit
|
|
85
|
+
def adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
|
|
86
|
+
f = jax.vmap(self.__adjust, 0)
|
|
87
|
+
return f(draws)
|
|
88
|
+
|
|
89
|
+
def __forward_and_adjust(self, draw: Float[Array, " n_dim"]) -> Tuple[Float[Array, " n_dim"], Scalar]:
|
|
90
|
+
params = self.transform_params()
|
|
91
|
+
|
|
92
|
+
assert len(draw.shape) == 1
|
|
93
|
+
|
|
94
|
+
# Extract parameters
|
|
95
|
+
shift: Float[Array, " n_dim"] = params["shift"]
|
|
96
|
+
scale: Float[Array, "n_dim n_dim"] = params["scale"]
|
|
97
|
+
|
|
98
|
+
# Compute forward transformation
|
|
99
|
+
draw = (scale @ draw.T).T + shift
|
|
100
|
+
|
|
101
|
+
assert len(draw.shape) == 1
|
|
102
|
+
|
|
103
|
+
# Compute lja
|
|
104
|
+
lja: Scalar = jnp.log(jnp.diag(scale)).sum()
|
|
105
|
+
|
|
106
|
+
assert lja.shape == ()
|
|
107
|
+
|
|
108
|
+
return draw, lja
|
|
109
|
+
|
|
110
|
+
@eqx.filter_jit
|
|
111
|
+
def forward_and_adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Tuple[Float[Array, "n_draws n_dim"], Scalar]:
|
|
112
|
+
f = jax.vmap(self.__forward_and_adjust, 0)
|
|
113
|
+
return f(draws)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class FullAffine(FlowSpec):
|
|
117
|
+
key: PRNGKeyArray
|
|
118
|
+
|
|
119
|
+
def __init__(self, key: PRNGKeyArray = jr.key(0)):
|
|
120
|
+
self.key = key
|
|
121
|
+
|
|
122
|
+
def construct(self, dim: int) -> FullAffineLayer:
|
|
123
|
+
return FullAffineLayer(dim, self.key)
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
from jax.lax import scan
|
|
8
|
+
from jaxtyping import Array, Float, PRNGKeyArray, Scalar
|
|
9
|
+
|
|
10
|
+
from bayinx.core.flow import FlowLayer, FlowSpec
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _mat_op(
|
|
14
|
+
carry: Float[Array, " rank"],
|
|
15
|
+
vars: Tuple[
|
|
16
|
+
Float[Array, " rank"],
|
|
17
|
+
Float[Array, " rank"],
|
|
18
|
+
Scalar,
|
|
19
|
+
Scalar,
|
|
20
|
+
Scalar
|
|
21
|
+
]
|
|
22
|
+
) -> Tuple[Float[Array, " rank"], Scalar]:
|
|
23
|
+
"""
|
|
24
|
+
The implicit matrix operations compute:
|
|
25
|
+
|
|
26
|
+
x = Lz + b = (D + W)z + b = Dz + (UV^T * M)z + b
|
|
27
|
+
|
|
28
|
+
Where W = UV^T * M is the low-rank representation of the strictly lower triangular matrix W,
|
|
29
|
+
parameterized by U and V, and an implicit mask M that zeros out upper triangular elements.
|
|
30
|
+
|
|
31
|
+
In the implementation only U, V, diag(D), and b are needed to complete the matrix operations.
|
|
32
|
+
"""
|
|
33
|
+
# Unpack state
|
|
34
|
+
U_r, V_r, z_i, d_i, b_i = vars
|
|
35
|
+
h_i = carry
|
|
36
|
+
|
|
37
|
+
# Compute partial product
|
|
38
|
+
h_next = h_i + z_i * V_r
|
|
39
|
+
|
|
40
|
+
# Compute i-th element of transformed draw:
|
|
41
|
+
x_i = b_i + z_i * d_i + U_r.dot(h_i)
|
|
42
|
+
|
|
43
|
+
return h_next, x_i
|
|
44
|
+
|
|
45
|
+
class LowRankAffineLayer(FlowLayer):
|
|
46
|
+
"""
|
|
47
|
+
A low-rank affine flow.
|
|
48
|
+
|
|
49
|
+
# Attributes
|
|
50
|
+
- `params`: A dictionary containing the shift and low-rank representation of the scale parameters.
|
|
51
|
+
- `constraints`: A dictionary of constraining transformations.
|
|
52
|
+
- `static`: Whether the flow layer is frozen (parameters are not subject to further optimization).
|
|
53
|
+
- `rank`: Rank of the scale transformation.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
rank: int
|
|
57
|
+
|
|
58
|
+
def __init__(self, dim: int, rank: int, key: PRNGKeyArray = jr.key(0)):
|
|
59
|
+
"""
|
|
60
|
+
Initializes a low-rank affine flow.
|
|
61
|
+
|
|
62
|
+
# Parameters
|
|
63
|
+
- `dim`: The dimension of the parameter space.
|
|
64
|
+
- `rank`: The rank of the (implicit) scale matrix.
|
|
65
|
+
"""
|
|
66
|
+
self.static = False
|
|
67
|
+
self.rank = rank
|
|
68
|
+
|
|
69
|
+
# Split key
|
|
70
|
+
k1, k2, k3, k4 = jr.split(key, 4)
|
|
71
|
+
|
|
72
|
+
# Initialize parameters
|
|
73
|
+
self.params = {
|
|
74
|
+
"shift": jr.normal(k1, (dim, )) / dim**0.5,
|
|
75
|
+
"diag_scale": jr.normal(k2, (dim, )) / dim**0.5,
|
|
76
|
+
"offdiag_scale": (
|
|
77
|
+
jr.normal(k3, (dim, rank)) / dim**0.5,
|
|
78
|
+
jr.normal(k4, (dim, rank)) / dim**0.5
|
|
79
|
+
)
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
# Define constraints
|
|
83
|
+
self.constraints = {"diag_scale": jnp.exp}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def __forward(self, draw: Float[Array, " n_dim"]) -> Float[Array, " n_dim"]:
|
|
87
|
+
params = self.transform_params()
|
|
88
|
+
|
|
89
|
+
# Extract parameters
|
|
90
|
+
shift: Array = params["shift"]
|
|
91
|
+
diag: Array = params["diag_scale"]
|
|
92
|
+
U, V = params["offdiag_scale"]
|
|
93
|
+
|
|
94
|
+
# Compute forward transformation
|
|
95
|
+
_, draw = scan(
|
|
96
|
+
f=_mat_op,
|
|
97
|
+
init=jnp.zeros((self.rank, )),
|
|
98
|
+
xs= (U, V, draw, diag, shift)
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return draw
|
|
102
|
+
|
|
103
|
+
@eqx.filter_jit
|
|
104
|
+
def forward(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
|
|
105
|
+
f = jax.vmap(self.__forward, 0)
|
|
106
|
+
return f(draws)
|
|
107
|
+
|
|
108
|
+
def __adjust(self, draw: Float[Array, " n_dim"]) -> Scalar:
|
|
109
|
+
params = self.transform_params()
|
|
110
|
+
|
|
111
|
+
diag: Array = params["diag_scale"]
|
|
112
|
+
|
|
113
|
+
# Compute log-Jacobian adjustment
|
|
114
|
+
lja: Scalar = jnp.log(diag).sum()
|
|
115
|
+
|
|
116
|
+
assert lja.shape == ()
|
|
117
|
+
|
|
118
|
+
return lja
|
|
119
|
+
|
|
120
|
+
@eqx.filter_jit
|
|
121
|
+
def adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
|
|
122
|
+
f = jax.vmap(self.__adjust, 0)
|
|
123
|
+
return f(draws)
|
|
124
|
+
|
|
125
|
+
def __forward_and_adjust(self, draw: Float[Array, " n_dim"]) -> Tuple[Float[Array, " n_dim"], Scalar]:
|
|
126
|
+
params = self.transform_params()
|
|
127
|
+
|
|
128
|
+
assert len(draw.shape) == 1
|
|
129
|
+
|
|
130
|
+
# Extract parameters
|
|
131
|
+
shift: Array = params["shift"]
|
|
132
|
+
diag: Array = params["diag_scale"]
|
|
133
|
+
U, V = params["offdiag_scale"]
|
|
134
|
+
|
|
135
|
+
# Compute log-Jacobian adjustment
|
|
136
|
+
lja: Scalar = jnp.log(diag).sum()
|
|
137
|
+
|
|
138
|
+
# Compute forward transformation
|
|
139
|
+
_, draw = scan(
|
|
140
|
+
f=_mat_op,
|
|
141
|
+
init=jnp.zeros((self.rank, )),
|
|
142
|
+
xs= (U, V, draw, diag, shift)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
assert len(draw.shape) == 1
|
|
146
|
+
assert lja.shape == ()
|
|
147
|
+
|
|
148
|
+
return draw, lja
|
|
149
|
+
|
|
150
|
+
@eqx.filter_jit
|
|
151
|
+
def forward_and_adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Tuple[Float[Array, "n_draws n_dim"], Scalar]:
|
|
152
|
+
f = jax.vmap(self.__forward_and_adjust, 0)
|
|
153
|
+
return f(draws)
|
|
154
|
+
|
|
155
|
+
class LowRankAffine(FlowSpec):
|
|
156
|
+
rank: int
|
|
157
|
+
|
|
158
|
+
def __init__(self, rank: int):
|
|
159
|
+
self.rank = rank
|
|
160
|
+
|
|
161
|
+
def construct(self, dim: int) -> LowRankAffineLayer:
|
|
162
|
+
if (self.rank > (dim - 1)/2):
|
|
163
|
+
raise ValueError(f"Rank {self.rank} is large, consider using a full affine flow instead.")
|
|
164
|
+
|
|
165
|
+
return LowRankAffineLayer(dim, self.rank)
|
bayinx/flows/planar.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from typing import Callable, Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
from jaxtyping import Array, Float, PRNGKeyArray, Scalar
|
|
8
|
+
|
|
9
|
+
from bayinx.core.flow import FlowLayer, FlowSpec
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _h(x: Array) -> Array:
|
|
13
|
+
return jnp.tanh(x)
|
|
14
|
+
|
|
15
|
+
def _dh(x: Array) -> Array:
|
|
16
|
+
return 1.0 - jnp.tanh(x)**2
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PlanarLayer(FlowLayer):
|
|
20
|
+
"""
|
|
21
|
+
A Planar flow.
|
|
22
|
+
|
|
23
|
+
# Attributes
|
|
24
|
+
- `params`: The parameters of the Planar flow (w, u_hat, b).
|
|
25
|
+
- `constraints`: The constraining transformations for the parameters.
|
|
26
|
+
- `static`: Whether the flow layer is frozen.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
params: Dict[str, Array]
|
|
30
|
+
constraints: Dict[str, Callable[[Array], Array]]
|
|
31
|
+
static: bool
|
|
32
|
+
|
|
33
|
+
def __init__(self, dim: int, key: PRNGKeyArray = jr.key(0)):
|
|
34
|
+
"""
|
|
35
|
+
Initializes a Planar flow.
|
|
36
|
+
|
|
37
|
+
# Parameters
|
|
38
|
+
- `dim`: The dimension of the parameter space.
|
|
39
|
+
"""
|
|
40
|
+
self.static = False
|
|
41
|
+
# Split key
|
|
42
|
+
k1, k2, k3 = jr.split(key, 3)
|
|
43
|
+
|
|
44
|
+
# Initialize parameters
|
|
45
|
+
self.params = {
|
|
46
|
+
"w": jr.normal(k1, (dim,)) / dim**0.5,
|
|
47
|
+
"u_hat": jr.normal(k2, (dim,)) / dim**0.5,
|
|
48
|
+
"b": jr.normal(k3, ()) / dim**0.5,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
self.constraints = {}
|
|
52
|
+
|
|
53
|
+
def __forward(self, draw: Float[Array, " n_dim"]) -> Float[Array, " n_dim"]:
|
|
54
|
+
params = self.transform_params()
|
|
55
|
+
|
|
56
|
+
# Extract parameters
|
|
57
|
+
w: Float[Array, " n_dim"] = params["w"]
|
|
58
|
+
u: Float[Array, " n_dim"] = params["u"]
|
|
59
|
+
b: Scalar = params["b"]
|
|
60
|
+
|
|
61
|
+
# Compute inner term
|
|
62
|
+
a = draw.dot(w) + b
|
|
63
|
+
|
|
64
|
+
# Compute nonlinear stretch
|
|
65
|
+
h = _h(a)
|
|
66
|
+
|
|
67
|
+
# Compute forward transformation
|
|
68
|
+
draw = draw + u * h
|
|
69
|
+
|
|
70
|
+
return draw
|
|
71
|
+
|
|
72
|
+
@eqx.filter_jit
|
|
73
|
+
def forward(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
|
|
74
|
+
f = jax.vmap(self.__forward, 0)
|
|
75
|
+
return f(draws)
|
|
76
|
+
|
|
77
|
+
def __adjust(self, draw: Float[Array, "n_dim"]) -> Scalar: # noqa
|
|
78
|
+
params = self.transform_params()
|
|
79
|
+
|
|
80
|
+
# Extract parameters
|
|
81
|
+
w: Float[Array, " n_dim"] = params["w"]
|
|
82
|
+
u: Float[Array, " n_dim"] = params["u"]
|
|
83
|
+
b: Scalar = params["b"]
|
|
84
|
+
|
|
85
|
+
# Compute inner term
|
|
86
|
+
a = draw.dot(w) + b
|
|
87
|
+
|
|
88
|
+
# Compute log-Jacobian adjustment
|
|
89
|
+
lja: Scalar = jnp.log(1.0 + u.dot(_dh(a) * w))
|
|
90
|
+
|
|
91
|
+
assert lja.shape == ()
|
|
92
|
+
return lja
|
|
93
|
+
|
|
94
|
+
@eqx.filter_jit
|
|
95
|
+
def adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
|
|
96
|
+
f = jax.vmap(self.__adjust, 0)
|
|
97
|
+
return f(draws)
|
|
98
|
+
|
|
99
|
+
def __forward_and_adjust(self, draw: Float[Array, " n_dim"]) -> Tuple[Float[Array, " n_dim"], Scalar]:
|
|
100
|
+
params = self.transform_params()
|
|
101
|
+
|
|
102
|
+
# Extract parameters
|
|
103
|
+
w: Float[Array, " n_dim"] = params["w"]
|
|
104
|
+
u: Float[Array, " n_dim"] = params["u"]
|
|
105
|
+
b: Scalar = params["b"]
|
|
106
|
+
|
|
107
|
+
# Compute inner term
|
|
108
|
+
a = draw.dot(w) + b
|
|
109
|
+
|
|
110
|
+
# Compute forward transformation
|
|
111
|
+
draw = draw + u * _h(a)
|
|
112
|
+
|
|
113
|
+
# Compute log-Jacobian adjustment
|
|
114
|
+
lja: Scalar = jnp.log(1.0 + u.dot(_dh(a) * w))
|
|
115
|
+
|
|
116
|
+
assert len(draw.shape) == 1
|
|
117
|
+
assert lja.shape == ()
|
|
118
|
+
|
|
119
|
+
return draw, lja
|
|
120
|
+
|
|
121
|
+
@eqx.filter_jit
|
|
122
|
+
def forward_and_adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Tuple[Float[Array, "n_draws n_dim"], Scalar]:
|
|
123
|
+
f = jax.vmap(self.__forward_and_adjust, 0)
|
|
124
|
+
return f(draws)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def transform_params(self) -> Dict[str, Array]:
|
|
128
|
+
"""
|
|
129
|
+
Applies the affine constraint to u_hat to compute the constrained parameter u,
|
|
130
|
+
and returns all parameters.
|
|
131
|
+
|
|
132
|
+
# Returns
|
|
133
|
+
A dictionary of parameters including the computed, constrained 'u'.
|
|
134
|
+
"""
|
|
135
|
+
constrained_params = self.constrain_params()
|
|
136
|
+
|
|
137
|
+
# Extract parameters
|
|
138
|
+
w: Array = constrained_params["w"]
|
|
139
|
+
u_hat: Array = constrained_params["u_hat"]
|
|
140
|
+
|
|
141
|
+
# Compute the constrained u
|
|
142
|
+
u = u_hat + (-1.0 - jnp.dot(w, u_hat)) / (jnp.sum(w**2) + 1e-6) * w
|
|
143
|
+
constrained_params["u"] = u
|
|
144
|
+
|
|
145
|
+
return constrained_params
|
|
146
|
+
|
|
147
|
+
class Planar(FlowSpec):
|
|
148
|
+
"""
|
|
149
|
+
A specification for the Planar flow.
|
|
150
|
+
"""
|
|
151
|
+
def __init__(self):
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
def construct(self, dim: int) -> PlanarLayer:
|
|
155
|
+
return PlanarLayer(dim)
|
bayinx/flows/radial.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# WIP
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from typing import Callable, Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
from jax.lax import scan
|
|
8
|
+
from jaxtyping import Array, Float, PRNGKeyArray, Scalar
|
|
9
|
+
|
|
10
|
+
from bayinx.core.flow import FlowLayer, FlowSpec
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _h(x: Array) -> Array:
|
|
14
|
+
"""Non-linearity for Sylvester flow."""
|
|
15
|
+
return jnp.tanh(x)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _dh(x: Array) -> Array:
|
|
19
|
+
"""Derivative of the non-linearity."""
|
|
20
|
+
return 1.0 - jnp.tanh(x)**2
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _construct_orthogonal(vectors: Float[Array, " dim rank"]) -> Float[Array, " dim rank"]:
|
|
24
|
+
"""
|
|
25
|
+
Constructs a D x M orthogonal matrix (Q) using Householder reflections.
|
|
26
|
+
"""
|
|
27
|
+
D, M = vectors.shape
|
|
28
|
+
|
|
29
|
+
# Initialize Q as the thin identity matrix (first M columns of I_D)
|
|
30
|
+
Q_initial = jnp.eye(D, M)
|
|
31
|
+
|
|
32
|
+
def apply_reflection(Q_current, v_k):
|
|
33
|
+
v_norm_sq = jnp.sum(v_k**2)
|
|
34
|
+
tau = 2.0 / v_norm_sq
|
|
35
|
+
|
|
36
|
+
a_T = jnp.dot(v_k, Q_current)
|
|
37
|
+
|
|
38
|
+
update = jnp.outer(v_k, tau * a_T)
|
|
39
|
+
|
|
40
|
+
Q_next = Q_current - update
|
|
41
|
+
|
|
42
|
+
return Q_next, None
|
|
43
|
+
|
|
44
|
+
Q_final, _ = scan(
|
|
45
|
+
f=apply_reflection,
|
|
46
|
+
init=Q_initial,
|
|
47
|
+
xs=vectors.T
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return Q_final
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class SylvesterLayer(FlowLayer):
|
|
54
|
+
"""
|
|
55
|
+
A Sylvester flow.
|
|
56
|
+
|
|
57
|
+
# Attributes
|
|
58
|
+
- `params`: Dictionary containing raw parameters.
|
|
59
|
+
- `constraints`: Dictionary of constraining transformations.
|
|
60
|
+
- `static`: Whether the flow layer is frozen.
|
|
61
|
+
- `rank`: The rank (M) of the transformation.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
rank: int
|
|
65
|
+
params: Dict[str, Array]
|
|
66
|
+
constraints: Dict[str, Callable[[Array], Array]]
|
|
67
|
+
static: bool
|
|
68
|
+
|
|
69
|
+
def __init__(self, dim: int, rank: int, key: PRNGKeyArray = jr.key(0)):
|
|
70
|
+
"""
|
|
71
|
+
Initializes the Sylvester flow.
|
|
72
|
+
|
|
73
|
+
# Parameters
|
|
74
|
+
- `dim`: The dimension of the parameter space (D).
|
|
75
|
+
- `rank`: The number of hidden units/reflections (M).
|
|
76
|
+
"""
|
|
77
|
+
self.static = False
|
|
78
|
+
self.rank = rank
|
|
79
|
+
|
|
80
|
+
k1, k2, k3, k4 = jr.split(key, 4)
|
|
81
|
+
|
|
82
|
+
# Initialize parameters
|
|
83
|
+
# hvecs: D x M matrix where each column is a Householder vector
|
|
84
|
+
self.params = {
|
|
85
|
+
"hvecs": jr.normal(k1, (dim, rank)) / dim**0.5,
|
|
86
|
+
"r1": jr.normal(k2, (rank, rank)) / rank**0.5,
|
|
87
|
+
"r2": jr.normal(k3, (rank, rank)) / rank**0.5,
|
|
88
|
+
"b": jr.normal(k4, (rank,)) / rank**0.5,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
# Constraint for Upper Triangular matrices with positive diagonal
|
|
92
|
+
def constrain_triangular(matrix: Array):
|
|
93
|
+
# Extract diagonal and apply exponential to ensure invertibility
|
|
94
|
+
diag: Array = jnp.exp(jnp.diag(matrix))
|
|
95
|
+
# Return upper triangular matrix with modified diagonal
|
|
96
|
+
return jnp.fill_diagonal(jnp.triu(matrix), diag, inplace=False)
|
|
97
|
+
|
|
98
|
+
self.constraints = {
|
|
99
|
+
"r1": constrain_triangular,
|
|
100
|
+
"r2": constrain_triangular
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
def transform_params(self) -> Dict[str, Array]:
|
|
104
|
+
"""
|
|
105
|
+
Applies constraints and constructs the orthogonal matrix Q.
|
|
106
|
+
"""
|
|
107
|
+
constrained = self.constrain_params()
|
|
108
|
+
|
|
109
|
+
# Construct orthogonal matrix
|
|
110
|
+
q = _construct_orthogonal(constrained["hvecs"])
|
|
111
|
+
constrained["q"] = q
|
|
112
|
+
|
|
113
|
+
return constrained
|
|
114
|
+
|
|
115
|
+
def __forward(self, draw: Float[Array, " n_dim"]) -> Float[Array, " n_dim"]:
|
|
116
|
+
params = self.transform_params()
|
|
117
|
+
|
|
118
|
+
# Extract parameters
|
|
119
|
+
Q: Float[Array, "dim rank"] = params["q"]
|
|
120
|
+
R1: Float[Array, "rank rank"] = params["r1"]
|
|
121
|
+
R2: Float[Array, "rank rank"] = params["r2"]
|
|
122
|
+
b: Float[Array, " rank"] = params["b"]
|
|
123
|
+
|
|
124
|
+
# Compute inner terms
|
|
125
|
+
y = R2.dot(Q.T.dot(draw)) + b
|
|
126
|
+
h_y = _h(y)
|
|
127
|
+
|
|
128
|
+
# Compute forward transform
|
|
129
|
+
draw = draw + Q.dot(R1.dot(h_y))
|
|
130
|
+
|
|
131
|
+
return draw
|
|
132
|
+
|
|
133
|
+
@eqx.filter_jit
|
|
134
|
+
def forward(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
|
|
135
|
+
f = jax.vmap(self.__forward, 0)
|
|
136
|
+
return f(draws)
|
|
137
|
+
|
|
138
|
+
def __adjust(self, draw: Float[Array, " n_dim"]) -> Scalar:
|
|
139
|
+
params = self.transform_params()
|
|
140
|
+
|
|
141
|
+
# Extract parameters
|
|
142
|
+
Q: Float[Array, "dim rank"] = params["q"]
|
|
143
|
+
R1: Float[Array, "rank rank"] = params["r1"]
|
|
144
|
+
R2: Float[Array, "rank rank"] = params["r2"]
|
|
145
|
+
b: Float[Array, " rank"] = params["b"]
|
|
146
|
+
|
|
147
|
+
# Recompute the argument to the nonlinearity
|
|
148
|
+
term = R2.dot(Q.T.dot(draw)) + b
|
|
149
|
+
diag_h_prime = _dh(term)
|
|
150
|
+
|
|
151
|
+
# Diagonal of R1 and R2
|
|
152
|
+
d_r1 = jnp.diag(R1)
|
|
153
|
+
d_r2 = jnp.diag(R2)
|
|
154
|
+
|
|
155
|
+
# Diagonal term of the matrix (I + diag(h') R2 R1)
|
|
156
|
+
# diag_term_i = 1 + h'_i * (R2_ii * R1_ii)
|
|
157
|
+
diag_term = 1.0 + diag_h_prime * d_r2 * d_r1
|
|
158
|
+
|
|
159
|
+
lja = jnp.sum(jnp.log(diag_term))
|
|
160
|
+
|
|
161
|
+
assert lja.shape == ()
|
|
162
|
+
|
|
163
|
+
return lja
|
|
164
|
+
|
|
165
|
+
@eqx.filter_jit
|
|
166
|
+
def adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
|
|
167
|
+
f = jax.vmap(self.__adjust, 0)
|
|
168
|
+
return f(draws)
|
|
169
|
+
|
|
170
|
+
def __forward_and_adjust(self, draw: Float[Array, " n_dim"]) -> Tuple[Float[Array, " n_dim"], Scalar]:
|
|
171
|
+
params = self.transform_params()
|
|
172
|
+
|
|
173
|
+
Q: Float[Array, "dim rank"] = params["q"]
|
|
174
|
+
R1: Float[Array, "rank rank"] = params["r1"]
|
|
175
|
+
R2: Float[Array, "rank rank"] = params["r2"]
|
|
176
|
+
b: Float[Array, " rank"] = params["b"]
|
|
177
|
+
|
|
178
|
+
# --- Forward ---
|
|
179
|
+
q_z = jnp.dot(Q.T, draw)
|
|
180
|
+
arg = jnp.dot(R2, q_z) + b
|
|
181
|
+
|
|
182
|
+
# h(arg)
|
|
183
|
+
h_y = _h(arg)
|
|
184
|
+
|
|
185
|
+
# Update
|
|
186
|
+
term = jnp.dot(R1, h_y)
|
|
187
|
+
draw_new = draw + jnp.dot(Q, term)
|
|
188
|
+
|
|
189
|
+
# --- Adjust ---
|
|
190
|
+
# Derivative h'(arg)
|
|
191
|
+
diag_h_prime = _dh(arg)
|
|
192
|
+
|
|
193
|
+
# Diagonals of triangular matrices
|
|
194
|
+
d_r1 = jnp.diag(R1)
|
|
195
|
+
d_r2 = jnp.diag(R2)
|
|
196
|
+
|
|
197
|
+
# Log-det of triangular matrix
|
|
198
|
+
diag_term = 1.0 + diag_h_prime * d_r2 * d_r1
|
|
199
|
+
lja = jnp.sum(jnp.log(jnp.abs(diag_term)))
|
|
200
|
+
|
|
201
|
+
assert len(draw_new.shape) == 1
|
|
202
|
+
assert lja.shape == ()
|
|
203
|
+
|
|
204
|
+
return draw_new, lja
|
|
205
|
+
|
|
206
|
+
@eqx.filter_jit
|
|
207
|
+
def forward_and_adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Tuple[Float[Array, "n_draws n_dim"], Scalar]:
|
|
208
|
+
f = jax.vmap(self.__forward_and_adjust, 0)
|
|
209
|
+
return f(draws)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class Sylvester(FlowSpec):
|
|
213
|
+
"""
|
|
214
|
+
A specification for the Orthogonal Sylvester flow.
|
|
215
|
+
"""
|
|
216
|
+
rank: int
|
|
217
|
+
|
|
218
|
+
def __init__(self, rank: int):
|
|
219
|
+
self.rank = rank
|
|
220
|
+
|
|
221
|
+
def construct(self, dim: int) -> SylvesterLayer:
|
|
222
|
+
if self.rank > dim:
|
|
223
|
+
raise ValueError(f"Rank {self.rank} cannot be greater than dimension {dim}.")
|
|
224
|
+
|
|
225
|
+
return SylvesterLayer(dim, self.rank)
|
bayinx/nodes/__init__.py
ADDED