bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/dists/posnormal.py DELETED
@@ -1,260 +0,0 @@
1
- import jax.numpy as jnp
2
- from jaxtyping import Array, ArrayLike, Float
3
-
4
- from bayinx.dists import normal
5
-
6
-
7
- def prob(
8
- x: Float[ArrayLike, "..."],
9
- mu: Float[ArrayLike, "..."],
10
- sigma: Float[ArrayLike, "..."],
11
- ) -> Float[Array, "..."]:
12
- """
13
- The probability density function (PDF) for a positive Normal distribution.
14
-
15
- # Parameters
16
- - `x`: Where to evaluate the PDF.
17
- - `mu`: The mean.
18
- - `sigma`: The standard deviation.
19
-
20
- # Returns
21
- The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
22
- """
23
- # Cast to Array
24
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
25
-
26
- # Construct boolean mask for non-negative elements
27
- non_negative: Array = jnp.asarray(0.0) <= x
28
-
29
- # Evaluate PDF
30
- evals = jnp.where(
31
- non_negative,
32
- normal.prob(x, mu, sigma) / normal.cdf(mu / sigma, 0.0, 1.0),
33
- jnp.asarray(0.0),
34
- )
35
-
36
- return evals
37
-
38
-
39
- def logprob(
40
- x: Float[ArrayLike, "..."],
41
- mu: Float[ArrayLike, "..."],
42
- sigma: Float[ArrayLike, "..."],
43
- ) -> Float[Array, "..."]:
44
- """
45
- The log of the probability density function (log PDF) for a positive Normal distribution.
46
-
47
- # Parameters
48
- - `x`: Where to evaluate the log PDF.
49
- - `mu`: The mean.
50
- - `sigma`: The standard deviation.
51
-
52
- # Returns
53
- The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
54
- """
55
- # Cast to Array
56
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
57
-
58
- # Construct boolean mask for non-negative elements
59
- non_negative: Array = jnp.asarray(0.0) <= x
60
-
61
- # Evaluate log PDF
62
- evals = jnp.where(
63
- non_negative,
64
- normal.logprob(x, mu, sigma) - normal.logcdf(mu / sigma, 0.0, 1.0),
65
- -jnp.inf,
66
- )
67
-
68
- return evals
69
-
70
-
71
- def uprob(
72
- x: Float[ArrayLike, "..."],
73
- mu: Float[ArrayLike, "..."],
74
- sigma: Float[ArrayLike, "..."],
75
- ) -> Float[Array, "..."]:
76
- """
77
- The unnormalized probability density function (uPDF) for a positive Normal distribution.
78
-
79
- # Parameters
80
- - `x`: Where to evaluate the uPDF.
81
- - `mu`: The mean.
82
- - `sigma`: The standard deviation.
83
-
84
- # Returns
85
- The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
86
- """
87
- # Cast to Array
88
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
89
-
90
- # Construct boolean mask for non-negative elements
91
- non_negative: Array = jnp.asarray(0.0) <= x
92
-
93
- # Evaluate PDF
94
- evals = jnp.where(non_negative, normal.prob(x, mu, sigma), jnp.asarray(0.0))
95
-
96
- return evals
97
-
98
-
99
- def ulogprob(
100
- x: Float[ArrayLike, "..."],
101
- mu: Float[ArrayLike, "..."],
102
- sigma: Float[ArrayLike, "..."],
103
- ) -> Float[Array, "..."]:
104
- """
105
- The log of the unnormalized probability density function (log uPDF) for a positive Normal distribution.
106
-
107
- # Parameters
108
- - `x`: Where to evaluate the log uPDF.
109
- - `mu`: The mean.
110
- - `sigma`: The standard deviation.
111
-
112
- # Returns
113
- The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
114
- """
115
- # Cast to Array
116
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
117
-
118
- # Construct boolean mask for non-negative elements
119
- non_negative: Array = jnp.asarray(0.0) <= x
120
-
121
- # Evaluate log PDF
122
- evals = jnp.where(non_negative, normal.logprob(x, mu, sigma), -jnp.inf)
123
-
124
- return evals
125
-
126
-
127
- def cdf(
128
- x: Float[ArrayLike, "..."],
129
- mu: Float[ArrayLike, "..."],
130
- sigma: Float[ArrayLike, "..."],
131
- ) -> Float[Array, "..."]:
132
- """
133
- The cumulative density function (CDF) for a positive Normal distribution.
134
-
135
- # Parameters
136
- - `x`: Where to evaluate the CDF.
137
- - `mu`: The mean.
138
- - `sigma`: The standard deviation.
139
-
140
- # Returns
141
- The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
142
-
143
- # Notes
144
- Not numerically stable for small `x`.
145
- """
146
- # Cast to Array
147
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
148
-
149
- # Construct boolean mask for non-negative elements
150
- non_negative: Array = jnp.asarray(0.0) <= x
151
-
152
- # Compute intermediates
153
- A: Array = normal.cdf(x, mu, sigma)
154
- B: Array = normal.cdf(-mu / sigma, 0.0, 1.0)
155
- C: Array = normal.cdf(mu / sigma, 0.0, 1.0)
156
-
157
- # Evaluate CDF
158
- evals = jnp.where(non_negative, (A - B) / C, jnp.asarray(0.0))
159
-
160
- return evals
161
-
162
-
163
- # TODO: make numerically stable
164
- def logcdf(
165
- x: Float[ArrayLike, "..."],
166
- mu: Float[ArrayLike, "..."],
167
- sigma: Float[ArrayLike, "..."],
168
- ) -> Float[Array, "..."]:
169
- """
170
- The log-transformed cumulative density function (log CDF) for a positive Normal distribution.
171
-
172
- # Parameters
173
- - `x`: Where to evaluate the log CDF.
174
- - `mu`: The mean.
175
- - `sigma`: The standard deviation.
176
-
177
- # Returns
178
- The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
179
-
180
- # Notes
181
- Not numerically stable for small `x`.
182
- """
183
- # Cast to Array
184
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
185
-
186
- # Construct boolean mask for non-negative elements
187
- non_negative: Array = jnp.asarray(0.0) <= x
188
-
189
- A: Array = normal.logcdf(x, mu, sigma)
190
- B: Array = normal.logcdf(-mu / sigma, 0.0, 1.0)
191
- C: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
192
-
193
- # Evaluate log CDF
194
- evals = jnp.where(non_negative, A + jnp.log1p(-jnp.exp(B - A)) - C, -jnp.inf)
195
-
196
- return evals
197
-
198
-
199
- def ccdf(
200
- x: Float[ArrayLike, "..."],
201
- mu: Float[ArrayLike, "..."],
202
- sigma: Float[ArrayLike, "..."],
203
- ) -> Float[Array, "..."]:
204
- """
205
- The complementary cumulative density function (cCDF) for a positive Normal distribution.
206
-
207
- # Parameters
208
- - `x`: Where to evaluate the cCDF.
209
- - `mu`: The mean.
210
- - `sigma`: The standard deviation.
211
-
212
- # Returns
213
- The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
214
- """
215
- # Cast to arrays
216
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
217
-
218
- # Construct boolean mask for non-negative elements
219
- non_negative: Array = 0.0 <= x
220
-
221
- # Compute intermediates
222
- A: Array = normal.cdf(-x, -mu, sigma)
223
- B: Array = normal.cdf(mu / sigma, 0.0, 1.0)
224
-
225
- # Evaluate cCDF
226
- evals = jnp.where(non_negative, A / B, jnp.asarray(1.0))
227
-
228
- return evals
229
-
230
-
231
- def logccdf(
232
- x: Float[ArrayLike, "..."],
233
- mu: Float[ArrayLike, "..."],
234
- sigma: Float[ArrayLike, "..."],
235
- ) -> Float[Array, "..."]:
236
- """
237
- The log-transformed complementary cumulative density function (log cCDF) for a positive Normal distribution.
238
-
239
- # Parameters
240
- - `x`: Where to evaluate the log cCDF.
241
- - `mu`: The mean.
242
- - `sigma`: The standard deviation.
243
-
244
- # Returns
245
- The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
246
- """
247
- # Cast to arrays
248
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
249
-
250
- # Construct boolean mask for non-negative elements
251
- non_negative: Array = 0.0 <= x
252
-
253
- # Compute intermediates
254
- A: Array = normal.logcdf(-x, -mu, sigma)
255
- B: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
256
-
257
- # Evaluate log cCDF
258
- evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
259
-
260
- return evals
bayinx/dists/uniform.py DELETED
@@ -1,75 +0,0 @@
1
- import jax.lax as _lax
2
- import jax.numpy as jnp
3
- from jaxtyping import Array, ArrayLike, Float
4
-
5
-
6
- def prob(
7
- x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
8
- ) -> Float[Array, "..."]:
9
- """
10
- The probability density function (PDF) for a Uniform distribution.
11
-
12
- # Parameters
13
- - `x`: Value(s) at which to evaluate the PDF.
14
- - `lb`: The lower bound parameter(s).
15
- - `ub`: The upper bound parameter(s).
16
-
17
- # Returns
18
- The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
19
- """
20
-
21
- return 1.0 / (ub - lb) # pyright: ignore
22
-
23
-
24
- def logprob(
25
- x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
26
- ) -> Float[Array, "..."]:
27
- """
28
- The log of the probability density function (log PDF) for a Uniform distribution.
29
-
30
- # Parameters
31
- - `x`: Value(s) at which to evaluate the PDF.
32
- - `lb`: The lower bound parameter(s).
33
- - `ub`: The upper bound parameter(s).
34
-
35
- # Returns
36
- The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
37
- """
38
-
39
- return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
40
-
41
-
42
- def uprob(
43
- x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
44
- ) -> Float[Array, "..."]:
45
- """
46
- The unnormalized probability density function (uPDF) for a Uniform distribution.
47
-
48
- # Parameters
49
- - `x`: Value(s) at which to evaluate the PDF.
50
- - `lb`: The lower bound parameter(s).
51
- - `ub`: The upper bound parameter(s).
52
-
53
- # Returns
54
- The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
55
- """
56
-
57
- return jnp.ones(jnp.broadcast_arrays(x, lb, ub))
58
-
59
-
60
- def ulogprob(
61
- x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
62
- ) -> Float[Array, "..."]:
63
- """
64
- The log of the unnormalized probability density function (log uPDF) for a Uniform distribution.
65
-
66
- # Parameters
67
- - `x`: Value(s) at which to evaluate the PDF.
68
- - `lb`: The lower bound parameter(s).
69
- - `ub`: The upper bound parameter(s).
70
-
71
- # Returns
72
- The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
73
- """
74
-
75
- return jnp.zeros(jnp.broadcast_arrays(x, lb, ub))
bayinx/mhx/__init__.py DELETED
@@ -1 +0,0 @@
1
-
bayinx/mhx/vi/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from bayinx.mhx.vi.meanfield import MeanField
2
- from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
3
- from bayinx.mhx.vi.standard import Standard
4
-
5
- __all__ = ['MeanField', 'NormalizingFlow', 'Standard']
@@ -1,3 +0,0 @@
1
- from bayinx.mhx.vi.flows.fullaffine import FullAffine as FullAffine
2
- from bayinx.mhx.vi.flows.planar import Planar as Planar
3
- from bayinx.mhx.vi.flows.radial import Radial as Radial
@@ -1,75 +0,0 @@
1
- from functools import partial
2
- from typing import Tuple
3
-
4
- import equinox as eqx
5
- import jax
6
- import jax.numpy as jnp
7
- from jaxtyping import Array, Scalar
8
-
9
- from bayinx.core import Flow
10
-
11
-
12
- class FullAffine(Flow):
13
- """
14
- A full affine flow.
15
-
16
- # Attributes
17
- - `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
18
- - `constraints`: A dictionary of constraining transformations.
19
- """
20
-
21
- def __init__(self, dim: int):
22
- """
23
- Initializes a full affine flow.
24
-
25
- # Parameters
26
- - `dim`: The dimension of the parameter space.
27
- """
28
- self.params = {
29
- "shift": jnp.zeros(dim),
30
- "scale": jnp.zeros((dim, dim)),
31
- }
32
-
33
- if dim == 1:
34
- self.constraints = {}
35
- else:
36
-
37
- @eqx.filter_jit
38
- def constrain_scale(scale: Array):
39
- # Extract diagonal and apply exponential
40
- diag: Array = jnp.exp(jnp.diag(scale))
41
-
42
- # Return matrix with modified diagonal
43
- return jnp.fill_diagonal(jnp.tril(scale), diag, inplace=False)
44
-
45
- self.constraints = {"scale": constrain_scale}
46
-
47
- @eqx.filter_jit
48
- def forward(self, draws: Array) -> Array:
49
- params = self.transform_params()
50
-
51
- # Extract parameters
52
- shift: Array = params["shift"]
53
- scale: Array = params["scale"]
54
-
55
- # Compute forward transformation
56
- draws = draws @ scale + shift
57
-
58
- return draws
59
-
60
- @eqx.filter_jit
61
- @partial(jax.vmap, in_axes=(None, 0))
62
- def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
63
- params = self.transform_params()
64
-
65
- # Extract parameters
66
- shift: Array = params["shift"]
67
- scale: Array = params["scale"]
68
-
69
- # Compute forward transformation
70
- draws = draws @ scale + shift
71
-
72
- # Compute laj
73
- laj: Scalar = jnp.log(jnp.diag(scale)).sum()
74
-
75
- return draws, laj
@@ -1,74 +0,0 @@
1
- from functools import partial
2
- from typing import Callable, Dict, Tuple
3
-
4
- import equinox as eqx
5
- import jax
6
- import jax.numpy as jnp
7
- import jax.random as jr
8
- from jaxtyping import Array, Float, Scalar
9
-
10
- from bayinx.core import Flow
11
-
12
-
13
- class Planar(Flow):
14
- """
15
- A planar flow.
16
-
17
- # Attributes
18
- - `params`: A dictionary containing the JAX Arrays representing the flow parameters.
19
- - `constraints`: A dictionary of constraining transformations.
20
- """
21
-
22
- params: Dict[str, Float[Array, "..."]]
23
- constraints: Dict[str, Callable[[Array], Array]]
24
-
25
- def __init__(self, dim: int, key=jr.PRNGKey(0)):
26
- """
27
- Initializes a planar flow.
28
-
29
- # Parameters
30
- - `dim`: The dimension of the parameter space.
31
- """
32
- self.params = {
33
- "u": jnp.zeros(dim),
34
- "w": jnp.ones(dim),
35
- "b": jnp.zeros(1),
36
- }
37
- self.constraints = {}
38
-
39
- @eqx.filter_jit
40
- @partial(jax.vmap, in_axes=(None, 0))
41
- def forward(self, draws: Array) -> Array:
42
- params = self.transform_params()
43
-
44
- # Extract parameters
45
- w: Array = params["w"]
46
- u: Array = params["u"]
47
- b: Array = params["b"]
48
-
49
- # Compute forward transformation
50
- draws = draws + u * jnp.tanh(draws.dot(w) + b)
51
-
52
- return draws
53
-
54
- @eqx.filter_jit
55
- @partial(jax.vmap, in_axes=(None, 0))
56
- def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
57
- params = self.transform_params()
58
-
59
- # Extract parameters
60
- w: Array = params["w"]
61
- u: Array = params["u"]
62
- b: Array = params["b"]
63
-
64
- # Compute shared intermediates
65
- x: Array = draws.dot(w) + b
66
-
67
- # Compute forward transformation
68
- draws = draws + u * jnp.tanh(x)
69
-
70
- # Compute laj
71
- h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
72
- laj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
73
-
74
- return draws, laj
@@ -1,94 +0,0 @@
1
- from functools import partial
2
- from typing import Callable, Dict, Tuple
3
-
4
- import equinox as eqx
5
- import jax
6
- import jax.numpy as jnp
7
- import jax.random as jr
8
- from jax.numpy.linalg import norm
9
- from jaxtyping import Array, Float, Scalar
10
-
11
- from bayinx.core import Flow
12
-
13
-
14
- class Radial(Flow):
15
- """
16
- A radial flow.
17
-
18
- # Attributes
19
- - `params`: A dictionary containing the JAX Arrays representing the flow parameters.
20
- - `constraints`: A dictionary of constraining transformations.
21
- """
22
-
23
- params: Dict[str, Float[Array, "..."]]
24
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
25
-
26
- def __init__(self, dim: int, key=jr.PRNGKey(0)):
27
- """
28
- Initializes a planar flow.
29
-
30
- # Parameters
31
- - `dim`: The dimension of the parameter space.
32
- """
33
- self.params = {
34
- "alpha": jnp.array(1.0),
35
- "beta": jnp.array(1.0),
36
- "center": jnp.ones(dim),
37
- }
38
- self.constraints = {"beta": jnp.exp}
39
-
40
- @partial(jax.vmap, in_axes=(None, 0))
41
- @eqx.filter_jit
42
- def forward(self, draws: Array) -> Array:
43
- """
44
- Applies the forward radial transformation for each draw.
45
-
46
- # Parameters
47
- - `draws`: Draws from some layer of a normalizing flow.
48
-
49
- # Returns
50
- The transformed samples.
51
- """
52
- params = self.transform_params()
53
-
54
- # Extract parameters
55
- alpha = params["alpha"]
56
- beta = params["beta"]
57
- center = params["center"]
58
-
59
- # Compute distance to center per-draw
60
- r: Array = norm(draws - center)
61
-
62
- # Apply forward transformation
63
- draws = draws + (beta / (alpha + r)) * (draws - center)
64
-
65
- return draws
66
-
67
- @partial(jax.vmap, in_axes=(None, 0))
68
- @eqx.filter_jit
69
- def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
70
- params = self.transform_params()
71
-
72
- # Extract parameters
73
- alpha = params["alpha"]
74
- beta = params["beta"]
75
- center = params["center"]
76
-
77
- # Compute distance to center per-draw
78
- r: Array = norm(draws - center)
79
-
80
- # Compute shared intermediates
81
- x: Array = beta / (alpha + r)
82
-
83
- # Apply forward transformation
84
- draws = draws + (x) * (draws - center)
85
-
86
- # Compute density adjustment
87
- laj = jnp.log(
88
- jnp.abs(
89
- (1.0 + alpha * beta / (alpha + r) ** 2.0)
90
- * (1.0 + x) ** (center.size - 1.0)
91
- )
92
- )
93
-
94
- return draws, laj
@@ -1,19 +0,0 @@
1
- from typing import Callable, Dict
2
-
3
- from jaxtyping import Array, Float
4
-
5
- from bayinx.core import Flow
6
-
7
-
8
- # TODO
9
- class Sylvester(Flow):
10
- """
11
- A sylvester flow.
12
-
13
- # Attributes
14
- - `params`: A dictionary containing the JAX Arrays representing the flow parameters.
15
- - `constraints`: A dictionary of constraining transformations.
16
- """
17
-
18
- params: Dict[str, Float[Array, "..."]]
19
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]