bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,123 @@
1
+ from typing import Callable, Dict, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ from jaxtyping import Array, Float, PRNGKeyArray, Scalar
8
+
9
+ from bayinx.core.flow import FlowLayer, FlowSpec
10
+
11
+
12
+ class FullAffineLayer(FlowLayer):
13
+ """
14
+ A full affine flow.
15
+
16
+ # Attributes
17
+ - `params`: The parameters of the full affine flow.
18
+ - `constraints`: The constraining transformations for the parameters of the full affine flow.
19
+ - `static`: Whether the flow layer is frozen (parameters are not subject to further optimization).
20
+ """
21
+
22
+ params: Dict[str, Array]
23
+ constraints: Dict[str, Callable[[Array], Array]]
24
+ static: bool
25
+
26
+
27
+ def __init__(self, dim: int, key: PRNGKeyArray):
28
+ """
29
+ Initializes a full affine flow.
30
+
31
+ # Parameters
32
+ - `dim`: The dimension of the parameter space.
33
+ """
34
+ self.static = False
35
+
36
+ # Split key
37
+ k1, k2 = jr.split(key)
38
+
39
+ # Initialize parameters
40
+ self.params = {
41
+ "shift": jr.normal(key, (dim, )) / dim**0.5,
42
+ "scale": jr.normal(key, (dim, dim)) / dim**0.5,
43
+ }
44
+
45
+ # Define constraints
46
+ if dim == 1:
47
+ self.constraints = {"scale": jnp.exp}
48
+ else:
49
+ def constrain_scale(scale: Array):
50
+ # Extract diagonal and apply exponential
51
+ diag: Array = jnp.exp(jnp.diag(scale))
52
+
53
+ # Return matrix with modified diagonal
54
+ return jnp.fill_diagonal(jnp.tril(scale), diag, inplace=False)
55
+
56
+ self.constraints = {"scale": constrain_scale}
57
+
58
+ @eqx.filter_jit
59
+ def forward(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
60
+ params = self.transform_params()
61
+
62
+ # Extract parameters
63
+ shift: Float[Array, " n_dim"] = params["shift"]
64
+ scale: Float[Array, "n_dim n_dim"] = params["scale"]
65
+
66
+ # Compute forward transformation
67
+ draws = (scale @ draws.T).T + shift
68
+
69
+ return draws
70
+
71
+ def __adjust(self, draw: Float[Array, " n_dim"]) -> Float[Array, " n_dim"]:
72
+ params = self.transform_params()
73
+
74
+ # Extract parameters
75
+ scale: Float[Array, "n_dim n_dim"] = params["scale"]
76
+
77
+ # Compute log-Jacobian adjustments
78
+ lja: Array = jnp.log(jnp.diag(scale)).sum()
79
+
80
+ assert lja.shape == ()
81
+
82
+ return lja
83
+
84
+ @eqx.filter_jit
85
+ def adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
86
+ f = jax.vmap(self.__adjust, 0)
87
+ return f(draws)
88
+
89
+ def __forward_and_adjust(self, draw: Float[Array, " n_dim"]) -> Tuple[Float[Array, " n_dim"], Scalar]:
90
+ params = self.transform_params()
91
+
92
+ assert len(draw.shape) == 1
93
+
94
+ # Extract parameters
95
+ shift: Float[Array, " n_dim"] = params["shift"]
96
+ scale: Float[Array, "n_dim n_dim"] = params["scale"]
97
+
98
+ # Compute forward transformation
99
+ draw = (scale @ draw.T).T + shift
100
+
101
+ assert len(draw.shape) == 1
102
+
103
+ # Compute lja
104
+ lja: Scalar = jnp.log(jnp.diag(scale)).sum()
105
+
106
+ assert lja.shape == ()
107
+
108
+ return draw, lja
109
+
110
+ @eqx.filter_jit
111
+ def forward_and_adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Tuple[Float[Array, "n_draws n_dim"], Scalar]:
112
+ f = jax.vmap(self.__forward_and_adjust, 0)
113
+ return f(draws)
114
+
115
+
116
+ class FullAffine(FlowSpec):
117
+ key: PRNGKeyArray
118
+
119
+ def __init__(self, key: PRNGKeyArray = jr.key(0)):
120
+ self.key = key
121
+
122
+ def construct(self, dim: int) -> FullAffineLayer:
123
+ return FullAffineLayer(dim, self.key)
@@ -0,0 +1,165 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ from jax.lax import scan
8
+ from jaxtyping import Array, Float, PRNGKeyArray, Scalar
9
+
10
+ from bayinx.core.flow import FlowLayer, FlowSpec
11
+
12
+
13
+ def _mat_op(
14
+ carry: Float[Array, " rank"],
15
+ vars: Tuple[
16
+ Float[Array, " rank"],
17
+ Float[Array, " rank"],
18
+ Scalar,
19
+ Scalar,
20
+ Scalar
21
+ ]
22
+ ) -> Tuple[Float[Array, " rank"], Scalar]:
23
+ """
24
+ The implicit matrix operations compute:
25
+
26
+ x = Lz + b = (D + W)z + b = Dz + (UV^T * M)z + b
27
+
28
+ Where W = UV^T * M is the low-rank representation of the strictly lower triangular matrix W,
29
+ parameterized by U and V, and an implicit mask M that zeros out upper triangular elements.
30
+
31
+ In the implementation only U, V, diag(D), and b are needed to complete the matrix operations.
32
+ """
33
+ # Unpack state
34
+ U_r, V_r, z_i, d_i, b_i = vars
35
+ h_i = carry
36
+
37
+ # Compute partial product
38
+ h_next = h_i + z_i * V_r
39
+
40
+ # Compute i-th element of transformed draw:
41
+ x_i = b_i + z_i * d_i + U_r.dot(h_i)
42
+
43
+ return h_next, x_i
44
+
45
+ class LowRankAffineLayer(FlowLayer):
46
+ """
47
+ A low-rank affine flow.
48
+
49
+ # Attributes
50
+ - `params`: A dictionary containing the shift and low-rank representation of the scale parameters.
51
+ - `constraints`: A dictionary of constraining transformations.
52
+ - `static`: Whether the flow layer is frozen (parameters are not subject to further optimization).
53
+ - `rank`: Rank of the scale transformation.
54
+ """
55
+
56
+ rank: int
57
+
58
+ def __init__(self, dim: int, rank: int, key: PRNGKeyArray = jr.key(0)):
59
+ """
60
+ Initializes a low-rank affine flow.
61
+
62
+ # Parameters
63
+ - `dim`: The dimension of the parameter space.
64
+ - `rank`: The rank of the (implicit) scale matrix.
65
+ """
66
+ self.static = False
67
+ self.rank = rank
68
+
69
+ # Split key
70
+ k1, k2, k3, k4 = jr.split(key, 4)
71
+
72
+ # Initialize parameters
73
+ self.params = {
74
+ "shift": jr.normal(k1, (dim, )) / dim**0.5,
75
+ "diag_scale": jr.normal(k2, (dim, )) / dim**0.5,
76
+ "offdiag_scale": (
77
+ jr.normal(k3, (dim, rank)) / dim**0.5,
78
+ jr.normal(k4, (dim, rank)) / dim**0.5
79
+ )
80
+ }
81
+
82
+ # Define constraints
83
+ self.constraints = {"diag_scale": jnp.exp}
84
+
85
+
86
+ def __forward(self, draw: Float[Array, " n_dim"]) -> Float[Array, " n_dim"]:
87
+ params = self.transform_params()
88
+
89
+ # Extract parameters
90
+ shift: Array = params["shift"]
91
+ diag: Array = params["diag_scale"]
92
+ U, V = params["offdiag_scale"]
93
+
94
+ # Compute forward transformation
95
+ _, draw = scan(
96
+ f=_mat_op,
97
+ init=jnp.zeros((self.rank, )),
98
+ xs= (U, V, draw, diag, shift)
99
+ )
100
+
101
+ return draw
102
+
103
+ @eqx.filter_jit
104
+ def forward(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
105
+ f = jax.vmap(self.__forward, 0)
106
+ return f(draws)
107
+
108
+ def __adjust(self, draw: Float[Array, " n_dim"]) -> Scalar:
109
+ params = self.transform_params()
110
+
111
+ diag: Array = params["diag_scale"]
112
+
113
+ # Compute log-Jacobian adjustment
114
+ lja: Scalar = jnp.log(diag).sum()
115
+
116
+ assert lja.shape == ()
117
+
118
+ return lja
119
+
120
+ @eqx.filter_jit
121
+ def adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
122
+ f = jax.vmap(self.__adjust, 0)
123
+ return f(draws)
124
+
125
+ def __forward_and_adjust(self, draw: Float[Array, " n_dim"]) -> Tuple[Float[Array, " n_dim"], Scalar]:
126
+ params = self.transform_params()
127
+
128
+ assert len(draw.shape) == 1
129
+
130
+ # Extract parameters
131
+ shift: Array = params["shift"]
132
+ diag: Array = params["diag_scale"]
133
+ U, V = params["offdiag_scale"]
134
+
135
+ # Compute log-Jacobian adjustment
136
+ lja: Scalar = jnp.log(diag).sum()
137
+
138
+ # Compute forward transformation
139
+ _, draw = scan(
140
+ f=_mat_op,
141
+ init=jnp.zeros((self.rank, )),
142
+ xs= (U, V, draw, diag, shift)
143
+ )
144
+
145
+ assert len(draw.shape) == 1
146
+ assert lja.shape == ()
147
+
148
+ return draw, lja
149
+
150
+ @eqx.filter_jit
151
+ def forward_and_adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Tuple[Float[Array, "n_draws n_dim"], Scalar]:
152
+ f = jax.vmap(self.__forward_and_adjust, 0)
153
+ return f(draws)
154
+
155
+ class LowRankAffine(FlowSpec):
156
+ rank: int
157
+
158
+ def __init__(self, rank: int):
159
+ self.rank = rank
160
+
161
+ def construct(self, dim: int) -> LowRankAffineLayer:
162
+ if (self.rank > (dim - 1)/2):
163
+ raise ValueError(f"Rank {self.rank} is large, consider using a full affine flow instead.")
164
+
165
+ return LowRankAffineLayer(dim, self.rank)
bayinx/flows/planar.py ADDED
@@ -0,0 +1,155 @@
1
+ from typing import Callable, Dict, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ from jaxtyping import Array, Float, PRNGKeyArray, Scalar
8
+
9
+ from bayinx.core.flow import FlowLayer, FlowSpec
10
+
11
+
12
+ def _h(x: Array) -> Array:
13
+ return jnp.tanh(x)
14
+
15
+ def _dh(x: Array) -> Array:
16
+ return 1.0 - jnp.tanh(x)**2
17
+
18
+
19
+ class PlanarLayer(FlowLayer):
20
+ """
21
+ A Planar flow.
22
+
23
+ # Attributes
24
+ - `params`: The parameters of the Planar flow (w, u_hat, b).
25
+ - `constraints`: The constraining transformations for the parameters.
26
+ - `static`: Whether the flow layer is frozen.
27
+ """
28
+
29
+ params: Dict[str, Array]
30
+ constraints: Dict[str, Callable[[Array], Array]]
31
+ static: bool
32
+
33
+ def __init__(self, dim: int, key: PRNGKeyArray = jr.key(0)):
34
+ """
35
+ Initializes a Planar flow.
36
+
37
+ # Parameters
38
+ - `dim`: The dimension of the parameter space.
39
+ """
40
+ self.static = False
41
+ # Split key
42
+ k1, k2, k3 = jr.split(key, 3)
43
+
44
+ # Initialize parameters
45
+ self.params = {
46
+ "w": jr.normal(k1, (dim,)) / dim**0.5,
47
+ "u_hat": jr.normal(k2, (dim,)) / dim**0.5,
48
+ "b": jr.normal(k3, ()) / dim**0.5,
49
+ }
50
+
51
+ self.constraints = {}
52
+
53
+ def __forward(self, draw: Float[Array, " n_dim"]) -> Float[Array, " n_dim"]:
54
+ params = self.transform_params()
55
+
56
+ # Extract parameters
57
+ w: Float[Array, " n_dim"] = params["w"]
58
+ u: Float[Array, " n_dim"] = params["u"]
59
+ b: Scalar = params["b"]
60
+
61
+ # Compute inner term
62
+ a = draw.dot(w) + b
63
+
64
+ # Compute nonlinear stretch
65
+ h = _h(a)
66
+
67
+ # Compute forward transformation
68
+ draw = draw + u * h
69
+
70
+ return draw
71
+
72
+ @eqx.filter_jit
73
+ def forward(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
74
+ f = jax.vmap(self.__forward, 0)
75
+ return f(draws)
76
+
77
+ def __adjust(self, draw: Float[Array, "n_dim"]) -> Scalar: # noqa
78
+ params = self.transform_params()
79
+
80
+ # Extract parameters
81
+ w: Float[Array, " n_dim"] = params["w"]
82
+ u: Float[Array, " n_dim"] = params["u"]
83
+ b: Scalar = params["b"]
84
+
85
+ # Compute inner term
86
+ a = draw.dot(w) + b
87
+
88
+ # Compute log-Jacobian adjustment
89
+ lja: Scalar = jnp.log(1.0 + u.dot(_dh(a) * w))
90
+
91
+ assert lja.shape == ()
92
+ return lja
93
+
94
+ @eqx.filter_jit
95
+ def adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
96
+ f = jax.vmap(self.__adjust, 0)
97
+ return f(draws)
98
+
99
+ def __forward_and_adjust(self, draw: Float[Array, " n_dim"]) -> Tuple[Float[Array, " n_dim"], Scalar]:
100
+ params = self.transform_params()
101
+
102
+ # Extract parameters
103
+ w: Float[Array, " n_dim"] = params["w"]
104
+ u: Float[Array, " n_dim"] = params["u"]
105
+ b: Scalar = params["b"]
106
+
107
+ # Compute inner term
108
+ a = draw.dot(w) + b
109
+
110
+ # Compute forward transformation
111
+ draw = draw + u * _h(a)
112
+
113
+ # Compute log-Jacobian adjustment
114
+ lja: Scalar = jnp.log(1.0 + u.dot(_dh(a) * w))
115
+
116
+ assert len(draw.shape) == 1
117
+ assert lja.shape == ()
118
+
119
+ return draw, lja
120
+
121
+ @eqx.filter_jit
122
+ def forward_and_adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Tuple[Float[Array, "n_draws n_dim"], Scalar]:
123
+ f = jax.vmap(self.__forward_and_adjust, 0)
124
+ return f(draws)
125
+
126
+
127
+ def transform_params(self) -> Dict[str, Array]:
128
+ """
129
+ Applies the affine constraint to u_hat to compute the constrained parameter u,
130
+ and returns all parameters.
131
+
132
+ # Returns
133
+ A dictionary of parameters including the computed, constrained 'u'.
134
+ """
135
+ constrained_params = self.constrain_params()
136
+
137
+ # Extract parameters
138
+ w: Array = constrained_params["w"]
139
+ u_hat: Array = constrained_params["u_hat"]
140
+
141
+ # Compute the constrained u
142
+ u = u_hat + (-1.0 - jnp.dot(w, u_hat)) / (jnp.sum(w**2) + 1e-6) * w
143
+ constrained_params["u"] = u
144
+
145
+ return constrained_params
146
+
147
+ class Planar(FlowSpec):
148
+ """
149
+ A specification for the Planar flow.
150
+ """
151
+ def __init__(self):
152
+ pass
153
+
154
+ def construct(self, dim: int) -> PlanarLayer:
155
+ return PlanarLayer(dim)
bayinx/flows/radial.py ADDED
@@ -0,0 +1 @@
1
+ # WIP
@@ -0,0 +1,225 @@
1
+ from typing import Callable, Dict, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ from jax.lax import scan
8
+ from jaxtyping import Array, Float, PRNGKeyArray, Scalar
9
+
10
+ from bayinx.core.flow import FlowLayer, FlowSpec
11
+
12
+
13
+ def _h(x: Array) -> Array:
14
+ """Non-linearity for Sylvester flow."""
15
+ return jnp.tanh(x)
16
+
17
+
18
+ def _dh(x: Array) -> Array:
19
+ """Derivative of the non-linearity."""
20
+ return 1.0 - jnp.tanh(x)**2
21
+
22
+
23
+ def _construct_orthogonal(vectors: Float[Array, " dim rank"]) -> Float[Array, " dim rank"]:
24
+ """
25
+ Constructs a D x M orthogonal matrix (Q) using Householder reflections.
26
+ """
27
+ D, M = vectors.shape
28
+
29
+ # Initialize Q as the thin identity matrix (first M columns of I_D)
30
+ Q_initial = jnp.eye(D, M)
31
+
32
+ def apply_reflection(Q_current, v_k):
33
+ v_norm_sq = jnp.sum(v_k**2)
34
+ tau = 2.0 / v_norm_sq
35
+
36
+ a_T = jnp.dot(v_k, Q_current)
37
+
38
+ update = jnp.outer(v_k, tau * a_T)
39
+
40
+ Q_next = Q_current - update
41
+
42
+ return Q_next, None
43
+
44
+ Q_final, _ = scan(
45
+ f=apply_reflection,
46
+ init=Q_initial,
47
+ xs=vectors.T
48
+ )
49
+
50
+ return Q_final
51
+
52
+
53
+ class SylvesterLayer(FlowLayer):
54
+ """
55
+ A Sylvester flow.
56
+
57
+ # Attributes
58
+ - `params`: Dictionary containing raw parameters.
59
+ - `constraints`: Dictionary of constraining transformations.
60
+ - `static`: Whether the flow layer is frozen.
61
+ - `rank`: The rank (M) of the transformation.
62
+ """
63
+
64
+ rank: int
65
+ params: Dict[str, Array]
66
+ constraints: Dict[str, Callable[[Array], Array]]
67
+ static: bool
68
+
69
+ def __init__(self, dim: int, rank: int, key: PRNGKeyArray = jr.key(0)):
70
+ """
71
+ Initializes the Sylvester flow.
72
+
73
+ # Parameters
74
+ - `dim`: The dimension of the parameter space (D).
75
+ - `rank`: The number of hidden units/reflections (M).
76
+ """
77
+ self.static = False
78
+ self.rank = rank
79
+
80
+ k1, k2, k3, k4 = jr.split(key, 4)
81
+
82
+ # Initialize parameters
83
+ # hvecs: D x M matrix where each column is a Householder vector
84
+ self.params = {
85
+ "hvecs": jr.normal(k1, (dim, rank)) / dim**0.5,
86
+ "r1": jr.normal(k2, (rank, rank)) / rank**0.5,
87
+ "r2": jr.normal(k3, (rank, rank)) / rank**0.5,
88
+ "b": jr.normal(k4, (rank,)) / rank**0.5,
89
+ }
90
+
91
+ # Constraint for Upper Triangular matrices with positive diagonal
92
+ def constrain_triangular(matrix: Array):
93
+ # Extract diagonal and apply exponential to ensure invertibility
94
+ diag: Array = jnp.exp(jnp.diag(matrix))
95
+ # Return upper triangular matrix with modified diagonal
96
+ return jnp.fill_diagonal(jnp.triu(matrix), diag, inplace=False)
97
+
98
+ self.constraints = {
99
+ "r1": constrain_triangular,
100
+ "r2": constrain_triangular
101
+ }
102
+
103
+ def transform_params(self) -> Dict[str, Array]:
104
+ """
105
+ Applies constraints and constructs the orthogonal matrix Q.
106
+ """
107
+ constrained = self.constrain_params()
108
+
109
+ # Construct orthogonal matrix
110
+ q = _construct_orthogonal(constrained["hvecs"])
111
+ constrained["q"] = q
112
+
113
+ return constrained
114
+
115
+ def __forward(self, draw: Float[Array, " n_dim"]) -> Float[Array, " n_dim"]:
116
+ params = self.transform_params()
117
+
118
+ # Extract parameters
119
+ Q: Float[Array, "dim rank"] = params["q"]
120
+ R1: Float[Array, "rank rank"] = params["r1"]
121
+ R2: Float[Array, "rank rank"] = params["r2"]
122
+ b: Float[Array, " rank"] = params["b"]
123
+
124
+ # Compute inner terms
125
+ y = R2.dot(Q.T.dot(draw)) + b
126
+ h_y = _h(y)
127
+
128
+ # Compute forward transform
129
+ draw = draw + Q.dot(R1.dot(h_y))
130
+
131
+ return draw
132
+
133
+ @eqx.filter_jit
134
+ def forward(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
135
+ f = jax.vmap(self.__forward, 0)
136
+ return f(draws)
137
+
138
+ def __adjust(self, draw: Float[Array, " n_dim"]) -> Scalar:
139
+ params = self.transform_params()
140
+
141
+ # Extract parameters
142
+ Q: Float[Array, "dim rank"] = params["q"]
143
+ R1: Float[Array, "rank rank"] = params["r1"]
144
+ R2: Float[Array, "rank rank"] = params["r2"]
145
+ b: Float[Array, " rank"] = params["b"]
146
+
147
+ # Recompute the argument to the nonlinearity
148
+ term = R2.dot(Q.T.dot(draw)) + b
149
+ diag_h_prime = _dh(term)
150
+
151
+ # Diagonal of R1 and R2
152
+ d_r1 = jnp.diag(R1)
153
+ d_r2 = jnp.diag(R2)
154
+
155
+ # Diagonal term of the matrix (I + diag(h') R2 R1)
156
+ # diag_term_i = 1 + h'_i * (R2_ii * R1_ii)
157
+ diag_term = 1.0 + diag_h_prime * d_r2 * d_r1
158
+
159
+ lja = jnp.sum(jnp.log(diag_term))
160
+
161
+ assert lja.shape == ()
162
+
163
+ return lja
164
+
165
+ @eqx.filter_jit
166
+ def adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Float[Array, "n_draws n_dim"]:
167
+ f = jax.vmap(self.__adjust, 0)
168
+ return f(draws)
169
+
170
+ def __forward_and_adjust(self, draw: Float[Array, " n_dim"]) -> Tuple[Float[Array, " n_dim"], Scalar]:
171
+ params = self.transform_params()
172
+
173
+ Q: Float[Array, "dim rank"] = params["q"]
174
+ R1: Float[Array, "rank rank"] = params["r1"]
175
+ R2: Float[Array, "rank rank"] = params["r2"]
176
+ b: Float[Array, " rank"] = params["b"]
177
+
178
+ # --- Forward ---
179
+ q_z = jnp.dot(Q.T, draw)
180
+ arg = jnp.dot(R2, q_z) + b
181
+
182
+ # h(arg)
183
+ h_y = _h(arg)
184
+
185
+ # Update
186
+ term = jnp.dot(R1, h_y)
187
+ draw_new = draw + jnp.dot(Q, term)
188
+
189
+ # --- Adjust ---
190
+ # Derivative h'(arg)
191
+ diag_h_prime = _dh(arg)
192
+
193
+ # Diagonals of triangular matrices
194
+ d_r1 = jnp.diag(R1)
195
+ d_r2 = jnp.diag(R2)
196
+
197
+ # Log-det of triangular matrix
198
+ diag_term = 1.0 + diag_h_prime * d_r2 * d_r1
199
+ lja = jnp.sum(jnp.log(jnp.abs(diag_term)))
200
+
201
+ assert len(draw_new.shape) == 1
202
+ assert lja.shape == ()
203
+
204
+ return draw_new, lja
205
+
206
+ @eqx.filter_jit
207
+ def forward_and_adjust(self, draws: Float[Array, "n_draws n_dim"]) -> Tuple[Float[Array, "n_draws n_dim"], Scalar]:
208
+ f = jax.vmap(self.__forward_and_adjust, 0)
209
+ return f(draws)
210
+
211
+
212
+ class Sylvester(FlowSpec):
213
+ """
214
+ A specification for the Orthogonal Sylvester flow.
215
+ """
216
+ rank: int
217
+
218
+ def __init__(self, rank: int):
219
+ self.rank = rank
220
+
221
+ def construct(self, dim: int) -> SylvesterLayer:
222
+ if self.rank > dim:
223
+ raise ValueError(f"Rank {self.rank} cannot be greater than dimension {dim}.")
224
+
225
+ return SylvesterLayer(dim, self.rank)
@@ -0,0 +1,3 @@
1
+ from .continuous import Continuous as Continuous
2
+ from .observed import Observed as Observed
3
+ from .stochastic import Stochastic as Stochastic