bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +4 -3
- bayinx/constraints/identity.py +26 -0
- bayinx/constraints/interval.py +62 -0
- bayinx/constraints/lower.py +31 -24
- bayinx/constraints/upper.py +57 -0
- bayinx/core/__init__.py +0 -7
- bayinx/core/constraint.py +32 -0
- bayinx/core/context.py +42 -0
- bayinx/core/distribution.py +34 -0
- bayinx/core/flow.py +99 -0
- bayinx/core/model.py +228 -0
- bayinx/core/node.py +201 -0
- bayinx/core/types.py +17 -0
- bayinx/core/utils.py +109 -0
- bayinx/core/variational.py +170 -0
- bayinx/dists/__init__.py +5 -3
- bayinx/dists/bernoulli.py +180 -11
- bayinx/dists/binomial.py +215 -0
- bayinx/dists/exponential.py +211 -0
- bayinx/dists/normal.py +131 -59
- bayinx/dists/poisson.py +203 -0
- bayinx/flows/__init__.py +5 -0
- bayinx/flows/diagaffine.py +120 -0
- bayinx/flows/fullaffine.py +123 -0
- bayinx/flows/lowrankaffine.py +165 -0
- bayinx/flows/planar.py +155 -0
- bayinx/flows/radial.py +1 -0
- bayinx/flows/sylvester.py +225 -0
- bayinx/nodes/__init__.py +3 -0
- bayinx/nodes/continuous.py +64 -0
- bayinx/nodes/observed.py +36 -0
- bayinx/nodes/stochastic.py +25 -0
- bayinx/ops.py +104 -0
- bayinx/posterior.py +220 -0
- bayinx/vi/__init__.py +0 -0
- bayinx/{mhx/vi → vi}/meanfield.py +33 -29
- bayinx/vi/normalizing_flow.py +246 -0
- bayinx/vi/standard.py +95 -0
- bayinx-0.5.3.dist-info/METADATA +93 -0
- bayinx-0.5.3.dist-info/RECORD +44 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
- bayinx/core/_constraint.py +0 -28
- bayinx/core/_flow.py +0 -80
- bayinx/core/_model.py +0 -98
- bayinx/core/_parameter.py +0 -44
- bayinx/core/_variational.py +0 -181
- bayinx/dists/censored/__init__.py +0 -3
- bayinx/dists/censored/gamma2/__init__.py +0 -3
- bayinx/dists/censored/gamma2/r.py +0 -68
- bayinx/dists/censored/posnormal/__init__.py +0 -3
- bayinx/dists/censored/posnormal/r.py +0 -116
- bayinx/dists/gamma2.py +0 -49
- bayinx/dists/posnormal.py +0 -260
- bayinx/dists/uniform.py +0 -75
- bayinx/mhx/__init__.py +0 -1
- bayinx/mhx/vi/__init__.py +0 -5
- bayinx/mhx/vi/flows/__init__.py +0 -3
- bayinx/mhx/vi/flows/fullaffine.py +0 -75
- bayinx/mhx/vi/flows/planar.py +0 -74
- bayinx/mhx/vi/flows/radial.py +0 -94
- bayinx/mhx/vi/flows/sylvester.py +0 -19
- bayinx/mhx/vi/normalizing_flow.py +0 -149
- bayinx/mhx/vi/standard.py +0 -63
- bayinx-0.3.10.dist-info/METADATA +0 -39
- bayinx-0.3.10.dist-info/RECORD +0 -35
- /bayinx/{py.typed → flows/otflow.py} +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/dists/posnormal.py
DELETED
|
@@ -1,260 +0,0 @@
|
|
|
1
|
-
import jax.numpy as jnp
|
|
2
|
-
from jaxtyping import Array, ArrayLike, Float
|
|
3
|
-
|
|
4
|
-
from bayinx.dists import normal
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def prob(
|
|
8
|
-
x: Float[ArrayLike, "..."],
|
|
9
|
-
mu: Float[ArrayLike, "..."],
|
|
10
|
-
sigma: Float[ArrayLike, "..."],
|
|
11
|
-
) -> Float[Array, "..."]:
|
|
12
|
-
"""
|
|
13
|
-
The probability density function (PDF) for a positive Normal distribution.
|
|
14
|
-
|
|
15
|
-
# Parameters
|
|
16
|
-
- `x`: Where to evaluate the PDF.
|
|
17
|
-
- `mu`: The mean.
|
|
18
|
-
- `sigma`: The standard deviation.
|
|
19
|
-
|
|
20
|
-
# Returns
|
|
21
|
-
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
22
|
-
"""
|
|
23
|
-
# Cast to Array
|
|
24
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
25
|
-
|
|
26
|
-
# Construct boolean mask for non-negative elements
|
|
27
|
-
non_negative: Array = jnp.asarray(0.0) <= x
|
|
28
|
-
|
|
29
|
-
# Evaluate PDF
|
|
30
|
-
evals = jnp.where(
|
|
31
|
-
non_negative,
|
|
32
|
-
normal.prob(x, mu, sigma) / normal.cdf(mu / sigma, 0.0, 1.0),
|
|
33
|
-
jnp.asarray(0.0),
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
return evals
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def logprob(
|
|
40
|
-
x: Float[ArrayLike, "..."],
|
|
41
|
-
mu: Float[ArrayLike, "..."],
|
|
42
|
-
sigma: Float[ArrayLike, "..."],
|
|
43
|
-
) -> Float[Array, "..."]:
|
|
44
|
-
"""
|
|
45
|
-
The log of the probability density function (log PDF) for a positive Normal distribution.
|
|
46
|
-
|
|
47
|
-
# Parameters
|
|
48
|
-
- `x`: Where to evaluate the log PDF.
|
|
49
|
-
- `mu`: The mean.
|
|
50
|
-
- `sigma`: The standard deviation.
|
|
51
|
-
|
|
52
|
-
# Returns
|
|
53
|
-
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
54
|
-
"""
|
|
55
|
-
# Cast to Array
|
|
56
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
57
|
-
|
|
58
|
-
# Construct boolean mask for non-negative elements
|
|
59
|
-
non_negative: Array = jnp.asarray(0.0) <= x
|
|
60
|
-
|
|
61
|
-
# Evaluate log PDF
|
|
62
|
-
evals = jnp.where(
|
|
63
|
-
non_negative,
|
|
64
|
-
normal.logprob(x, mu, sigma) - normal.logcdf(mu / sigma, 0.0, 1.0),
|
|
65
|
-
-jnp.inf,
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
return evals
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def uprob(
|
|
72
|
-
x: Float[ArrayLike, "..."],
|
|
73
|
-
mu: Float[ArrayLike, "..."],
|
|
74
|
-
sigma: Float[ArrayLike, "..."],
|
|
75
|
-
) -> Float[Array, "..."]:
|
|
76
|
-
"""
|
|
77
|
-
The unnormalized probability density function (uPDF) for a positive Normal distribution.
|
|
78
|
-
|
|
79
|
-
# Parameters
|
|
80
|
-
- `x`: Where to evaluate the uPDF.
|
|
81
|
-
- `mu`: The mean.
|
|
82
|
-
- `sigma`: The standard deviation.
|
|
83
|
-
|
|
84
|
-
# Returns
|
|
85
|
-
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
86
|
-
"""
|
|
87
|
-
# Cast to Array
|
|
88
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
89
|
-
|
|
90
|
-
# Construct boolean mask for non-negative elements
|
|
91
|
-
non_negative: Array = jnp.asarray(0.0) <= x
|
|
92
|
-
|
|
93
|
-
# Evaluate PDF
|
|
94
|
-
evals = jnp.where(non_negative, normal.prob(x, mu, sigma), jnp.asarray(0.0))
|
|
95
|
-
|
|
96
|
-
return evals
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def ulogprob(
|
|
100
|
-
x: Float[ArrayLike, "..."],
|
|
101
|
-
mu: Float[ArrayLike, "..."],
|
|
102
|
-
sigma: Float[ArrayLike, "..."],
|
|
103
|
-
) -> Float[Array, "..."]:
|
|
104
|
-
"""
|
|
105
|
-
The log of the unnormalized probability density function (log uPDF) for a positive Normal distribution.
|
|
106
|
-
|
|
107
|
-
# Parameters
|
|
108
|
-
- `x`: Where to evaluate the log uPDF.
|
|
109
|
-
- `mu`: The mean.
|
|
110
|
-
- `sigma`: The standard deviation.
|
|
111
|
-
|
|
112
|
-
# Returns
|
|
113
|
-
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
114
|
-
"""
|
|
115
|
-
# Cast to Array
|
|
116
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
117
|
-
|
|
118
|
-
# Construct boolean mask for non-negative elements
|
|
119
|
-
non_negative: Array = jnp.asarray(0.0) <= x
|
|
120
|
-
|
|
121
|
-
# Evaluate log PDF
|
|
122
|
-
evals = jnp.where(non_negative, normal.logprob(x, mu, sigma), -jnp.inf)
|
|
123
|
-
|
|
124
|
-
return evals
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def cdf(
|
|
128
|
-
x: Float[ArrayLike, "..."],
|
|
129
|
-
mu: Float[ArrayLike, "..."],
|
|
130
|
-
sigma: Float[ArrayLike, "..."],
|
|
131
|
-
) -> Float[Array, "..."]:
|
|
132
|
-
"""
|
|
133
|
-
The cumulative density function (CDF) for a positive Normal distribution.
|
|
134
|
-
|
|
135
|
-
# Parameters
|
|
136
|
-
- `x`: Where to evaluate the CDF.
|
|
137
|
-
- `mu`: The mean.
|
|
138
|
-
- `sigma`: The standard deviation.
|
|
139
|
-
|
|
140
|
-
# Returns
|
|
141
|
-
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
142
|
-
|
|
143
|
-
# Notes
|
|
144
|
-
Not numerically stable for small `x`.
|
|
145
|
-
"""
|
|
146
|
-
# Cast to Array
|
|
147
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
148
|
-
|
|
149
|
-
# Construct boolean mask for non-negative elements
|
|
150
|
-
non_negative: Array = jnp.asarray(0.0) <= x
|
|
151
|
-
|
|
152
|
-
# Compute intermediates
|
|
153
|
-
A: Array = normal.cdf(x, mu, sigma)
|
|
154
|
-
B: Array = normal.cdf(-mu / sigma, 0.0, 1.0)
|
|
155
|
-
C: Array = normal.cdf(mu / sigma, 0.0, 1.0)
|
|
156
|
-
|
|
157
|
-
# Evaluate CDF
|
|
158
|
-
evals = jnp.where(non_negative, (A - B) / C, jnp.asarray(0.0))
|
|
159
|
-
|
|
160
|
-
return evals
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
# TODO: make numerically stable
|
|
164
|
-
def logcdf(
|
|
165
|
-
x: Float[ArrayLike, "..."],
|
|
166
|
-
mu: Float[ArrayLike, "..."],
|
|
167
|
-
sigma: Float[ArrayLike, "..."],
|
|
168
|
-
) -> Float[Array, "..."]:
|
|
169
|
-
"""
|
|
170
|
-
The log-transformed cumulative density function (log CDF) for a positive Normal distribution.
|
|
171
|
-
|
|
172
|
-
# Parameters
|
|
173
|
-
- `x`: Where to evaluate the log CDF.
|
|
174
|
-
- `mu`: The mean.
|
|
175
|
-
- `sigma`: The standard deviation.
|
|
176
|
-
|
|
177
|
-
# Returns
|
|
178
|
-
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
179
|
-
|
|
180
|
-
# Notes
|
|
181
|
-
Not numerically stable for small `x`.
|
|
182
|
-
"""
|
|
183
|
-
# Cast to Array
|
|
184
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
185
|
-
|
|
186
|
-
# Construct boolean mask for non-negative elements
|
|
187
|
-
non_negative: Array = jnp.asarray(0.0) <= x
|
|
188
|
-
|
|
189
|
-
A: Array = normal.logcdf(x, mu, sigma)
|
|
190
|
-
B: Array = normal.logcdf(-mu / sigma, 0.0, 1.0)
|
|
191
|
-
C: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
|
|
192
|
-
|
|
193
|
-
# Evaluate log CDF
|
|
194
|
-
evals = jnp.where(non_negative, A + jnp.log1p(-jnp.exp(B - A)) - C, -jnp.inf)
|
|
195
|
-
|
|
196
|
-
return evals
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
def ccdf(
|
|
200
|
-
x: Float[ArrayLike, "..."],
|
|
201
|
-
mu: Float[ArrayLike, "..."],
|
|
202
|
-
sigma: Float[ArrayLike, "..."],
|
|
203
|
-
) -> Float[Array, "..."]:
|
|
204
|
-
"""
|
|
205
|
-
The complementary cumulative density function (cCDF) for a positive Normal distribution.
|
|
206
|
-
|
|
207
|
-
# Parameters
|
|
208
|
-
- `x`: Where to evaluate the cCDF.
|
|
209
|
-
- `mu`: The mean.
|
|
210
|
-
- `sigma`: The standard deviation.
|
|
211
|
-
|
|
212
|
-
# Returns
|
|
213
|
-
The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
214
|
-
"""
|
|
215
|
-
# Cast to arrays
|
|
216
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
217
|
-
|
|
218
|
-
# Construct boolean mask for non-negative elements
|
|
219
|
-
non_negative: Array = 0.0 <= x
|
|
220
|
-
|
|
221
|
-
# Compute intermediates
|
|
222
|
-
A: Array = normal.cdf(-x, -mu, sigma)
|
|
223
|
-
B: Array = normal.cdf(mu / sigma, 0.0, 1.0)
|
|
224
|
-
|
|
225
|
-
# Evaluate cCDF
|
|
226
|
-
evals = jnp.where(non_negative, A / B, jnp.asarray(1.0))
|
|
227
|
-
|
|
228
|
-
return evals
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def logccdf(
|
|
232
|
-
x: Float[ArrayLike, "..."],
|
|
233
|
-
mu: Float[ArrayLike, "..."],
|
|
234
|
-
sigma: Float[ArrayLike, "..."],
|
|
235
|
-
) -> Float[Array, "..."]:
|
|
236
|
-
"""
|
|
237
|
-
The log-transformed complementary cumulative density function (log cCDF) for a positive Normal distribution.
|
|
238
|
-
|
|
239
|
-
# Parameters
|
|
240
|
-
- `x`: Where to evaluate the log cCDF.
|
|
241
|
-
- `mu`: The mean.
|
|
242
|
-
- `sigma`: The standard deviation.
|
|
243
|
-
|
|
244
|
-
# Returns
|
|
245
|
-
The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
246
|
-
"""
|
|
247
|
-
# Cast to arrays
|
|
248
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
249
|
-
|
|
250
|
-
# Construct boolean mask for non-negative elements
|
|
251
|
-
non_negative: Array = 0.0 <= x
|
|
252
|
-
|
|
253
|
-
# Compute intermediates
|
|
254
|
-
A: Array = normal.logcdf(-x, -mu, sigma)
|
|
255
|
-
B: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
|
|
256
|
-
|
|
257
|
-
# Evaluate log cCDF
|
|
258
|
-
evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
|
|
259
|
-
|
|
260
|
-
return evals
|
bayinx/dists/uniform.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
import jax.lax as _lax
|
|
2
|
-
import jax.numpy as jnp
|
|
3
|
-
from jaxtyping import Array, ArrayLike, Float
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def prob(
|
|
7
|
-
x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
|
|
8
|
-
) -> Float[Array, "..."]:
|
|
9
|
-
"""
|
|
10
|
-
The probability density function (PDF) for a Uniform distribution.
|
|
11
|
-
|
|
12
|
-
# Parameters
|
|
13
|
-
- `x`: Value(s) at which to evaluate the PDF.
|
|
14
|
-
- `lb`: The lower bound parameter(s).
|
|
15
|
-
- `ub`: The upper bound parameter(s).
|
|
16
|
-
|
|
17
|
-
# Returns
|
|
18
|
-
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
return 1.0 / (ub - lb) # pyright: ignore
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def logprob(
|
|
25
|
-
x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
|
|
26
|
-
) -> Float[Array, "..."]:
|
|
27
|
-
"""
|
|
28
|
-
The log of the probability density function (log PDF) for a Uniform distribution.
|
|
29
|
-
|
|
30
|
-
# Parameters
|
|
31
|
-
- `x`: Value(s) at which to evaluate the PDF.
|
|
32
|
-
- `lb`: The lower bound parameter(s).
|
|
33
|
-
- `ub`: The upper bound parameter(s).
|
|
34
|
-
|
|
35
|
-
# Returns
|
|
36
|
-
The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def uprob(
|
|
43
|
-
x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
|
|
44
|
-
) -> Float[Array, "..."]:
|
|
45
|
-
"""
|
|
46
|
-
The unnormalized probability density function (uPDF) for a Uniform distribution.
|
|
47
|
-
|
|
48
|
-
# Parameters
|
|
49
|
-
- `x`: Value(s) at which to evaluate the PDF.
|
|
50
|
-
- `lb`: The lower bound parameter(s).
|
|
51
|
-
- `ub`: The upper bound parameter(s).
|
|
52
|
-
|
|
53
|
-
# Returns
|
|
54
|
-
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
return jnp.ones(jnp.broadcast_arrays(x, lb, ub))
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def ulogprob(
|
|
61
|
-
x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
|
|
62
|
-
) -> Float[Array, "..."]:
|
|
63
|
-
"""
|
|
64
|
-
The log of the unnormalized probability density function (log uPDF) for a Uniform distribution.
|
|
65
|
-
|
|
66
|
-
# Parameters
|
|
67
|
-
- `x`: Value(s) at which to evaluate the PDF.
|
|
68
|
-
- `lb`: The lower bound parameter(s).
|
|
69
|
-
- `ub`: The upper bound parameter(s).
|
|
70
|
-
|
|
71
|
-
# Returns
|
|
72
|
-
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
return jnp.zeros(jnp.broadcast_arrays(x, lb, ub))
|
bayinx/mhx/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
|
bayinx/mhx/vi/__init__.py
DELETED
bayinx/mhx/vi/flows/__init__.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
from functools import partial
|
|
2
|
-
from typing import Tuple
|
|
3
|
-
|
|
4
|
-
import equinox as eqx
|
|
5
|
-
import jax
|
|
6
|
-
import jax.numpy as jnp
|
|
7
|
-
from jaxtyping import Array, Scalar
|
|
8
|
-
|
|
9
|
-
from bayinx.core import Flow
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class FullAffine(Flow):
|
|
13
|
-
"""
|
|
14
|
-
A full affine flow.
|
|
15
|
-
|
|
16
|
-
# Attributes
|
|
17
|
-
- `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
|
|
18
|
-
- `constraints`: A dictionary of constraining transformations.
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
def __init__(self, dim: int):
|
|
22
|
-
"""
|
|
23
|
-
Initializes a full affine flow.
|
|
24
|
-
|
|
25
|
-
# Parameters
|
|
26
|
-
- `dim`: The dimension of the parameter space.
|
|
27
|
-
"""
|
|
28
|
-
self.params = {
|
|
29
|
-
"shift": jnp.zeros(dim),
|
|
30
|
-
"scale": jnp.zeros((dim, dim)),
|
|
31
|
-
}
|
|
32
|
-
|
|
33
|
-
if dim == 1:
|
|
34
|
-
self.constraints = {}
|
|
35
|
-
else:
|
|
36
|
-
|
|
37
|
-
@eqx.filter_jit
|
|
38
|
-
def constrain_scale(scale: Array):
|
|
39
|
-
# Extract diagonal and apply exponential
|
|
40
|
-
diag: Array = jnp.exp(jnp.diag(scale))
|
|
41
|
-
|
|
42
|
-
# Return matrix with modified diagonal
|
|
43
|
-
return jnp.fill_diagonal(jnp.tril(scale), diag, inplace=False)
|
|
44
|
-
|
|
45
|
-
self.constraints = {"scale": constrain_scale}
|
|
46
|
-
|
|
47
|
-
@eqx.filter_jit
|
|
48
|
-
def forward(self, draws: Array) -> Array:
|
|
49
|
-
params = self.transform_params()
|
|
50
|
-
|
|
51
|
-
# Extract parameters
|
|
52
|
-
shift: Array = params["shift"]
|
|
53
|
-
scale: Array = params["scale"]
|
|
54
|
-
|
|
55
|
-
# Compute forward transformation
|
|
56
|
-
draws = draws @ scale + shift
|
|
57
|
-
|
|
58
|
-
return draws
|
|
59
|
-
|
|
60
|
-
@eqx.filter_jit
|
|
61
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
|
62
|
-
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
|
63
|
-
params = self.transform_params()
|
|
64
|
-
|
|
65
|
-
# Extract parameters
|
|
66
|
-
shift: Array = params["shift"]
|
|
67
|
-
scale: Array = params["scale"]
|
|
68
|
-
|
|
69
|
-
# Compute forward transformation
|
|
70
|
-
draws = draws @ scale + shift
|
|
71
|
-
|
|
72
|
-
# Compute laj
|
|
73
|
-
laj: Scalar = jnp.log(jnp.diag(scale)).sum()
|
|
74
|
-
|
|
75
|
-
return draws, laj
|
bayinx/mhx/vi/flows/planar.py
DELETED
|
@@ -1,74 +0,0 @@
|
|
|
1
|
-
from functools import partial
|
|
2
|
-
from typing import Callable, Dict, Tuple
|
|
3
|
-
|
|
4
|
-
import equinox as eqx
|
|
5
|
-
import jax
|
|
6
|
-
import jax.numpy as jnp
|
|
7
|
-
import jax.random as jr
|
|
8
|
-
from jaxtyping import Array, Float, Scalar
|
|
9
|
-
|
|
10
|
-
from bayinx.core import Flow
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class Planar(Flow):
|
|
14
|
-
"""
|
|
15
|
-
A planar flow.
|
|
16
|
-
|
|
17
|
-
# Attributes
|
|
18
|
-
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
|
19
|
-
- `constraints`: A dictionary of constraining transformations.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
params: Dict[str, Float[Array, "..."]]
|
|
23
|
-
constraints: Dict[str, Callable[[Array], Array]]
|
|
24
|
-
|
|
25
|
-
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
|
26
|
-
"""
|
|
27
|
-
Initializes a planar flow.
|
|
28
|
-
|
|
29
|
-
# Parameters
|
|
30
|
-
- `dim`: The dimension of the parameter space.
|
|
31
|
-
"""
|
|
32
|
-
self.params = {
|
|
33
|
-
"u": jnp.zeros(dim),
|
|
34
|
-
"w": jnp.ones(dim),
|
|
35
|
-
"b": jnp.zeros(1),
|
|
36
|
-
}
|
|
37
|
-
self.constraints = {}
|
|
38
|
-
|
|
39
|
-
@eqx.filter_jit
|
|
40
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
|
41
|
-
def forward(self, draws: Array) -> Array:
|
|
42
|
-
params = self.transform_params()
|
|
43
|
-
|
|
44
|
-
# Extract parameters
|
|
45
|
-
w: Array = params["w"]
|
|
46
|
-
u: Array = params["u"]
|
|
47
|
-
b: Array = params["b"]
|
|
48
|
-
|
|
49
|
-
# Compute forward transformation
|
|
50
|
-
draws = draws + u * jnp.tanh(draws.dot(w) + b)
|
|
51
|
-
|
|
52
|
-
return draws
|
|
53
|
-
|
|
54
|
-
@eqx.filter_jit
|
|
55
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
|
56
|
-
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
|
57
|
-
params = self.transform_params()
|
|
58
|
-
|
|
59
|
-
# Extract parameters
|
|
60
|
-
w: Array = params["w"]
|
|
61
|
-
u: Array = params["u"]
|
|
62
|
-
b: Array = params["b"]
|
|
63
|
-
|
|
64
|
-
# Compute shared intermediates
|
|
65
|
-
x: Array = draws.dot(w) + b
|
|
66
|
-
|
|
67
|
-
# Compute forward transformation
|
|
68
|
-
draws = draws + u * jnp.tanh(x)
|
|
69
|
-
|
|
70
|
-
# Compute laj
|
|
71
|
-
h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
|
|
72
|
-
laj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
|
|
73
|
-
|
|
74
|
-
return draws, laj
|
bayinx/mhx/vi/flows/radial.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
|
1
|
-
from functools import partial
|
|
2
|
-
from typing import Callable, Dict, Tuple
|
|
3
|
-
|
|
4
|
-
import equinox as eqx
|
|
5
|
-
import jax
|
|
6
|
-
import jax.numpy as jnp
|
|
7
|
-
import jax.random as jr
|
|
8
|
-
from jax.numpy.linalg import norm
|
|
9
|
-
from jaxtyping import Array, Float, Scalar
|
|
10
|
-
|
|
11
|
-
from bayinx.core import Flow
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class Radial(Flow):
|
|
15
|
-
"""
|
|
16
|
-
A radial flow.
|
|
17
|
-
|
|
18
|
-
# Attributes
|
|
19
|
-
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
|
20
|
-
- `constraints`: A dictionary of constraining transformations.
|
|
21
|
-
"""
|
|
22
|
-
|
|
23
|
-
params: Dict[str, Float[Array, "..."]]
|
|
24
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
|
25
|
-
|
|
26
|
-
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
|
27
|
-
"""
|
|
28
|
-
Initializes a planar flow.
|
|
29
|
-
|
|
30
|
-
# Parameters
|
|
31
|
-
- `dim`: The dimension of the parameter space.
|
|
32
|
-
"""
|
|
33
|
-
self.params = {
|
|
34
|
-
"alpha": jnp.array(1.0),
|
|
35
|
-
"beta": jnp.array(1.0),
|
|
36
|
-
"center": jnp.ones(dim),
|
|
37
|
-
}
|
|
38
|
-
self.constraints = {"beta": jnp.exp}
|
|
39
|
-
|
|
40
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
|
41
|
-
@eqx.filter_jit
|
|
42
|
-
def forward(self, draws: Array) -> Array:
|
|
43
|
-
"""
|
|
44
|
-
Applies the forward radial transformation for each draw.
|
|
45
|
-
|
|
46
|
-
# Parameters
|
|
47
|
-
- `draws`: Draws from some layer of a normalizing flow.
|
|
48
|
-
|
|
49
|
-
# Returns
|
|
50
|
-
The transformed samples.
|
|
51
|
-
"""
|
|
52
|
-
params = self.transform_params()
|
|
53
|
-
|
|
54
|
-
# Extract parameters
|
|
55
|
-
alpha = params["alpha"]
|
|
56
|
-
beta = params["beta"]
|
|
57
|
-
center = params["center"]
|
|
58
|
-
|
|
59
|
-
# Compute distance to center per-draw
|
|
60
|
-
r: Array = norm(draws - center)
|
|
61
|
-
|
|
62
|
-
# Apply forward transformation
|
|
63
|
-
draws = draws + (beta / (alpha + r)) * (draws - center)
|
|
64
|
-
|
|
65
|
-
return draws
|
|
66
|
-
|
|
67
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
|
68
|
-
@eqx.filter_jit
|
|
69
|
-
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
|
70
|
-
params = self.transform_params()
|
|
71
|
-
|
|
72
|
-
# Extract parameters
|
|
73
|
-
alpha = params["alpha"]
|
|
74
|
-
beta = params["beta"]
|
|
75
|
-
center = params["center"]
|
|
76
|
-
|
|
77
|
-
# Compute distance to center per-draw
|
|
78
|
-
r: Array = norm(draws - center)
|
|
79
|
-
|
|
80
|
-
# Compute shared intermediates
|
|
81
|
-
x: Array = beta / (alpha + r)
|
|
82
|
-
|
|
83
|
-
# Apply forward transformation
|
|
84
|
-
draws = draws + (x) * (draws - center)
|
|
85
|
-
|
|
86
|
-
# Compute density adjustment
|
|
87
|
-
laj = jnp.log(
|
|
88
|
-
jnp.abs(
|
|
89
|
-
(1.0 + alpha * beta / (alpha + r) ** 2.0)
|
|
90
|
-
* (1.0 + x) ** (center.size - 1.0)
|
|
91
|
-
)
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
return draws, laj
|
bayinx/mhx/vi/flows/sylvester.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
from typing import Callable, Dict
|
|
2
|
-
|
|
3
|
-
from jaxtyping import Array, Float
|
|
4
|
-
|
|
5
|
-
from bayinx.core import Flow
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
# TODO
|
|
9
|
-
class Sylvester(Flow):
|
|
10
|
-
"""
|
|
11
|
-
A sylvester flow.
|
|
12
|
-
|
|
13
|
-
# Attributes
|
|
14
|
-
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
|
15
|
-
- `constraints`: A dictionary of constraining transformations.
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
params: Dict[str, Float[Array, "..."]]
|
|
19
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|