bayinx 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/dists/gamma2.py CHANGED
@@ -0,0 +1,39 @@
1
+ import jax.lax as lax
2
+ from jax.scipy.special import gammaln
3
+ from jaxtyping import Array, ArrayLike, Float, Real
4
+
5
+
6
+ def prob(
7
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], nu: Real[ArrayLike, "..."]
8
+ ) -> Float[Array, "..."]:
9
+ """
10
+ The probability density function (PDF) for a (mean-precision parameterized) Gamma distribution.
11
+
12
+ # Parameters
13
+ - `x`: Value(s) at which to evaluate the PDF.
14
+ - `mu`: The mean.
15
+ - `nu`: The positive inverse dispersion.
16
+
17
+ # Returns
18
+ The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
19
+ """
20
+
21
+ return lax.exp(logprob(x, mu, nu))
22
+
23
+
24
+ def logprob(
25
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], nu: Real[ArrayLike, "..."]
26
+ ) -> Float[Array, "..."]:
27
+ """
28
+ The log-transformed probability density function (log PDF) for a (mean-precision parameterized) Gamma distribution.
29
+
30
+ # Parameters
31
+ - `x`: Value(s) at which to evaluate the log PDF.
32
+ - `mu`: The mean/location.
33
+ - `nu`: The positive inverse dispersion.
34
+
35
+ # Returns
36
+ The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
37
+ """
38
+
39
+ return - gammaln(nu) + nu * (lax.log(nu) - lax.log(mu)) + (nu - 1.0) * lax.log(x) - (x * nu / mu) # pyright: ignore
bayinx/dists/normal.py CHANGED
@@ -1,4 +1,4 @@
1
- import jax.lax as _lax
1
+ import jax.lax as lax
2
2
  from jaxtyping import Array, ArrayLike, Float, Real
3
3
 
4
4
  __PI = 3.141592653589793
@@ -12,15 +12,15 @@ def prob(
12
12
 
13
13
  # Parameters
14
14
  - `x`: Value(s) at which to evaluate the PDF.
15
- - `mu`: The mean/location parameter(s).
16
- - `sigma`: The non-negative standard deviation parameter(s).
15
+ - `mu`: The mean/location.
16
+ - `sigma`: The positive standard deviation.
17
17
 
18
18
  # Returns
19
19
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
20
20
  """
21
21
 
22
- return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / ( # pyright: ignore
23
- sigma * _lax.sqrt(2.0 * __PI)
22
+ return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / ( # pyright: ignore
23
+ sigma * lax.sqrt(2.0 * __PI)
24
24
  )
25
25
 
26
26
 
@@ -36,11 +36,11 @@ def logprob(
36
36
  - `sigma`: The non-negative standard deviation parameter(s).
37
37
 
38
38
  # Returns
39
- The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
39
+ The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
40
40
  """
41
41
 
42
- return -_lax.log(sigma * _lax.sqrt(2.0 * __PI)) - 0.5 * _lax.square(
43
- (x - mu) / sigma # pyright: ignore
42
+ return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square(
43
+ (x - mu) / sigma # pyright: ignore
44
44
  )
45
45
 
46
46
 
@@ -53,13 +53,13 @@ def uprob(
53
53
  # Parameters
54
54
  - `x`: Value(s) at which to evaluate the uPDF.
55
55
  - `mu`: The mean/location parameter(s).
56
- - `sigma`: The non-negative standard deviation parameter(s).
56
+ - `sigma`: The positive standard deviation parameter(s).
57
57
 
58
58
  # Returns
59
59
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
60
60
  """
61
61
 
62
- return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma # pyright: ignore
62
+ return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma # pyright: ignore
63
63
 
64
64
 
65
65
  def ulogprob(
@@ -77,4 +77,4 @@ def ulogprob(
77
77
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
78
78
  """
79
79
 
80
- return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma) # pyright: ignore
80
+ return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma) # pyright: ignore
@@ -30,19 +30,19 @@ class FullAffine(Flow):
30
30
  "scale": jnp.zeros((dim, dim)),
31
31
  }
32
32
 
33
- self.constraints = {"scale": lambda m: jnp.tril(m)}
33
+ if dim == 1:
34
+ self.constraints = {}
35
+ else:
34
36
 
35
- @eqx.filter_jit
36
- def transform_pars(self):
37
- params = self.constrain_pars()
38
-
39
- # Extract diagonal and apply exponential
40
- diag: Array = jnp.exp(jnp.diag(params["scale"]))
37
+ @eqx.filter_jit
38
+ def constrain_scale(scale: Array):
39
+ # Extract diagonal and apply exponential
40
+ diag: Array = jnp.exp(jnp.diag(scale))
41
41
 
42
- # Fill diagonal
43
- params["scale"] = jnp.fill_diagonal(params["scale"], diag, inplace=False)
42
+ # Return matrix with modified diagonal
43
+ return jnp.fill_diagonal(scale, diag, inplace=False)
44
44
 
45
- return params
45
+ self.constraints = {"scale": constrain_scale}
46
46
 
47
47
  @eqx.filter_jit
48
48
  def forward(self, draws: Array) -> Array:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.28
3
+ Version: 0.2.29
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -7,10 +7,8 @@ bayinx/core/model.py,sha256=Z_HaFr0_-keMjG5tg3xxP3hGML7aDFIcCI8Y5dGrtM4,2145
7
7
  bayinx/core/variational.py,sha256=W0747jfVJFAtMZqL3mpbtl2wfnARHln-dVBag4xZ23Y,4813
8
8
  bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
10
- bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- bayinx/dists/normal.py,sha256=3CXSgHWnuglmP8cKVUh2Yt4Rb9_LR_mwPRXDm_LuSRo,2679
10
+ bayinx/dists/gamma2.py,sha256=jc9jiNuuIAv3tdWghQ2Y2ANpJbDscPJKOvs6dOJVJD0,1315
11
+ bayinx/dists/normal.py,sha256=AXsf3Xe2BfCyzfcQ5i8J9AD92LedQjgqviUMQd697D8,2628
14
12
  bayinx/dists/uniform.py,sha256=mogFe8VuDelM9KXE6RxGek0-tuZYFrwmo_oMOPHXleA,2359
15
13
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
16
14
  bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
@@ -18,10 +16,10 @@ bayinx/mhx/vi/meanfield.py,sha256=8hM1KZ52TpRPLwiQcowsJLlQ-5nJzUEcKrtDiGrFoSs,37
18
16
  bayinx/mhx/vi/normalizing_flow.py,sha256=FvxDtqGRtaEeeF-bXCYnIEAvOOXVHKUK0oCTF9ma02Y,4622
19
17
  bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
20
18
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
21
- bayinx/mhx/vi/flows/fullaffine.py,sha256=Kvaa8epqaqz9tdMCnf9T_-2P3Bh_TkhA6NrilKHY93A,1886
19
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=s-fxgzv84BEqNMnmLt6vtvwJqBzixCC2OwWXLz3IK-w,1940
22
20
  bayinx/mhx/vi/flows/planar.py,sha256=WVj-oxcRctuoRA6KJjU63ek1ZgKNG2vI-TLN0QqjtKA,1916
23
21
  bayinx/mhx/vi/flows/radial.py,sha256=Obj3SraliawIHmP14F9wRpWt34y3kscY--Izy24eCvM,2499
24
22
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
25
- bayinx-0.2.28.dist-info/METADATA,sha256=xe3Wlo3UlD3VuTc42ChwnPTL6lp3BZmxnuf0gnZxWv0,3058
26
- bayinx-0.2.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
27
- bayinx-0.2.28.dist-info/RECORD,,
23
+ bayinx-0.2.29.dist-info/METADATA,sha256=kfq898UxqnChe-87HPe46_icfS2btMbJPp3INv3IXo0,3058
24
+ bayinx-0.2.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
25
+ bayinx-0.2.29.dist-info/RECORD,,
bayinx/dists/binomial.py DELETED
File without changes
bayinx/dists/gamma.py DELETED
File without changes