bayinx 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/dists/gamma2.py +39 -0
- bayinx/dists/normal.py +11 -11
- bayinx/mhx/vi/flows/fullaffine.py +10 -10
- {bayinx-0.2.28.dist-info → bayinx-0.2.29.dist-info}/METADATA +1 -1
- {bayinx-0.2.28.dist-info → bayinx-0.2.29.dist-info}/RECORD +6 -8
- bayinx/dists/binomial.py +0 -0
- bayinx/dists/gamma.py +0 -0
- {bayinx-0.2.28.dist-info → bayinx-0.2.29.dist-info}/WHEEL +0 -0
bayinx/dists/gamma2.py
CHANGED
@@ -0,0 +1,39 @@
|
|
1
|
+
import jax.lax as lax
|
2
|
+
from jax.scipy.special import gammaln
|
3
|
+
from jaxtyping import Array, ArrayLike, Float, Real
|
4
|
+
|
5
|
+
|
6
|
+
def prob(
|
7
|
+
x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], nu: Real[ArrayLike, "..."]
|
8
|
+
) -> Float[Array, "..."]:
|
9
|
+
"""
|
10
|
+
The probability density function (PDF) for a (mean-precision parameterized) Gamma distribution.
|
11
|
+
|
12
|
+
# Parameters
|
13
|
+
- `x`: Value(s) at which to evaluate the PDF.
|
14
|
+
- `mu`: The mean.
|
15
|
+
- `nu`: The positive inverse dispersion.
|
16
|
+
|
17
|
+
# Returns
|
18
|
+
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
19
|
+
"""
|
20
|
+
|
21
|
+
return lax.exp(logprob(x, mu, nu))
|
22
|
+
|
23
|
+
|
24
|
+
def logprob(
|
25
|
+
x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], nu: Real[ArrayLike, "..."]
|
26
|
+
) -> Float[Array, "..."]:
|
27
|
+
"""
|
28
|
+
The log-transformed probability density function (log PDF) for a (mean-precision parameterized) Gamma distribution.
|
29
|
+
|
30
|
+
# Parameters
|
31
|
+
- `x`: Value(s) at which to evaluate the log PDF.
|
32
|
+
- `mu`: The mean/location.
|
33
|
+
- `nu`: The positive inverse dispersion.
|
34
|
+
|
35
|
+
# Returns
|
36
|
+
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
37
|
+
"""
|
38
|
+
|
39
|
+
return - gammaln(nu) + nu * (lax.log(nu) - lax.log(mu)) + (nu - 1.0) * lax.log(x) - (x * nu / mu) # pyright: ignore
|
bayinx/dists/normal.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import jax.lax as
|
1
|
+
import jax.lax as lax
|
2
2
|
from jaxtyping import Array, ArrayLike, Float, Real
|
3
3
|
|
4
4
|
__PI = 3.141592653589793
|
@@ -12,15 +12,15 @@ def prob(
|
|
12
12
|
|
13
13
|
# Parameters
|
14
14
|
- `x`: Value(s) at which to evaluate the PDF.
|
15
|
-
- `mu`: The mean/location
|
16
|
-
- `sigma`: The
|
15
|
+
- `mu`: The mean/location.
|
16
|
+
- `sigma`: The positive standard deviation.
|
17
17
|
|
18
18
|
# Returns
|
19
19
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
20
20
|
"""
|
21
21
|
|
22
|
-
return
|
23
|
-
sigma *
|
22
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / ( # pyright: ignore
|
23
|
+
sigma * lax.sqrt(2.0 * __PI)
|
24
24
|
)
|
25
25
|
|
26
26
|
|
@@ -36,11 +36,11 @@ def logprob(
|
|
36
36
|
- `sigma`: The non-negative standard deviation parameter(s).
|
37
37
|
|
38
38
|
# Returns
|
39
|
-
The log
|
39
|
+
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
40
40
|
"""
|
41
41
|
|
42
|
-
return -
|
43
|
-
(x - mu) / sigma
|
42
|
+
return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square(
|
43
|
+
(x - mu) / sigma # pyright: ignore
|
44
44
|
)
|
45
45
|
|
46
46
|
|
@@ -53,13 +53,13 @@ def uprob(
|
|
53
53
|
# Parameters
|
54
54
|
- `x`: Value(s) at which to evaluate the uPDF.
|
55
55
|
- `mu`: The mean/location parameter(s).
|
56
|
-
- `sigma`: The
|
56
|
+
- `sigma`: The positive standard deviation parameter(s).
|
57
57
|
|
58
58
|
# Returns
|
59
59
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
60
60
|
"""
|
61
61
|
|
62
|
-
return
|
62
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma # pyright: ignore
|
63
63
|
|
64
64
|
|
65
65
|
def ulogprob(
|
@@ -77,4 +77,4 @@ def ulogprob(
|
|
77
77
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
78
78
|
"""
|
79
79
|
|
80
|
-
return -
|
80
|
+
return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma) # pyright: ignore
|
@@ -30,19 +30,19 @@ class FullAffine(Flow):
|
|
30
30
|
"scale": jnp.zeros((dim, dim)),
|
31
31
|
}
|
32
32
|
|
33
|
-
|
33
|
+
if dim == 1:
|
34
|
+
self.constraints = {}
|
35
|
+
else:
|
34
36
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
# Extract diagonal and apply exponential
|
40
|
-
diag: Array = jnp.exp(jnp.diag(params["scale"]))
|
37
|
+
@eqx.filter_jit
|
38
|
+
def constrain_scale(scale: Array):
|
39
|
+
# Extract diagonal and apply exponential
|
40
|
+
diag: Array = jnp.exp(jnp.diag(scale))
|
41
41
|
|
42
|
-
|
43
|
-
|
42
|
+
# Return matrix with modified diagonal
|
43
|
+
return jnp.fill_diagonal(scale, diag, inplace=False)
|
44
44
|
|
45
|
-
|
45
|
+
self.constraints = {"scale": constrain_scale}
|
46
46
|
|
47
47
|
@eqx.filter_jit
|
48
48
|
def forward(self, draws: Array) -> Array:
|
@@ -7,10 +7,8 @@ bayinx/core/model.py,sha256=Z_HaFr0_-keMjG5tg3xxP3hGML7aDFIcCI8Y5dGrtM4,2145
|
|
7
7
|
bayinx/core/variational.py,sha256=W0747jfVJFAtMZqL3mpbtl2wfnARHln-dVBag4xZ23Y,4813
|
8
8
|
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
10
|
-
bayinx/dists/
|
11
|
-
bayinx/dists/
|
12
|
-
bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
bayinx/dists/normal.py,sha256=3CXSgHWnuglmP8cKVUh2Yt4Rb9_LR_mwPRXDm_LuSRo,2679
|
10
|
+
bayinx/dists/gamma2.py,sha256=jc9jiNuuIAv3tdWghQ2Y2ANpJbDscPJKOvs6dOJVJD0,1315
|
11
|
+
bayinx/dists/normal.py,sha256=AXsf3Xe2BfCyzfcQ5i8J9AD92LedQjgqviUMQd697D8,2628
|
14
12
|
bayinx/dists/uniform.py,sha256=mogFe8VuDelM9KXE6RxGek0-tuZYFrwmo_oMOPHXleA,2359
|
15
13
|
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
16
14
|
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
@@ -18,10 +16,10 @@ bayinx/mhx/vi/meanfield.py,sha256=8hM1KZ52TpRPLwiQcowsJLlQ-5nJzUEcKrtDiGrFoSs,37
|
|
18
16
|
bayinx/mhx/vi/normalizing_flow.py,sha256=FvxDtqGRtaEeeF-bXCYnIEAvOOXVHKUK0oCTF9ma02Y,4622
|
19
17
|
bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
|
20
18
|
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
21
|
-
bayinx/mhx/vi/flows/fullaffine.py,sha256=
|
19
|
+
bayinx/mhx/vi/flows/fullaffine.py,sha256=s-fxgzv84BEqNMnmLt6vtvwJqBzixCC2OwWXLz3IK-w,1940
|
22
20
|
bayinx/mhx/vi/flows/planar.py,sha256=WVj-oxcRctuoRA6KJjU63ek1ZgKNG2vI-TLN0QqjtKA,1916
|
23
21
|
bayinx/mhx/vi/flows/radial.py,sha256=Obj3SraliawIHmP14F9wRpWt34y3kscY--Izy24eCvM,2499
|
24
22
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
25
|
-
bayinx-0.2.
|
26
|
-
bayinx-0.2.
|
27
|
-
bayinx-0.2.
|
23
|
+
bayinx-0.2.29.dist-info/METADATA,sha256=kfq898UxqnChe-87HPe46_icfS2btMbJPp3INv3IXo0,3058
|
24
|
+
bayinx-0.2.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
25
|
+
bayinx-0.2.29.dist-info/RECORD,,
|
bayinx/dists/binomial.py
DELETED
File without changes
|
bayinx/dists/gamma.py
DELETED
File without changes
|
File without changes
|