bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +4 -3
- bayinx/constraints/identity.py +26 -0
- bayinx/constraints/interval.py +62 -0
- bayinx/constraints/lower.py +31 -24
- bayinx/constraints/upper.py +57 -0
- bayinx/core/__init__.py +0 -7
- bayinx/core/constraint.py +32 -0
- bayinx/core/context.py +42 -0
- bayinx/core/distribution.py +34 -0
- bayinx/core/flow.py +99 -0
- bayinx/core/model.py +228 -0
- bayinx/core/node.py +201 -0
- bayinx/core/types.py +17 -0
- bayinx/core/utils.py +109 -0
- bayinx/core/variational.py +170 -0
- bayinx/dists/__init__.py +5 -3
- bayinx/dists/bernoulli.py +180 -11
- bayinx/dists/binomial.py +215 -0
- bayinx/dists/exponential.py +211 -0
- bayinx/dists/normal.py +131 -59
- bayinx/dists/poisson.py +203 -0
- bayinx/flows/__init__.py +5 -0
- bayinx/flows/diagaffine.py +120 -0
- bayinx/flows/fullaffine.py +123 -0
- bayinx/flows/lowrankaffine.py +165 -0
- bayinx/flows/planar.py +155 -0
- bayinx/flows/radial.py +1 -0
- bayinx/flows/sylvester.py +225 -0
- bayinx/nodes/__init__.py +3 -0
- bayinx/nodes/continuous.py +64 -0
- bayinx/nodes/observed.py +36 -0
- bayinx/nodes/stochastic.py +25 -0
- bayinx/ops.py +104 -0
- bayinx/posterior.py +220 -0
- bayinx/vi/__init__.py +0 -0
- bayinx/{mhx/vi → vi}/meanfield.py +33 -29
- bayinx/vi/normalizing_flow.py +246 -0
- bayinx/vi/standard.py +95 -0
- bayinx-0.5.3.dist-info/METADATA +93 -0
- bayinx-0.5.3.dist-info/RECORD +44 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
- bayinx/core/_constraint.py +0 -28
- bayinx/core/_flow.py +0 -80
- bayinx/core/_model.py +0 -98
- bayinx/core/_parameter.py +0 -44
- bayinx/core/_variational.py +0 -181
- bayinx/dists/censored/__init__.py +0 -3
- bayinx/dists/censored/gamma2/__init__.py +0 -3
- bayinx/dists/censored/gamma2/r.py +0 -68
- bayinx/dists/censored/posnormal/__init__.py +0 -3
- bayinx/dists/censored/posnormal/r.py +0 -116
- bayinx/dists/gamma2.py +0 -49
- bayinx/dists/posnormal.py +0 -260
- bayinx/dists/uniform.py +0 -75
- bayinx/mhx/__init__.py +0 -1
- bayinx/mhx/vi/__init__.py +0 -5
- bayinx/mhx/vi/flows/__init__.py +0 -3
- bayinx/mhx/vi/flows/fullaffine.py +0 -75
- bayinx/mhx/vi/flows/planar.py +0 -74
- bayinx/mhx/vi/flows/radial.py +0 -94
- bayinx/mhx/vi/flows/sylvester.py +0 -19
- bayinx/mhx/vi/normalizing_flow.py +0 -149
- bayinx/mhx/vi/standard.py +0 -63
- bayinx-0.3.10.dist-info/METADATA +0 -39
- bayinx-0.3.10.dist-info/RECORD +0 -35
- /bayinx/{py.typed → flows/otflow.py} +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/dists/bernoulli.py
CHANGED
|
@@ -1,33 +1,202 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
1
4
|
import jax.lax as lax
|
|
2
|
-
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
import jax.tree as jt
|
|
8
|
+
from jaxtyping import Array, ArrayLike, Float, Integer, PRNGKeyArray, Real, Scalar
|
|
9
|
+
|
|
10
|
+
from bayinx.core.distribution import Distribution
|
|
11
|
+
from bayinx.core.node import Node
|
|
12
|
+
from bayinx.nodes import Observed
|
|
3
13
|
|
|
4
14
|
|
|
5
|
-
|
|
6
|
-
|
|
15
|
+
def prob(
|
|
16
|
+
x: Integer[ArrayLike, "..."],
|
|
17
|
+
p: Real[ArrayLike, "..."]
|
|
18
|
+
) -> Real[Array, "..."]:
|
|
7
19
|
"""
|
|
8
20
|
The probability mass function (PMF) for a Bernoulli distribution.
|
|
9
21
|
|
|
10
22
|
# Parameters
|
|
11
|
-
- `x`:
|
|
12
|
-
- `p`:
|
|
23
|
+
- `x`: Where to evaluate the PMF (pst be 0 or 1).
|
|
24
|
+
- `p`: The probability of success (p).
|
|
13
25
|
|
|
14
26
|
# Returns
|
|
15
27
|
The PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
|
16
28
|
"""
|
|
29
|
+
# Cast to Array
|
|
30
|
+
x, p = jnp.asarray(x), jnp.asarray(p)
|
|
17
31
|
|
|
18
|
-
|
|
32
|
+
# Bernoulli PMF: p^x * (1-p)^(1-x)
|
|
33
|
+
return lax.exp(x * lax.log(p) + (1.0 - x) * lax.log1p(-p))
|
|
19
34
|
|
|
20
35
|
|
|
21
|
-
def logprob(
|
|
36
|
+
def logprob(
|
|
37
|
+
x: Integer[ArrayLike, "..."],
|
|
38
|
+
p: Real[ArrayLike, "..."]
|
|
39
|
+
) -> Real[Array, "..."]:
|
|
22
40
|
"""
|
|
23
|
-
The log probability mass function (log PMF) for a Bernoulli distribution.
|
|
41
|
+
The log of the probability mass function (log PMF) for a Bernoulli distribution.
|
|
24
42
|
|
|
25
43
|
# Parameters
|
|
26
|
-
- `x`:
|
|
27
|
-
- `p`:
|
|
44
|
+
- `x`: Where to evaluate the log PMF (pst be 0 or 1).
|
|
45
|
+
- `p`: The probability of success (p).
|
|
28
46
|
|
|
29
47
|
# Returns
|
|
30
48
|
The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
|
31
49
|
"""
|
|
50
|
+
# Cast to Array
|
|
51
|
+
x, p = jnp.asarray(x), jnp.asarray(p)
|
|
52
|
+
|
|
53
|
+
return x * lax.log(p) + (1.0 - x) * lax.log(1.0 - p)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def cdf(
|
|
57
|
+
x: Integer[ArrayLike, "..."],
|
|
58
|
+
p: Real[ArrayLike, "..."]
|
|
59
|
+
) -> Real[Array, "..."]:
|
|
60
|
+
"""
|
|
61
|
+
The cuplative distribution function (CDF) for a Bernoulli distribution.
|
|
62
|
+
|
|
63
|
+
# Parameters
|
|
64
|
+
- `x`: Where to evaluate the CDF.
|
|
65
|
+
- `p`: The probability of success (p).
|
|
66
|
+
|
|
67
|
+
# Returns
|
|
68
|
+
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
|
69
|
+
"""
|
|
70
|
+
# Cast to Array
|
|
71
|
+
x, p = jnp.asarray(x), jnp.asarray(p)
|
|
72
|
+
|
|
73
|
+
return jnp.where(
|
|
74
|
+
x < 0.0,
|
|
75
|
+
0.0,
|
|
76
|
+
jnp.where(x < 1.0, 1.0 - p, 1.0)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def logcdf(
|
|
81
|
+
x: Integer[ArrayLike, "..."],
|
|
82
|
+
p: Real[ArrayLike, "..."]
|
|
83
|
+
) -> Real[Array, "..."]:
|
|
84
|
+
"""
|
|
85
|
+
The log of the cuplative distribution function (log CDF) for a Bernoulli distribution.
|
|
86
|
+
|
|
87
|
+
# Parameters
|
|
88
|
+
- `x`: Where to evaluate the log CDF.
|
|
89
|
+
- `p`: The probability of success (p).
|
|
90
|
+
|
|
91
|
+
# Returns
|
|
92
|
+
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
|
93
|
+
"""
|
|
94
|
+
# Cast to Array
|
|
95
|
+
x, p = jnp.asarray(x), jnp.asarray(p)
|
|
96
|
+
|
|
97
|
+
return jnp.where(
|
|
98
|
+
x < 0.0,
|
|
99
|
+
-jnp.inf,
|
|
100
|
+
jnp.where(x < 1.0, lax.log1p(-p), 0.0)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def ccdf(
|
|
105
|
+
x: Integer[ArrayLike, "..."],
|
|
106
|
+
p: Real[ArrayLike, "..."]
|
|
107
|
+
) -> Real[Array, "..."]:
|
|
108
|
+
"""
|
|
109
|
+
The complementary cuplative distribution function (cCDF) for a Bernoulli distribution.
|
|
110
|
+
|
|
111
|
+
# Parameters
|
|
112
|
+
- `x`: Where to evaluate the cCDF.
|
|
113
|
+
- `p`: The probability of success (p).
|
|
114
|
+
|
|
115
|
+
# Returns
|
|
116
|
+
The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
# Cast to Array
|
|
120
|
+
x, p = jnp.asarray(x), jnp.asarray(p)
|
|
121
|
+
|
|
122
|
+
return jnp.where(
|
|
123
|
+
x < 0.0,
|
|
124
|
+
1.0,
|
|
125
|
+
jnp.where(x < 1.0, p, 0.0)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def logccdf(
|
|
130
|
+
x: Integer[ArrayLike, "..."],
|
|
131
|
+
p: Real[ArrayLike, "..."]
|
|
132
|
+
) -> Real[Array, "..."]:
|
|
133
|
+
"""
|
|
134
|
+
The log of the complementary cuplative distribution function (log cCDF) for a Bernoulli distribution.
|
|
135
|
+
|
|
136
|
+
# Parameters
|
|
137
|
+
- `x`: Where to evaluate the log cCDF.
|
|
138
|
+
- `p`: The probability of success (p).
|
|
139
|
+
|
|
140
|
+
# Returns
|
|
141
|
+
The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
|
|
142
|
+
"""
|
|
143
|
+
# Cast to Array
|
|
144
|
+
x, p = jnp.asarray(x), jnp.asarray(p)
|
|
145
|
+
|
|
146
|
+
return jnp.where(
|
|
147
|
+
x < 0.0,
|
|
148
|
+
0.0,
|
|
149
|
+
jnp.where(x < 1.0, lax.log(p), -jnp.inf)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class Bernoulli(Distribution):
|
|
154
|
+
"""
|
|
155
|
+
A Bernoulli distribution.
|
|
156
|
+
|
|
157
|
+
# Attributes
|
|
158
|
+
- `p`: The probability of success (p).
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
p: Node[Real[Array, "..."]]
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
p: Real[ArrayLike, "..."] | Node[Real[ArrayLike, "..."]],
|
|
167
|
+
):
|
|
168
|
+
# Initialize probability of success parameter (p)
|
|
169
|
+
if isinstance(p, Node):
|
|
170
|
+
if isinstance(p.obj, ArrayLike):
|
|
171
|
+
self.p = p # type: ignore
|
|
172
|
+
else:
|
|
173
|
+
self.p = Observed(jnp.asarray(p))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def logprob(self, node: Node) -> Scalar:
|
|
177
|
+
obj, p = (
|
|
178
|
+
node.obj,
|
|
179
|
+
self.p.obj,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Filter out irrelevant values
|
|
183
|
+
obj, _ = eqx.partition(obj, node._filter_spec)
|
|
184
|
+
|
|
185
|
+
# Helper function for the single-leaf log-probability evaluation
|
|
186
|
+
def leaf_logprob(x: Float[ArrayLike, ""]) -> Scalar:
|
|
187
|
+
return logprob(x, p).sum()
|
|
188
|
+
|
|
189
|
+
# Compute log probabilities across leaves
|
|
190
|
+
eval_obj = jt.map(leaf_logprob, obj)
|
|
191
|
+
|
|
192
|
+
# Compute total sum
|
|
193
|
+
total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
|
|
194
|
+
|
|
195
|
+
return jnp.asarray(total)
|
|
196
|
+
|
|
197
|
+
def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
|
|
198
|
+
# Coerce to tuple
|
|
199
|
+
if isinstance(shape, int):
|
|
200
|
+
shape = (shape, )
|
|
32
201
|
|
|
33
|
-
|
|
202
|
+
return jr.bernoulli(key, self.p.obj, shape)
|
bayinx/dists/binomial.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax.lax as lax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
import jax.scipy.special as jsp
|
|
8
|
+
import jax.tree as jt
|
|
9
|
+
from jaxtyping import Array, ArrayLike, Float, Integer, PRNGKeyArray, Real, Scalar
|
|
10
|
+
|
|
11
|
+
from bayinx.core.distribution import Distribution
|
|
12
|
+
from bayinx.core.node import Node
|
|
13
|
+
from bayinx.nodes import Observed
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def log_binom_coeff(n: ArrayLike, x: ArrayLike) -> Array:
|
|
17
|
+
n, x = jnp.asarray(n), jnp.asarray(x)
|
|
18
|
+
return jsp.gammaln(n + 1) - jsp.gammaln(x + 1) - jsp.gammaln(n - x + 1)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def prob(
|
|
22
|
+
x: Integer[ArrayLike, "..."],
|
|
23
|
+
n: Integer[ArrayLike, "..."],
|
|
24
|
+
p: Real[ArrayLike, "..."],
|
|
25
|
+
) -> Real[Array, "..."]:
|
|
26
|
+
"""
|
|
27
|
+
The probability mass function (PMF) for a Binomial distribution.
|
|
28
|
+
|
|
29
|
+
# Parameters
|
|
30
|
+
- `x`: Where to evaluate the PMF (number of successes).
|
|
31
|
+
- `n`: The number of trials.
|
|
32
|
+
- `p`: The probability of success.
|
|
33
|
+
|
|
34
|
+
# Returns
|
|
35
|
+
The PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
|
|
36
|
+
"""
|
|
37
|
+
# Cast to Array
|
|
38
|
+
x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
|
|
39
|
+
|
|
40
|
+
return lax.exp(logprob(x, n, p))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def logprob(
|
|
44
|
+
x: Integer[ArrayLike, "..."],
|
|
45
|
+
n: Integer[ArrayLike, "..."],
|
|
46
|
+
p: Real[ArrayLike, "..."],
|
|
47
|
+
) -> Real[Array, "..."]:
|
|
48
|
+
"""
|
|
49
|
+
The log of the probability mass function (log PMF) for a Binomial distribution.
|
|
50
|
+
|
|
51
|
+
# Parameters
|
|
52
|
+
- `x`: Where to evaluate the log PMF (number of successes).
|
|
53
|
+
- `n`: The number of trials.
|
|
54
|
+
- `p`: The probability of success.
|
|
55
|
+
|
|
56
|
+
# Returns
|
|
57
|
+
The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
|
|
58
|
+
"""
|
|
59
|
+
# Cast to Array
|
|
60
|
+
k, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
|
|
61
|
+
|
|
62
|
+
return log_binom_coeff(n, k) + k * lax.log(p) + (n - k) * lax.log1p(-p)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def cdf(
|
|
66
|
+
x: Integer[ArrayLike, "..."],
|
|
67
|
+
n: Integer[ArrayLike, "..."],
|
|
68
|
+
p: Real[ArrayLike, "..."],
|
|
69
|
+
) -> Real[Array, "..."]:
|
|
70
|
+
"""
|
|
71
|
+
The cumulative density function (CDF) for a Binomial distribution.
|
|
72
|
+
|
|
73
|
+
# Parameters
|
|
74
|
+
- `x`: Where to evaluate the CDF (P(X <= x)).
|
|
75
|
+
- `n`: The number of trials.
|
|
76
|
+
- `p`: The probability of success.
|
|
77
|
+
|
|
78
|
+
# Returns
|
|
79
|
+
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
|
|
80
|
+
"""
|
|
81
|
+
# Cast to Array
|
|
82
|
+
x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
|
|
83
|
+
|
|
84
|
+
return jsp.betainc(n - x, x + 1, 1.0 - p)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def logcdf(
|
|
88
|
+
x: Integer[ArrayLike, "..."],
|
|
89
|
+
n: Integer[ArrayLike, "..."],
|
|
90
|
+
p: Real[ArrayLike, "..."],
|
|
91
|
+
) -> Real[Array, "..."]:
|
|
92
|
+
"""
|
|
93
|
+
The log of the cumulative density function (log CDF) for a Binomial distribution.
|
|
94
|
+
|
|
95
|
+
# Parameters
|
|
96
|
+
- `x`: Where to evaluate the log CDF (ln P(X <= x)).
|
|
97
|
+
- `n`: The number of trials.
|
|
98
|
+
- `p`: The probability of success.
|
|
99
|
+
|
|
100
|
+
# Returns
|
|
101
|
+
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
|
|
102
|
+
"""
|
|
103
|
+
# Cast to Array
|
|
104
|
+
x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
|
|
105
|
+
|
|
106
|
+
return lax.log(cdf(x, n, p))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def ccdf(
|
|
110
|
+
x: Integer[ArrayLike, "..."],
|
|
111
|
+
n: Integer[ArrayLike, "..."],
|
|
112
|
+
p: Real[ArrayLike, "..."],
|
|
113
|
+
) -> Real[Array, "..."]:
|
|
114
|
+
"""
|
|
115
|
+
The complementary cumulative density function (cCDF) for a Binomial distribution (P(X > k)).
|
|
116
|
+
|
|
117
|
+
# Parameters
|
|
118
|
+
- `k`: Where to evaluate the cCDF.
|
|
119
|
+
- `n`: The number of trials.
|
|
120
|
+
- `p`: The probability of success.
|
|
121
|
+
|
|
122
|
+
# Returns
|
|
123
|
+
The cCDF evaluated at `k`. The output will have the broadcasted shapes of `k`, `n`, and `p`.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
# Cast to Array
|
|
127
|
+
x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
|
|
128
|
+
|
|
129
|
+
return jsp.betainc(x + 1, n - x, p)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def logccdf(
|
|
133
|
+
x: Integer[ArrayLike, "..."],
|
|
134
|
+
n: Integer[ArrayLike, "..."],
|
|
135
|
+
p: Real[ArrayLike, "..."],
|
|
136
|
+
) -> Real[Array, "..."]:
|
|
137
|
+
"""
|
|
138
|
+
The log of the complementary cumulative density function (log cCDF) for a Binomial distribution.
|
|
139
|
+
|
|
140
|
+
# Parameters
|
|
141
|
+
- `x`: Where to evaluate the log cCDF.
|
|
142
|
+
- `n`: The number of trials.
|
|
143
|
+
- `p`: The probability of success.
|
|
144
|
+
|
|
145
|
+
# Returns
|
|
146
|
+
The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
|
|
147
|
+
"""
|
|
148
|
+
# Cast to Array
|
|
149
|
+
x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
|
|
150
|
+
|
|
151
|
+
return lax.log(ccdf(x, n, p))
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class Binomial(Distribution):
|
|
155
|
+
"""
|
|
156
|
+
A Binomial distribution.
|
|
157
|
+
|
|
158
|
+
# Attributes
|
|
159
|
+
- `n`: The number of trials parameter (integer).
|
|
160
|
+
- `p`: The probability of success parameter (float, [0, 1]).
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
n: Node[Integer[Array, "..."]]
|
|
164
|
+
p: Node[Real[Array, "..."]]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
n: Integer[ArrayLike, "..."] | Node[Integer[Array, "..."]],
|
|
170
|
+
p: Real[ArrayLike, "..."] | Node[Real[Array, "..."]]
|
|
171
|
+
):
|
|
172
|
+
# Initialize number of trials (n)
|
|
173
|
+
if isinstance(n, Node):
|
|
174
|
+
if isinstance(n.obj, ArrayLike):
|
|
175
|
+
self.n = n # type: ignore
|
|
176
|
+
else:
|
|
177
|
+
self.n = Observed(jnp.asarray(n))
|
|
178
|
+
|
|
179
|
+
# Initialize probability of success parameter (p)
|
|
180
|
+
if isinstance(p, Node):
|
|
181
|
+
if isinstance(p.obj, ArrayLike):
|
|
182
|
+
self.p = p # type: ignore
|
|
183
|
+
else:
|
|
184
|
+
self.p = Observed(jnp.asarray(p))
|
|
185
|
+
|
|
186
|
+
def logprob(self, node: Node) -> Scalar:
|
|
187
|
+
obj, n, p = (
|
|
188
|
+
node.obj,
|
|
189
|
+
self.n.obj,
|
|
190
|
+
self.p.obj
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Filter out irrelevant values
|
|
194
|
+
obj, _ = eqx.partition(obj, node._filter_spec)
|
|
195
|
+
|
|
196
|
+
# Helper function for the single-leaf log-probability evaluation
|
|
197
|
+
def leaf_logprob(k: Float[ArrayLike, "..."]) -> Scalar:
|
|
198
|
+
return logprob(k, n, p).sum()
|
|
199
|
+
|
|
200
|
+
# Compute log probabilities across leaves
|
|
201
|
+
eval_obj = jt.map(leaf_logprob, obj)
|
|
202
|
+
|
|
203
|
+
# Compute total sum
|
|
204
|
+
total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
|
|
205
|
+
|
|
206
|
+
return jnp.asarray(total)
|
|
207
|
+
|
|
208
|
+
def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
|
|
209
|
+
# Coerce to tuple
|
|
210
|
+
if isinstance(shape, int):
|
|
211
|
+
shape = (shape, )
|
|
212
|
+
|
|
213
|
+
# Use jr.binomial for sampling (returns integer array)
|
|
214
|
+
# Note: jr.binomial accepts n as int, p as float, and shape
|
|
215
|
+
return jr.binomial(key, self.n.obj, self.p.obj, shape=shape)
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax.lax as lax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
import jax.tree as jt
|
|
8
|
+
from jaxtyping import Array, ArrayLike, PRNGKeyArray, Real, Scalar
|
|
9
|
+
|
|
10
|
+
from bayinx.core.distribution import Distribution
|
|
11
|
+
from bayinx.core.node import Node
|
|
12
|
+
from bayinx.nodes import Observed
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def prob(
|
|
16
|
+
x: Real[ArrayLike, "..."],
|
|
17
|
+
lam: Real[ArrayLike, "..."],
|
|
18
|
+
) -> Real[Array, "..."]:
|
|
19
|
+
"""
|
|
20
|
+
The probability density function (PDF) for an Exponential distribution.
|
|
21
|
+
|
|
22
|
+
# Parameters
|
|
23
|
+
- `x`: Where to evaluate the PDF (must be >= 0).
|
|
24
|
+
- `lam`: The rate parameter (lambda), must be > 0.
|
|
25
|
+
|
|
26
|
+
# Returns
|
|
27
|
+
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
28
|
+
"""
|
|
29
|
+
# Cast to Array
|
|
30
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
31
|
+
|
|
32
|
+
result = lam * lax.exp(-lam * x)
|
|
33
|
+
|
|
34
|
+
# Handle values outside of support
|
|
35
|
+
result = lax.select(x >= 0, result, jnp.array(0.0))
|
|
36
|
+
|
|
37
|
+
return result
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def logprob(
|
|
41
|
+
x: Real[ArrayLike, "..."],
|
|
42
|
+
lam: Real[ArrayLike, "..."],
|
|
43
|
+
) -> Real[Array, "..."]:
|
|
44
|
+
"""
|
|
45
|
+
The log of the probability density function (log PDF) for an Exponential distribution.
|
|
46
|
+
|
|
47
|
+
# Parameters
|
|
48
|
+
- `x`: Where to evaluate the log PDF (must be >= 0).
|
|
49
|
+
- `lam`: The rate parameter (lambda), must be > 0.
|
|
50
|
+
|
|
51
|
+
# Returns
|
|
52
|
+
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
53
|
+
"""
|
|
54
|
+
# Cast to Array
|
|
55
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
56
|
+
|
|
57
|
+
result = lax.log(lam) - lam * x
|
|
58
|
+
|
|
59
|
+
# Handle values outside of support
|
|
60
|
+
result = lax.select(x >= 0, result, -jnp.inf)
|
|
61
|
+
|
|
62
|
+
return result
|
|
63
|
+
|
|
64
|
+
def cdf(
|
|
65
|
+
x: Real[ArrayLike, "..."],
|
|
66
|
+
lam: Real[ArrayLike, "..."],
|
|
67
|
+
) -> Real[Array, "..."]:
|
|
68
|
+
"""
|
|
69
|
+
The cumulative density function (CDF) for an Exponential distribution (P(X <= x)).
|
|
70
|
+
|
|
71
|
+
# Parameters
|
|
72
|
+
- `x`: Where to evaluate the CDF.
|
|
73
|
+
- `lam`: The rate parameter (lambda).
|
|
74
|
+
|
|
75
|
+
# Returns
|
|
76
|
+
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
77
|
+
"""
|
|
78
|
+
# Cast to Array
|
|
79
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
80
|
+
|
|
81
|
+
result = 1.0 - lax.exp(-lam * x)
|
|
82
|
+
|
|
83
|
+
# Handle values outside of support
|
|
84
|
+
result = lax.select(x >= 0, result, jnp.array(0.0))
|
|
85
|
+
|
|
86
|
+
return result
|
|
87
|
+
|
|
88
|
+
def logcdf(
|
|
89
|
+
x: Real[ArrayLike, "..."],
|
|
90
|
+
lam: Real[ArrayLike, "..."],
|
|
91
|
+
) -> Real[Array, "..."]:
|
|
92
|
+
"""
|
|
93
|
+
The log of the cumulative density function (log CDF) for an Exponential distribution.
|
|
94
|
+
|
|
95
|
+
# Parameters
|
|
96
|
+
- `x`: Where to evaluate the log CDF (ln P(X <= x)).
|
|
97
|
+
- `lam`: The rate parameter (lambda).
|
|
98
|
+
|
|
99
|
+
# Returns
|
|
100
|
+
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
101
|
+
"""
|
|
102
|
+
# Cast to Array
|
|
103
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
104
|
+
|
|
105
|
+
result = lax.log1p(-lax.exp(-lam * x))
|
|
106
|
+
|
|
107
|
+
# Handle values outside of support (x < 0)
|
|
108
|
+
result = lax.select(x >= 0, result, -jnp.inf)
|
|
109
|
+
|
|
110
|
+
return result
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def ccdf(
|
|
114
|
+
x: Real[ArrayLike, "..."],
|
|
115
|
+
lam: Real[ArrayLike, "..."],
|
|
116
|
+
) -> Real[Array, "..."]:
|
|
117
|
+
"""
|
|
118
|
+
The complementary cumulative density function (cCDF) for an Exponential distribution (P(X > x)).
|
|
119
|
+
|
|
120
|
+
# Parameters
|
|
121
|
+
- `x`: Where to evaluate the cCDF.
|
|
122
|
+
- `lam`: The rate parameter (lambda).
|
|
123
|
+
|
|
124
|
+
# Returns
|
|
125
|
+
The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
126
|
+
"""
|
|
127
|
+
# Cast to Array
|
|
128
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
129
|
+
|
|
130
|
+
result = lax.exp(-lam * x)
|
|
131
|
+
|
|
132
|
+
# Handle values outside of support
|
|
133
|
+
result = lax.select(x >= 0, result, jnp.array(1.0))
|
|
134
|
+
|
|
135
|
+
return result
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def logccdf(
|
|
139
|
+
x: Real[ArrayLike, "..."],
|
|
140
|
+
lam: Real[ArrayLike, "..."],
|
|
141
|
+
) -> Real[Array, "..."]:
|
|
142
|
+
"""
|
|
143
|
+
The log of the complementary cumulative density function (log cCDF) for an Exponential distribution.
|
|
144
|
+
|
|
145
|
+
# Parameters
|
|
146
|
+
- `x`: Where to evaluate the log cCDF.
|
|
147
|
+
- `lam`: The rate parameter (lambda).
|
|
148
|
+
|
|
149
|
+
# Returns
|
|
150
|
+
The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
|
|
151
|
+
"""
|
|
152
|
+
# Cast to Array
|
|
153
|
+
x, lam = jnp.asarray(x), jnp.asarray(lam)
|
|
154
|
+
|
|
155
|
+
# log(cCDF(x)) = -lambda * x for x >= 0
|
|
156
|
+
log_ccdf_val = -lam * x
|
|
157
|
+
|
|
158
|
+
# Handle values outside of support (x < 0), log(P(X > x)) = log(1.0) = 0.0
|
|
159
|
+
return lax.select(x >= 0, log_ccdf_val, jnp.array(0.0))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class Exponential(Distribution):
|
|
163
|
+
"""
|
|
164
|
+
An Exponential distribution.
|
|
165
|
+
|
|
166
|
+
# Attributes
|
|
167
|
+
- `lam`: The rate parameter (lambda), must be positive.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
lam: Node[Real[Array, "..."]]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
lam: Real[ArrayLike, "..."] | Node[Real[Array, "..."]],
|
|
176
|
+
):
|
|
177
|
+
# Initialize rate parameter (lambda)
|
|
178
|
+
if isinstance(lam, Node):
|
|
179
|
+
if isinstance(lam.obj, ArrayLike):
|
|
180
|
+
self.lam = lam # type: ignore
|
|
181
|
+
else:
|
|
182
|
+
self.lam = Observed(jnp.asarray(lam))
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def logprob(self, node: Node) -> Scalar:
|
|
186
|
+
obj, lam = (
|
|
187
|
+
node.obj,
|
|
188
|
+
self.lam.obj,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Filter out irrelevant values
|
|
192
|
+
obj, _ = eqx.partition(obj, node._filter_spec)
|
|
193
|
+
|
|
194
|
+
# Helper function for the single-leaf log-probability evaluation
|
|
195
|
+
def leaf_logprob(x: Real[ArrayLike, "..."]) -> Scalar:
|
|
196
|
+
return logprob(x, lam).sum()
|
|
197
|
+
|
|
198
|
+
# Compute log probabilities across leaves
|
|
199
|
+
eval_obj = jt.map(leaf_logprob, obj)
|
|
200
|
+
|
|
201
|
+
# Compute total sum
|
|
202
|
+
total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
|
|
203
|
+
|
|
204
|
+
return jnp.asarray(total)
|
|
205
|
+
|
|
206
|
+
def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
|
|
207
|
+
# Coerce to tuple
|
|
208
|
+
if isinstance(shape, int):
|
|
209
|
+
shape = (shape, )
|
|
210
|
+
|
|
211
|
+
return jr.exponential(key, shape=shape) * 1.0 / self.lam.obj
|