bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +4 -3
- bayinx/constraints/identity.py +26 -0
- bayinx/constraints/interval.py +62 -0
- bayinx/constraints/lower.py +31 -24
- bayinx/constraints/upper.py +57 -0
- bayinx/core/__init__.py +0 -7
- bayinx/core/constraint.py +32 -0
- bayinx/core/context.py +42 -0
- bayinx/core/distribution.py +34 -0
- bayinx/core/flow.py +99 -0
- bayinx/core/model.py +228 -0
- bayinx/core/node.py +201 -0
- bayinx/core/types.py +17 -0
- bayinx/core/utils.py +109 -0
- bayinx/core/variational.py +170 -0
- bayinx/dists/__init__.py +5 -3
- bayinx/dists/bernoulli.py +180 -11
- bayinx/dists/binomial.py +215 -0
- bayinx/dists/exponential.py +211 -0
- bayinx/dists/normal.py +131 -59
- bayinx/dists/poisson.py +203 -0
- bayinx/flows/__init__.py +5 -0
- bayinx/flows/diagaffine.py +120 -0
- bayinx/flows/fullaffine.py +123 -0
- bayinx/flows/lowrankaffine.py +165 -0
- bayinx/flows/planar.py +155 -0
- bayinx/flows/radial.py +1 -0
- bayinx/flows/sylvester.py +225 -0
- bayinx/nodes/__init__.py +3 -0
- bayinx/nodes/continuous.py +64 -0
- bayinx/nodes/observed.py +36 -0
- bayinx/nodes/stochastic.py +25 -0
- bayinx/ops.py +104 -0
- bayinx/posterior.py +220 -0
- bayinx/vi/__init__.py +0 -0
- bayinx/{mhx/vi → vi}/meanfield.py +33 -29
- bayinx/vi/normalizing_flow.py +246 -0
- bayinx/vi/standard.py +95 -0
- bayinx-0.5.3.dist-info/METADATA +93 -0
- bayinx-0.5.3.dist-info/RECORD +44 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
- bayinx/core/_constraint.py +0 -28
- bayinx/core/_flow.py +0 -80
- bayinx/core/_model.py +0 -98
- bayinx/core/_parameter.py +0 -44
- bayinx/core/_variational.py +0 -181
- bayinx/dists/censored/__init__.py +0 -3
- bayinx/dists/censored/gamma2/__init__.py +0 -3
- bayinx/dists/censored/gamma2/r.py +0 -68
- bayinx/dists/censored/posnormal/__init__.py +0 -3
- bayinx/dists/censored/posnormal/r.py +0 -116
- bayinx/dists/gamma2.py +0 -49
- bayinx/dists/posnormal.py +0 -260
- bayinx/dists/uniform.py +0 -75
- bayinx/mhx/__init__.py +0 -1
- bayinx/mhx/vi/__init__.py +0 -5
- bayinx/mhx/vi/flows/__init__.py +0 -3
- bayinx/mhx/vi/flows/fullaffine.py +0 -75
- bayinx/mhx/vi/flows/planar.py +0 -74
- bayinx/mhx/vi/flows/radial.py +0 -94
- bayinx/mhx/vi/flows/sylvester.py +0 -19
- bayinx/mhx/vi/normalizing_flow.py +0 -149
- bayinx/mhx/vi/standard.py +0 -63
- bayinx-0.3.10.dist-info/METADATA +0 -39
- bayinx-0.3.10.dist-info/RECORD +0 -35
- /bayinx/{py.typed → flows/otflow.py} +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/dists/normal.py
CHANGED
|
@@ -1,16 +1,25 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
1
4
|
import jax.lax as lax
|
|
2
5
|
import jax.numpy as jnp
|
|
3
|
-
import jax.
|
|
4
|
-
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
import jax.scipy.special as jsp
|
|
8
|
+
import jax.tree as jt
|
|
9
|
+
from jaxtyping import Array, ArrayLike, PRNGKeyArray, Real, Scalar
|
|
10
|
+
|
|
11
|
+
from bayinx.core.distribution import Distribution
|
|
12
|
+
from bayinx.core.node import Node
|
|
13
|
+
from bayinx.nodes import Observed
|
|
5
14
|
|
|
6
|
-
|
|
15
|
+
PI = 3.141592653589793
|
|
7
16
|
|
|
8
17
|
|
|
9
18
|
def prob(
|
|
10
|
-
x:
|
|
11
|
-
mu:
|
|
12
|
-
sigma:
|
|
13
|
-
) ->
|
|
19
|
+
x: Real[ArrayLike, "..."],
|
|
20
|
+
mu: Real[ArrayLike, "..."],
|
|
21
|
+
sigma: Real[ArrayLike, "..."],
|
|
22
|
+
) -> Real[Array, "..."]:
|
|
14
23
|
"""
|
|
15
24
|
The probability density function (PDF) for a Normal distribution.
|
|
16
25
|
|
|
@@ -25,14 +34,14 @@ def prob(
|
|
|
25
34
|
# Cast to Array
|
|
26
35
|
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
27
36
|
|
|
28
|
-
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (sigma * lax.sqrt(2.0 *
|
|
37
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (sigma * lax.sqrt(2.0 * PI))
|
|
29
38
|
|
|
30
39
|
|
|
31
40
|
def logprob(
|
|
32
|
-
x:
|
|
33
|
-
mu:
|
|
34
|
-
sigma:
|
|
35
|
-
) ->
|
|
41
|
+
x: Real[ArrayLike, "..."],
|
|
42
|
+
mu: Real[ArrayLike, "..."],
|
|
43
|
+
sigma: Real[ArrayLike, "..."],
|
|
44
|
+
) -> Real[Array, "..."]:
|
|
36
45
|
"""
|
|
37
46
|
The log of the probability density function (log PDF) for a Normal distribution.
|
|
38
47
|
|
|
@@ -47,92 +56,155 @@ def logprob(
|
|
|
47
56
|
# Cast to Array
|
|
48
57
|
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
49
58
|
|
|
50
|
-
return -lax.log(
|
|
59
|
+
return -lax.log(lax.sqrt(2.0 * PI)) - lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
|
|
51
60
|
|
|
52
61
|
|
|
53
|
-
def
|
|
54
|
-
x:
|
|
55
|
-
mu:
|
|
56
|
-
sigma:
|
|
57
|
-
) ->
|
|
62
|
+
def cdf(
|
|
63
|
+
x: Real[ArrayLike, "..."],
|
|
64
|
+
mu: Real[ArrayLike, "..."],
|
|
65
|
+
sigma: Real[ArrayLike, "..."],
|
|
66
|
+
) -> Real[Array, "..."]:
|
|
58
67
|
"""
|
|
59
|
-
The
|
|
68
|
+
The cumulative density function (CDF) for a Normal distribution.
|
|
60
69
|
|
|
61
70
|
# Parameters
|
|
62
|
-
- `x`: Where to evaluate the
|
|
71
|
+
- `x`: Where to evaluate the CDF.
|
|
63
72
|
- `mu`: The mean.
|
|
64
73
|
- `sigma`: The standard deviation.
|
|
65
74
|
|
|
66
75
|
# Returns
|
|
67
|
-
The
|
|
76
|
+
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
68
77
|
"""
|
|
69
78
|
# Cast to Array
|
|
70
79
|
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
71
80
|
|
|
72
|
-
return
|
|
81
|
+
return jsp.ndtr((x - mu) / sigma)
|
|
73
82
|
|
|
74
83
|
|
|
75
|
-
def
|
|
76
|
-
x:
|
|
77
|
-
mu:
|
|
78
|
-
sigma:
|
|
79
|
-
) ->
|
|
84
|
+
def logcdf(
|
|
85
|
+
x: Real[ArrayLike, "..."],
|
|
86
|
+
mu: Real[ArrayLike, "..."],
|
|
87
|
+
sigma: Real[ArrayLike, "..."],
|
|
88
|
+
) -> Real[Array, "..."]:
|
|
80
89
|
"""
|
|
81
|
-
The log of the
|
|
90
|
+
The log of the cumulative density function (log CDF) for a Normal distribution.
|
|
82
91
|
|
|
83
92
|
# Parameters
|
|
84
|
-
- `x`: Where to evaluate the
|
|
93
|
+
- `x`: Where to evaluate the log CDF.
|
|
85
94
|
- `mu`: The mean.
|
|
86
95
|
- `sigma`: The standard deviation.
|
|
87
96
|
|
|
88
97
|
# Returns
|
|
89
|
-
The log
|
|
98
|
+
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
90
99
|
"""
|
|
91
100
|
# Cast to Array
|
|
92
101
|
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
93
102
|
|
|
94
|
-
return
|
|
103
|
+
return jsp.log_ndtr((x - mu) / sigma)
|
|
95
104
|
|
|
96
105
|
|
|
97
|
-
def
|
|
98
|
-
x:
|
|
99
|
-
mu:
|
|
100
|
-
sigma:
|
|
101
|
-
) ->
|
|
106
|
+
def ccdf(
|
|
107
|
+
x: Real[ArrayLike, "..."],
|
|
108
|
+
mu: Real[ArrayLike, "..."],
|
|
109
|
+
sigma: Real[ArrayLike, "..."],
|
|
110
|
+
) -> Real[Array, "..."]:
|
|
111
|
+
"""
|
|
112
|
+
The complementary cumulative density function (cCDF) for a Normal distribution.
|
|
113
|
+
|
|
114
|
+
# Parameters
|
|
115
|
+
- `x`: Where to evaluate the cCDF.
|
|
116
|
+
- `mu`: The mean.
|
|
117
|
+
- `sigma`: The standard deviation.
|
|
118
|
+
|
|
119
|
+
# Returns
|
|
120
|
+
The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
121
|
+
"""
|
|
102
122
|
# Cast to Array
|
|
103
123
|
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
104
124
|
|
|
105
|
-
return
|
|
125
|
+
return jsp.ndtr((mu - x) / sigma)
|
|
106
126
|
|
|
107
127
|
|
|
108
|
-
def
|
|
109
|
-
x:
|
|
110
|
-
mu:
|
|
111
|
-
sigma:
|
|
112
|
-
) ->
|
|
128
|
+
def logccdf(
|
|
129
|
+
x: Real[ArrayLike, "..."],
|
|
130
|
+
mu: Real[ArrayLike, "..."],
|
|
131
|
+
sigma: Real[ArrayLike, "..."],
|
|
132
|
+
) -> Real[Array, "..."]:
|
|
133
|
+
"""
|
|
134
|
+
The log of the complementary cumulative density function (log cCDF) for a Normal distribution.
|
|
135
|
+
|
|
136
|
+
# Parameters
|
|
137
|
+
- `x`: Where to evaluate the log cCDF.
|
|
138
|
+
- `mu`: The mean.
|
|
139
|
+
- `sigma`: The standard deviation.
|
|
140
|
+
|
|
141
|
+
# Returns
|
|
142
|
+
The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
|
143
|
+
"""
|
|
113
144
|
# Cast to Array
|
|
114
145
|
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
115
146
|
|
|
116
|
-
return
|
|
147
|
+
return jsp.log_ndtr((mu - x) / sigma)
|
|
117
148
|
|
|
118
149
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
sigma: Float[ArrayLike, "..."],
|
|
123
|
-
) -> Float[Array, "..."]:
|
|
124
|
-
# Cast to Array
|
|
125
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
150
|
+
class Normal(Distribution):
|
|
151
|
+
"""
|
|
152
|
+
A normal distribution.
|
|
126
153
|
|
|
127
|
-
|
|
154
|
+
# Attributes
|
|
155
|
+
- `mu`: The mean/location parameter.
|
|
156
|
+
- `sigma`: The standard-deviation/scale parameter.
|
|
157
|
+
"""
|
|
128
158
|
|
|
159
|
+
mu: Node[Real[Array, "..."]]
|
|
160
|
+
sigma: Node[Real[Array, "..."]]
|
|
129
161
|
|
|
130
|
-
def logccdf(
|
|
131
|
-
x: Float[ArrayLike, "..."],
|
|
132
|
-
mu: Float[ArrayLike, "..."],
|
|
133
|
-
sigma: Float[ArrayLike, "..."],
|
|
134
|
-
) -> Float[Array, "..."]:
|
|
135
|
-
# Cast to Array
|
|
136
|
-
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
137
162
|
|
|
138
|
-
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
mu: Real[ArrayLike, "..."] | Node[Real[Array, "..."]],
|
|
166
|
+
sigma: Real[ArrayLike, "..."] | Node[Real[Array, "..."]]
|
|
167
|
+
):
|
|
168
|
+
# Initialize mean/location parameter (mu)
|
|
169
|
+
if isinstance(mu, Node):
|
|
170
|
+
if isinstance(mu.obj, ArrayLike):
|
|
171
|
+
self.mu = mu # type: ignore
|
|
172
|
+
else:
|
|
173
|
+
self.mu = Observed(jnp.asarray(mu))
|
|
174
|
+
|
|
175
|
+
# Initialize dispersion/scale parameter (sigma)
|
|
176
|
+
if isinstance(sigma, Node):
|
|
177
|
+
if isinstance(sigma.obj, ArrayLike):
|
|
178
|
+
self.sigma = sigma # type: ignore
|
|
179
|
+
else:
|
|
180
|
+
self.sigma = Observed(jnp.asarray(sigma))
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def logprob(self, node: Node) -> Scalar:
|
|
184
|
+
obj, mu, sigma = (
|
|
185
|
+
node.obj,
|
|
186
|
+
self.mu.obj,
|
|
187
|
+
self.sigma.obj
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Filter out irrelevant values
|
|
191
|
+
obj, _ = eqx.partition(obj, node._filter_spec)
|
|
192
|
+
|
|
193
|
+
# Helper function for the single-leaf log-probability evaluation
|
|
194
|
+
def leaf_logprob(x: Real[ArrayLike, "..."]) -> Scalar:
|
|
195
|
+
return logprob(x, mu, sigma).sum()
|
|
196
|
+
|
|
197
|
+
# Compute log probabilities across leaves
|
|
198
|
+
eval_obj = jt.map(leaf_logprob, obj)
|
|
199
|
+
|
|
200
|
+
# Compute total sum
|
|
201
|
+
total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
|
|
202
|
+
|
|
203
|
+
return jnp.asarray(total)
|
|
204
|
+
|
|
205
|
+
def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
|
|
206
|
+
# Coerce to tuple
|
|
207
|
+
if isinstance(shape, int):
|
|
208
|
+
shape = (shape, )
|
|
209
|
+
|
|
210
|
+
return jr.normal(key, shape) * self.sigma.obj + self.mu.obj
|
bayinx/dists/poisson.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax.lax as lax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
import jax.scipy.special as jsp
|
|
8
|
+
import jax.tree as jt
|
|
9
|
+
from jaxtyping import Array, ArrayLike, Integer, PRNGKeyArray, Real, Scalar
|
|
10
|
+
|
|
11
|
+
from bayinx.core.distribution import Distribution
|
|
12
|
+
from bayinx.core.node import Node
|
|
13
|
+
from bayinx.nodes import Observed
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def prob(
|
|
17
|
+
x: Integer[ArrayLike, "..."],
|
|
18
|
+
lam: Real[ArrayLike, "..."],
|
|
19
|
+
) -> Real[Array, "..."]:
|
|
20
|
+
"""
|
|
21
|
+
The probability mass function (PMF) for a Poisson distribution.
|
|
22
|
+
|
|
23
|
+
# Parameters
|
|
24
|
+
- `x`: Where to evaluate the PMF.
|
|
25
|
+
- `lam`: The rate parameter (lambda), representing the average number of events.
|
|
26
|
+
|
|
27
|
+
# Returns
|
|
28
|
+
The PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
29
|
+
"""
|
|
30
|
+
# Cast to Array
|
|
31
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
32
|
+
|
|
33
|
+
return lax.exp(logprob(x, lam))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def logprob(
|
|
37
|
+
x: Integer[ArrayLike, "..."],
|
|
38
|
+
lam: Real[ArrayLike, "..."],
|
|
39
|
+
) -> Real[Array, "..."]:
|
|
40
|
+
"""
|
|
41
|
+
The log of the probability mass function (log PMF) for a Poisson distribution.
|
|
42
|
+
|
|
43
|
+
# Parameters
|
|
44
|
+
- `x`: Where to evaluate the log PMF.
|
|
45
|
+
- `lam`: The rate parameter (lambda), representing the average number of events.
|
|
46
|
+
|
|
47
|
+
# Returns
|
|
48
|
+
The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
49
|
+
"""
|
|
50
|
+
# Cast to Array
|
|
51
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
52
|
+
|
|
53
|
+
return x * lax.log(lam) - lam - jsp.gammaln(x + 1)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def cdf(
|
|
57
|
+
x: Integer[ArrayLike, "..."],
|
|
58
|
+
lam: Real[ArrayLike, "..."],
|
|
59
|
+
) -> Real[Array, "..."]:
|
|
60
|
+
"""
|
|
61
|
+
The cumulative density function (CDF) for a Poisson distribution (P(X <= x)).
|
|
62
|
+
|
|
63
|
+
# Parameters
|
|
64
|
+
- `x`: Where to evaluate the CDF.
|
|
65
|
+
- `lam`: The rate parameter (lambda).
|
|
66
|
+
|
|
67
|
+
# Returns
|
|
68
|
+
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
69
|
+
"""
|
|
70
|
+
# Cast to Array
|
|
71
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
72
|
+
|
|
73
|
+
result = jsp.gammainc(x + 1.0, lam)
|
|
74
|
+
result = lax.select(x < 0, jnp.array(0.0), result)
|
|
75
|
+
|
|
76
|
+
return result
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def logcdf(
|
|
80
|
+
x: Integer[ArrayLike, "..."],
|
|
81
|
+
lam: Real[ArrayLike, "..."],
|
|
82
|
+
) -> Real[Array, "..."]:
|
|
83
|
+
"""
|
|
84
|
+
The log of the cumulative density function (log CDF) for a Poisson distribution.
|
|
85
|
+
|
|
86
|
+
# Parameters
|
|
87
|
+
- `x`: Where to evaluate the log CDF (ln P(X <= x)).
|
|
88
|
+
- `lam`: The rate parameter (lambda).
|
|
89
|
+
|
|
90
|
+
# Returns
|
|
91
|
+
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
92
|
+
"""
|
|
93
|
+
# Cast to Array
|
|
94
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
95
|
+
|
|
96
|
+
result = lax.log(jsp.gammainc(x + 1.0, lam))
|
|
97
|
+
|
|
98
|
+
# Handle values outside of support
|
|
99
|
+
result = lax.select(x < 0, -jnp.inf, result)
|
|
100
|
+
|
|
101
|
+
return result
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def ccdf(
|
|
105
|
+
x: Integer[ArrayLike, "..."],
|
|
106
|
+
lam: Real[ArrayLike, "..."],
|
|
107
|
+
) -> Real[Array, "..."]:
|
|
108
|
+
"""
|
|
109
|
+
The complementary cumulative density function (cCDF) for a Poisson distribution (P(X > x)).
|
|
110
|
+
|
|
111
|
+
# Parameters
|
|
112
|
+
- `x`: Where to evaluate the cCDF.
|
|
113
|
+
- `lam`: The rate parameter (lambda).
|
|
114
|
+
|
|
115
|
+
# Returns
|
|
116
|
+
The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
117
|
+
"""
|
|
118
|
+
# Cast to Array
|
|
119
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
120
|
+
|
|
121
|
+
result = jsp.gammaincc(x + 1.0, lam)
|
|
122
|
+
|
|
123
|
+
# Handle values outside of support
|
|
124
|
+
result = lax.select(x < 0, jnp.array(1.0), result)
|
|
125
|
+
|
|
126
|
+
return result
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def logccdf(
|
|
130
|
+
x: Integer[ArrayLike, "..."],
|
|
131
|
+
lam: Real[ArrayLike, "..."],
|
|
132
|
+
) -> Real[Array, "..."]:
|
|
133
|
+
"""
|
|
134
|
+
The log of the complementary cumulative density function (log cCDF) for a Poisson distribution.
|
|
135
|
+
|
|
136
|
+
# Parameters
|
|
137
|
+
- `x`: Where to evaluate the log cCDF.
|
|
138
|
+
- `lam`: The rate parameter (lambda).
|
|
139
|
+
|
|
140
|
+
# Returns
|
|
141
|
+
The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
142
|
+
"""
|
|
143
|
+
# Cast to Array
|
|
144
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
145
|
+
|
|
146
|
+
result = lax.log(jsp.gammaincc(x + 1.0, lam))
|
|
147
|
+
|
|
148
|
+
# Handle values outside of support
|
|
149
|
+
result = lax.select(x < 0, jnp.array(0.0), result)
|
|
150
|
+
|
|
151
|
+
return result
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class Poisson(Distribution):
|
|
155
|
+
"""
|
|
156
|
+
A Poisson distribution.
|
|
157
|
+
|
|
158
|
+
# Attributes
|
|
159
|
+
- `lam`: The rate parameter (lambda), representing the average number of events in an interval.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
lam: Node[Real[Array, "..."]]
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def __init__(
|
|
166
|
+
self,
|
|
167
|
+
lam: Real[ArrayLike, "..."] | Node[Real[Array, "..."]],
|
|
168
|
+
):
|
|
169
|
+
# Initialize rate parameter (lambda)
|
|
170
|
+
if isinstance(lam, Node):
|
|
171
|
+
if isinstance(lam.obj, ArrayLike):
|
|
172
|
+
self.lam = lam # type: ignore
|
|
173
|
+
else:
|
|
174
|
+
self.lam = Observed(jnp.asarray(lam))
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def logprob(self, node: Node) -> Scalar:
|
|
178
|
+
obj, lam = (
|
|
179
|
+
node.obj,
|
|
180
|
+
self.lam.obj,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Filter out irrelevant values
|
|
184
|
+
obj, _ = eqx.partition(obj, node._filter_spec)
|
|
185
|
+
|
|
186
|
+
# Helper function for the single-leaf log-probability evaluation
|
|
187
|
+
def leaf_logprob(x: Integer[ArrayLike, "..."]) -> Scalar:
|
|
188
|
+
return logprob(x, lam).sum()
|
|
189
|
+
|
|
190
|
+
# Compute log probabilities across leaves
|
|
191
|
+
eval_obj = jt.map(leaf_logprob, obj)
|
|
192
|
+
|
|
193
|
+
# Compute total sum
|
|
194
|
+
total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
|
|
195
|
+
|
|
196
|
+
return jnp.asarray(total)
|
|
197
|
+
|
|
198
|
+
def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
|
|
199
|
+
# Coerce to tuple
|
|
200
|
+
if isinstance(shape, int):
|
|
201
|
+
shape = (shape, )
|
|
202
|
+
|
|
203
|
+
return jr.poisson(key, self.lam.obj, shape=shape)
|
bayinx/flows/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from bayinx.flows.diagaffine import DiagAffine as DiagAffine
|
|
2
|
+
from bayinx.flows.fullaffine import FullAffine as FullAffine
|
|
3
|
+
from bayinx.flows.lowrankaffine import LowRankAffine as LowRankAffine
|
|
4
|
+
from bayinx.flows.planar import Planar as Planar
|
|
5
|
+
from bayinx.flows.sylvester import Sylvester as Sylvester
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from typing import Callable, Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
from jaxtyping import Array, Float, PRNGKeyArray, Scalar
|
|
8
|
+
|
|
9
|
+
from bayinx.core.flow import FlowLayer, FlowSpec
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DiagAffineLayer(FlowLayer):
|
|
13
|
+
"""
|
|
14
|
+
A diagonal (element-wise) affine flow.
|
|
15
|
+
|
|
16
|
+
# Attributes
|
|
17
|
+
- `params`: The parameters of the diagonal affine flow.
|
|
18
|
+
- `constraints`: The constraining transformations for the parameters of the diagonal affine flow.
|
|
19
|
+
- `static`: Whether the flow layer is frozen (parameters are not subject to further optimization).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
params: Dict[str, Array]
|
|
23
|
+
constraints: Dict[str, Callable[[Array], Array]]
|
|
24
|
+
static: bool
|
|
25
|
+
|
|
26
|
+
def __init__(self, dim: int, key: PRNGKeyArray):
|
|
27
|
+
"""
|
|
28
|
+
Initializes a full affine flow.
|
|
29
|
+
|
|
30
|
+
# Parameters
|
|
31
|
+
- `dim`: The dimension of the parameter space.
|
|
32
|
+
"""
|
|
33
|
+
self.static = False
|
|
34
|
+
# Split key
|
|
35
|
+
k1, k2 = jr.split(key)
|
|
36
|
+
|
|
37
|
+
# Initialize parameters
|
|
38
|
+
self.params = {
|
|
39
|
+
"shift": jr.normal(k1, (dim,)) / dim**0.5,
|
|
40
|
+
"scale": jr.normal(k2, (dim,)) / dim**0.5,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
# Define constraints
|
|
44
|
+
self.constraints = {"scale": jnp.exp}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def __forward(self, draw: Float[Array, " D"]) -> Float[Array, " D"]:
|
|
48
|
+
params = self.transform_params()
|
|
49
|
+
|
|
50
|
+
assert len(draw.shape) == 1
|
|
51
|
+
|
|
52
|
+
# Extract parameters
|
|
53
|
+
shift: Float[Array, " dim"] = params["shift"]
|
|
54
|
+
scale: Float[Array, " dim"] = params["scale"]
|
|
55
|
+
|
|
56
|
+
# Compute forward transformation
|
|
57
|
+
draw = draw * scale + shift
|
|
58
|
+
|
|
59
|
+
return draw
|
|
60
|
+
|
|
61
|
+
@eqx.filter_jit
|
|
62
|
+
def forward(self, draws: Float[Array, "draws dim"]) -> Float[Array, "draws dim"]:
|
|
63
|
+
f = jax.vmap(self.__forward, 0)
|
|
64
|
+
return f(draws)
|
|
65
|
+
|
|
66
|
+
def __adjust(self, draw: Float[Array, " dim"]) -> Float[Array, " dim"]:
|
|
67
|
+
params = self.transform_params()
|
|
68
|
+
|
|
69
|
+
# Extract parameters
|
|
70
|
+
scale: Float[Array, " dim"] = params["scale"]
|
|
71
|
+
|
|
72
|
+
# Compute log-Jacobian adjustment
|
|
73
|
+
lja: Array = jnp.log(scale).sum()
|
|
74
|
+
|
|
75
|
+
assert lja.shape == ()
|
|
76
|
+
|
|
77
|
+
return lja
|
|
78
|
+
|
|
79
|
+
@eqx.filter_jit
|
|
80
|
+
def adjust(self, draws: Float[Array, "draws dim"]) -> Float[Array, "draws dim"]:
|
|
81
|
+
f = jax.vmap(self.__adjust, 0)
|
|
82
|
+
return f(draws)
|
|
83
|
+
|
|
84
|
+
def __forward_and_adjust(self, draw: Float[Array, " dim"]) -> Tuple[Float[Array, " dim"], Scalar]:
|
|
85
|
+
params = self.transform_params()
|
|
86
|
+
|
|
87
|
+
assert len(draw.shape) == 1
|
|
88
|
+
|
|
89
|
+
# Extract parameters
|
|
90
|
+
shift: Float[Array, " dim"] = params["shift"] # noqa
|
|
91
|
+
scale: Float[Array, " dim"] = params["scale"] # noqa
|
|
92
|
+
|
|
93
|
+
# Compute forward transformation
|
|
94
|
+
draw = scale * draw + shift
|
|
95
|
+
|
|
96
|
+
assert len(draw.shape) == 1
|
|
97
|
+
|
|
98
|
+
# Compute log-Jacobian adjustment
|
|
99
|
+
lja: Scalar = jnp.log(scale).sum()
|
|
100
|
+
|
|
101
|
+
assert lja.shape == ()
|
|
102
|
+
|
|
103
|
+
return draw, lja
|
|
104
|
+
|
|
105
|
+
@eqx.filter_jit
|
|
106
|
+
def forward_and_adjust(self, draws: Float[Array, "draws dim"]) -> Tuple[Float[Array, "draws dim"], Scalar]:
|
|
107
|
+
f = jax.vmap(self.__forward_and_adjust, 0)
|
|
108
|
+
return f(draws)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class DiagAffine(FlowSpec):
|
|
112
|
+
"""
|
|
113
|
+
A specification for the diagonal affine flow.
|
|
114
|
+
"""
|
|
115
|
+
key: PRNGKeyArray
|
|
116
|
+
def __init__(self, key: PRNGKeyArray = jr.key(0)):
|
|
117
|
+
self.key = key
|
|
118
|
+
|
|
119
|
+
def construct(self, dim: int) -> DiagAffineLayer:
|
|
120
|
+
return DiagAffineLayer(dim, self.key)
|