bayinx 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/dists/__init__.py +3 -0
- bayinx/dists/censored/__init__.py +3 -0
- bayinx/dists/censored/gamma2/__init__.py +3 -0
- bayinx/dists/censored/gamma2/r.py +6 -3
- bayinx/dists/censored/posnormal/__init__.py +3 -0
- bayinx/dists/censored/posnormal/r.py +68 -0
- bayinx/dists/normal.py +46 -4
- bayinx/dists/posnormal.py +258 -0
- {bayinx-0.3.3.dist-info → bayinx-0.3.5.dist-info}/METADATA +1 -1
- {bayinx-0.3.3.dist-info → bayinx-0.3.5.dist-info}/RECORD +11 -7
- {bayinx-0.3.3.dist-info → bayinx-0.3.5.dist-info}/WHEEL +0 -0
bayinx/dists/__init__.py
CHANGED
@@ -19,9 +19,10 @@ def prob(
|
|
19
19
|
- `x`: Value(s) at which to evaluate the PMF/PDF.
|
20
20
|
- `mu`: The positive mean.
|
21
21
|
- `nu`: The positive inverse dispersion.
|
22
|
+
- `censor`: The positive censor value.
|
22
23
|
|
23
24
|
# Returns
|
24
|
-
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `
|
25
|
+
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
|
25
26
|
"""
|
26
27
|
evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
|
27
28
|
|
@@ -29,7 +30,7 @@ def prob(
|
|
29
30
|
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
30
31
|
censored: Array = jnp.array(x == censor) # pyright: ignore
|
31
32
|
|
32
|
-
# Evaluate
|
33
|
+
# Evaluate probability mass/density function
|
33
34
|
evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
|
34
35
|
evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
|
35
36
|
|
@@ -49,9 +50,10 @@ def logprob(
|
|
49
50
|
- `x`: Value(s) at which to evaluate the log PMF/PDF.
|
50
51
|
- `mu`: The positive mean/location.
|
51
52
|
- `nu`: The positive inverse dispersion.
|
53
|
+
- `censor`: The positive censor value.
|
52
54
|
|
53
55
|
# Returns
|
54
|
-
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `
|
56
|
+
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
|
55
57
|
"""
|
56
58
|
evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
|
57
59
|
|
@@ -59,6 +61,7 @@ def logprob(
|
|
59
61
|
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
60
62
|
censored: Array = jnp.array(x == censor) # pyright: ignore
|
61
63
|
|
64
|
+
# Evaluate log probability mass/density function
|
62
65
|
evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
|
63
66
|
evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
|
64
67
|
|
@@ -0,0 +1,68 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jaxtyping import Array, ArrayLike, Float
|
3
|
+
|
4
|
+
from bayinx.dists import posnormal
|
5
|
+
|
6
|
+
|
7
|
+
def prob(
|
8
|
+
x: Float[ArrayLike, "..."],
|
9
|
+
mu: Float[ArrayLike, "..."],
|
10
|
+
sigma: Float[ArrayLike, "..."],
|
11
|
+
censor: Float[ArrayLike, "..."]
|
12
|
+
) -> Float[Array, "..."]:
|
13
|
+
"""
|
14
|
+
The mixed probability mass/density function (PMF/PDF) for a censored positive Normal distribution.
|
15
|
+
|
16
|
+
# Parameters
|
17
|
+
- `x`: Value(s) at which to evaluate the PMF/PDF.
|
18
|
+
- `mu`: The mean.
|
19
|
+
- `sigma`: The positive standard deviation.
|
20
|
+
- `censor`: The positive censor value.
|
21
|
+
|
22
|
+
# Returns
|
23
|
+
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
|
24
|
+
"""
|
25
|
+
# Cast to Array
|
26
|
+
x, mu, sigma, censor = jnp.array(x), jnp.array(mu), jnp.array(sigma), jnp.array(censor)
|
27
|
+
|
28
|
+
# Construct boolean masks
|
29
|
+
uncensored: Array = jnp.logical_and(0.0 < x, x < censor)
|
30
|
+
censored: Array = x == censor
|
31
|
+
|
32
|
+
# Evaluate probability mass/density function
|
33
|
+
evals = jnp.where(uncensored, posnormal.prob(x, mu, sigma), 0.0)
|
34
|
+
evals = jnp.where(censored, posnormal.ccdf(x,mu,sigma), evals)
|
35
|
+
|
36
|
+
return evals
|
37
|
+
|
38
|
+
|
39
|
+
def logprob(
|
40
|
+
x: Float[ArrayLike, "..."],
|
41
|
+
mu: Float[ArrayLike, "..."],
|
42
|
+
sigma: Float[ArrayLike, "..."],
|
43
|
+
censor: Float[ArrayLike, "..."]
|
44
|
+
) -> Float[Array, "..."]:
|
45
|
+
"""
|
46
|
+
The log-transformed mixed probability mass/density function (log PMF/PDF) for a censored positive Normal distribution.
|
47
|
+
|
48
|
+
# Parameters
|
49
|
+
- `x`: Where to evaluate the log PMF/PDF.
|
50
|
+
- `mu`: The mean.
|
51
|
+
- `sigma`: The standard deviation.
|
52
|
+
- `censor`: The censor.
|
53
|
+
|
54
|
+
# Returns
|
55
|
+
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
|
56
|
+
"""
|
57
|
+
# Cast to Array
|
58
|
+
x, mu, sigma, censor = jnp.array(x), jnp.array(mu), jnp.array(sigma), jnp.array(censor)
|
59
|
+
|
60
|
+
# Construct boolean masks
|
61
|
+
uncensored: Array = jnp.logical_and(jnp.array(0.0) < x, x < censor)
|
62
|
+
censored: Array = x == censor
|
63
|
+
|
64
|
+
# Evaluate log probability mass/density function
|
65
|
+
evals = jnp.where(uncensored, posnormal.logprob(x, mu, sigma), -jnp.inf)
|
66
|
+
evals = jnp.where(censored, posnormal.logccdf(x, mu, sigma), evals)
|
67
|
+
|
68
|
+
return evals
|
bayinx/dists/normal.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
1
|
import jax.lax as lax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jax.scipy.special as jss
|
2
4
|
from jaxtyping import Array, ArrayLike, Float
|
3
5
|
|
4
6
|
__PI = 3.141592653589793
|
@@ -18,8 +20,10 @@ def prob(
|
|
18
20
|
# Returns
|
19
21
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
20
22
|
"""
|
23
|
+
# Cast to Array
|
24
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
21
25
|
|
22
|
-
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (
|
26
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (
|
23
27
|
sigma * lax.sqrt(2.0 * __PI)
|
24
28
|
)
|
25
29
|
|
@@ -38,9 +42,11 @@ def logprob(
|
|
38
42
|
# Returns
|
39
43
|
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
40
44
|
"""
|
45
|
+
# Cast to Array
|
46
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
41
47
|
|
42
48
|
return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square(
|
43
|
-
(x - mu) / sigma
|
49
|
+
(x - mu) / sigma
|
44
50
|
)
|
45
51
|
|
46
52
|
|
@@ -58,8 +64,10 @@ def uprob(
|
|
58
64
|
# Returns
|
59
65
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
60
66
|
"""
|
67
|
+
# Cast to Array
|
68
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
61
69
|
|
62
|
-
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma
|
70
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma
|
63
71
|
|
64
72
|
|
65
73
|
def ulogprob(
|
@@ -76,5 +84,39 @@ def ulogprob(
|
|
76
84
|
# Returns
|
77
85
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
78
86
|
"""
|
87
|
+
# Cast to Array
|
88
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
79
89
|
|
80
|
-
return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
|
90
|
+
return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
|
91
|
+
|
92
|
+
def cdf(
|
93
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
94
|
+
) -> Float[Array, "..."]:
|
95
|
+
# Cast to Array
|
96
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
97
|
+
|
98
|
+
return jss.ndtr((x - mu) / sigma)
|
99
|
+
|
100
|
+
def logcdf(
|
101
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
102
|
+
) -> Float[Array, "..."]:
|
103
|
+
# Cast to Array
|
104
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
105
|
+
|
106
|
+
return jss.log_ndtr((x - mu) / sigma)
|
107
|
+
|
108
|
+
def ccdf(
|
109
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
110
|
+
) -> Float[Array, "..."]:
|
111
|
+
# Cast to Array
|
112
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
113
|
+
|
114
|
+
return jss.ndtr((mu - x) / sigma)
|
115
|
+
|
116
|
+
def logccdf(
|
117
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
118
|
+
) -> Float[Array, "..."]:
|
119
|
+
# Cast to Array
|
120
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
121
|
+
|
122
|
+
return jss.log_ndtr((mu - x) / sigma)
|
@@ -0,0 +1,258 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jaxtyping import Array, ArrayLike, Float
|
3
|
+
|
4
|
+
from bayinx.dists import normal
|
5
|
+
|
6
|
+
|
7
|
+
def prob(
|
8
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
9
|
+
) -> Float[Array, "..."]:
|
10
|
+
"""
|
11
|
+
The probability density function (PDF) for a positive Normal distribution.
|
12
|
+
|
13
|
+
# Parameters
|
14
|
+
- `x`: Value(s) at which to evaluate the PDF.
|
15
|
+
- `mu`: The mean.
|
16
|
+
- `sigma`: The standard deviation.
|
17
|
+
|
18
|
+
# Returns
|
19
|
+
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
20
|
+
"""
|
21
|
+
# Cast to Array
|
22
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
23
|
+
|
24
|
+
# Construct boolean mask for non-negative elements
|
25
|
+
non_negative: Array = jnp.array(0.0) <= x
|
26
|
+
|
27
|
+
# Evaluate PDF
|
28
|
+
evals = jnp.where(
|
29
|
+
non_negative,
|
30
|
+
normal.prob(x, mu, sigma) / normal.cdf(mu/sigma, 0.0, 1.0),
|
31
|
+
jnp.array(0.0))
|
32
|
+
|
33
|
+
return evals
|
34
|
+
|
35
|
+
|
36
|
+
def logprob(
|
37
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
38
|
+
) -> Float[Array, "..."]:
|
39
|
+
"""
|
40
|
+
The log of the probability density function (log PDF) for a positive Normal distribution.
|
41
|
+
|
42
|
+
# Parameters
|
43
|
+
- `x`: Value(s) at which to evaluate the log PDF.
|
44
|
+
- `mu`: The mean.
|
45
|
+
- `sigma`: The standard deviation.
|
46
|
+
|
47
|
+
# Returns
|
48
|
+
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
49
|
+
"""
|
50
|
+
# Cast to Array
|
51
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
52
|
+
|
53
|
+
# Construct boolean mask for non-negative elements
|
54
|
+
non_negative: Array = jnp.array(0.0) <= x
|
55
|
+
|
56
|
+
# Evaluate log PDF
|
57
|
+
evals = jnp.where(
|
58
|
+
non_negative,
|
59
|
+
normal.logprob(x, mu, sigma) - normal.logcdf(mu/sigma, 0.0, 1.0),
|
60
|
+
-jnp.inf)
|
61
|
+
|
62
|
+
return evals
|
63
|
+
|
64
|
+
|
65
|
+
def uprob(
|
66
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
67
|
+
) -> Float[Array, "..."]:
|
68
|
+
"""
|
69
|
+
The unnormalized probability density function (uPDF) for a positive Normal distribution.
|
70
|
+
|
71
|
+
# Parameters
|
72
|
+
- `x`: Value(s) at which to evaluate the uPDF.
|
73
|
+
- `mu`: The mean/location parameter(s).
|
74
|
+
- `sigma`: The positive standard deviation parameter(s).
|
75
|
+
|
76
|
+
# Returns
|
77
|
+
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
78
|
+
"""
|
79
|
+
# Cast to Array
|
80
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
81
|
+
|
82
|
+
# Construct boolean mask for non-negative elements
|
83
|
+
non_negative: Array = jnp.array(0.0) <= x
|
84
|
+
|
85
|
+
# Evaluate PDF
|
86
|
+
evals = jnp.where(
|
87
|
+
non_negative,
|
88
|
+
normal.prob(x, mu, sigma),
|
89
|
+
jnp.array(0.0))
|
90
|
+
|
91
|
+
return evals
|
92
|
+
|
93
|
+
|
94
|
+
def ulogprob(
|
95
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
96
|
+
) -> Float[Array, "..."]:
|
97
|
+
"""
|
98
|
+
The log of the unnormalized probability density function (log uPDF) for a positive Normal distribution.
|
99
|
+
|
100
|
+
# Parameters
|
101
|
+
- `x`: Value(s) at which to evaluate the log uPDF.
|
102
|
+
- `mu`: The mean/location parameter(s).
|
103
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
104
|
+
|
105
|
+
# Returns
|
106
|
+
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
107
|
+
"""
|
108
|
+
# Cast to Array
|
109
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
110
|
+
|
111
|
+
# Construct boolean mask for non-negative elements
|
112
|
+
non_negative: Array = jnp.array(0.0) <= x
|
113
|
+
|
114
|
+
# Evaluate log PDF
|
115
|
+
evals = jnp.where(
|
116
|
+
non_negative,
|
117
|
+
normal.logprob(x, mu, sigma),
|
118
|
+
-jnp.inf)
|
119
|
+
|
120
|
+
return evals
|
121
|
+
|
122
|
+
|
123
|
+
def cdf(
|
124
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
125
|
+
) -> Float[Array, "..."]:
|
126
|
+
"""
|
127
|
+
The cumulative density function (CDF) for a positive Normal distribution.
|
128
|
+
|
129
|
+
# Parameters
|
130
|
+
- `x`: Value(s) at which to evaluate the log uPDF.
|
131
|
+
- `mu`: The mean/location parameter(s).
|
132
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
133
|
+
|
134
|
+
# Returns
|
135
|
+
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
136
|
+
|
137
|
+
# Notes
|
138
|
+
Not numerically stable for small `x`.
|
139
|
+
"""
|
140
|
+
# Cast to Array
|
141
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
142
|
+
|
143
|
+
# Construct boolean mask for non-negative elements
|
144
|
+
non_negative: Array = jnp.array(0.0) <= x
|
145
|
+
|
146
|
+
# Compute intermediates
|
147
|
+
A: Array = normal.cdf(x, mu, sigma)
|
148
|
+
B: Array = normal.cdf(- mu / sigma, 0.0, 1.0)
|
149
|
+
C: Array = normal.cdf(mu / sigma, 0.0, 1.0)
|
150
|
+
|
151
|
+
# Evaluate CDF
|
152
|
+
evals = jnp.where(
|
153
|
+
non_negative,
|
154
|
+
(A - B) / C,
|
155
|
+
jnp.array(0.0))
|
156
|
+
|
157
|
+
return evals
|
158
|
+
|
159
|
+
# TODO: make numerically stable
|
160
|
+
def logcdf(
|
161
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
162
|
+
) -> Float[Array, "..."]:
|
163
|
+
"""
|
164
|
+
The log-transformed cumulative density function (log CDF) for a positive Normal distribution.
|
165
|
+
|
166
|
+
# Parameters
|
167
|
+
- `x`: Value(s) at which to evaluate the log uPDF.
|
168
|
+
- `mu`: The mean/location parameter(s).
|
169
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
170
|
+
|
171
|
+
# Returns
|
172
|
+
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
173
|
+
|
174
|
+
# Notes
|
175
|
+
Not numerically stable for small `x`.
|
176
|
+
"""
|
177
|
+
# Cast to Array
|
178
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
179
|
+
|
180
|
+
# Construct boolean mask for non-negative elements
|
181
|
+
non_negative: Array = jnp.array(0.0) <= x
|
182
|
+
|
183
|
+
A: Array = normal.logcdf(x, mu, sigma)
|
184
|
+
B: Array = normal.logcdf(- mu/sigma, 0.0, 1.0)
|
185
|
+
C: Array = normal.logcdf(mu/sigma, 0.0, 1.0)
|
186
|
+
|
187
|
+
# Evaluate log CDF
|
188
|
+
evals = jnp.where(
|
189
|
+
non_negative,
|
190
|
+
A + jnp.log1p(-jnp.exp(B - A)) - C,
|
191
|
+
-jnp.inf)
|
192
|
+
|
193
|
+
return evals
|
194
|
+
|
195
|
+
def ccdf(
|
196
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
197
|
+
) -> Float[Array, "..."]:
|
198
|
+
"""
|
199
|
+
The complementary cumulative density function (cCDF) for a positive Normal distribution.
|
200
|
+
|
201
|
+
# Parameters
|
202
|
+
- `x`: Value(s) at which to evaluate the log uPDF.
|
203
|
+
- `mu`: The mean/location parameter(s).
|
204
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
205
|
+
|
206
|
+
# Returns
|
207
|
+
The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
208
|
+
|
209
|
+
# Notes
|
210
|
+
Not numerically stable for small `x`.
|
211
|
+
"""
|
212
|
+
# Cast to arrays
|
213
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
214
|
+
|
215
|
+
# Construct boolean mask for non-negative elements
|
216
|
+
non_negative: Array = 0.0 <= x
|
217
|
+
|
218
|
+
# Compute intermediates
|
219
|
+
A: Array = normal.cdf(-x, -mu, sigma)
|
220
|
+
B: Array = normal.cdf(mu/sigma, 0.0, 1.0)
|
221
|
+
|
222
|
+
# Evaluate cCDF
|
223
|
+
evals = jnp.where(non_negative, A / B, jnp.array(1.0))
|
224
|
+
|
225
|
+
return evals
|
226
|
+
|
227
|
+
|
228
|
+
def logccdf(
|
229
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
230
|
+
) -> Float[Array, "..."]:
|
231
|
+
"""
|
232
|
+
The log-transformed complementary cumulative density function (log cCDF) for a positive Normal distribution.
|
233
|
+
|
234
|
+
# Parameters
|
235
|
+
- `x`: Value(s) at which to evaluate the log uPDF.
|
236
|
+
- `mu`: The mean/location parameter(s).
|
237
|
+
- `sigma`: The non-negative standard deviation parameter(s).
|
238
|
+
|
239
|
+
# Returns
|
240
|
+
The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
241
|
+
|
242
|
+
# Notes
|
243
|
+
Not numerically stable for small `x`.
|
244
|
+
"""
|
245
|
+
# Cast to arrays
|
246
|
+
x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
|
247
|
+
|
248
|
+
# Construct boolean mask for non-negative elements
|
249
|
+
non_negative: Array = 0.0 <= x
|
250
|
+
|
251
|
+
# Compute intermediates
|
252
|
+
A: Array = normal.logcdf(-x, -mu, sigma)
|
253
|
+
B: Array = normal.logcdf(mu/sigma, 0.0, 1.0)
|
254
|
+
|
255
|
+
# Evaluate log cCDF
|
256
|
+
evals = jnp.where(non_negative, A - B, jnp.array(0.0))
|
257
|
+
|
258
|
+
return evals
|
@@ -8,13 +8,17 @@ bayinx/core/flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
|
8
8
|
bayinx/core/model.py,sha256=1vQPVjE0ebCdW7mLuabgQcCTi95o8n8CC6GuzJdNL1s,2956
|
9
9
|
bayinx/core/parameter.py,sha256=eECqvfMNWSU8_CkGYaAfOCneMMQGZI21kF0mErsh2Rc,1080
|
10
10
|
bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
|
11
|
-
bayinx/dists/__init__.py,sha256=
|
11
|
+
bayinx/dists/__init__.py,sha256=qPQrl5vkS9K56GzIaHZXkSUP07YAu4lVB8K2yQ1m3SY,78
|
12
12
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
13
13
|
bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
|
14
|
-
bayinx/dists/normal.py,sha256=
|
14
|
+
bayinx/dists/normal.py,sha256=BLlp7hGMAxUbroROvzA5ChH5YLXgadeK4VOuBtjjdjs,3978
|
15
|
+
bayinx/dists/posnormal.py,sha256=NNr5OHv1fWCxYvc6hwUMIGXX31UAg0sEnc4tsxHLjUg,7726
|
15
16
|
bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
|
16
|
-
bayinx/dists/censored/__init__.py,sha256=
|
17
|
-
bayinx/dists/censored/gamma2/
|
17
|
+
bayinx/dists/censored/__init__.py,sha256=p8T03TenD-_8YNiOgB_RKksq8hFNFejA5bnoK4JJ8Ms,67
|
18
|
+
bayinx/dists/censored/gamma2/__init__.py,sha256=qqm0n2hfid617PvyFRHAOMAp3AvpOlt5v3ns8HgTD7E,33
|
19
|
+
bayinx/dists/censored/gamma2/r.py,sha256=dE0MNTAl0E6npQhFONv341U7XbomBB-fNzQhgRjxYpk,2436
|
20
|
+
bayinx/dists/censored/posnormal/__init__.py,sha256=qqm0n2hfid617PvyFRHAOMAp3AvpOlt5v3ns8HgTD7E,33
|
21
|
+
bayinx/dists/censored/posnormal/r.py,sha256=4MfFkQ2klzOZJNjxS9g4zz1bdoJ6ehBxZQi6QkmPGgE,2232
|
18
22
|
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
19
23
|
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
20
24
|
bayinx/mhx/vi/meanfield.py,sha256=M4QrOuHaIMLTuQSD6JNF9vELnTm370tXV68JPB7B67M,3652
|
@@ -25,6 +29,6 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
|
|
25
29
|
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
26
30
|
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
27
31
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
28
|
-
bayinx-0.3.
|
29
|
-
bayinx-0.3.
|
30
|
-
bayinx-0.3.
|
32
|
+
bayinx-0.3.5.dist-info/METADATA,sha256=Hj8GWJef3kfJ6umsHGIFWovYXXtPegAlcsopunoHFFs,3057
|
33
|
+
bayinx-0.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
34
|
+
bayinx-0.3.5.dist-info/RECORD,,
|
File without changes
|