bayinx 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -1,30 +1,9 @@
1
- from abc import abstractmethod
2
1
  from typing import Tuple
3
2
 
4
- import equinox as eqx
5
3
  import jax.numpy as jnp
6
4
  from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
7
5
 
8
-
9
- class Constraint(eqx.Module):
10
- """
11
- Abstract base class for defining parameter constraints.
12
- """
13
-
14
- @abstractmethod
15
- def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
16
- """
17
- Applies the constraining transformation to an unconstrained input and computes the log-absolute-jacobian of the transformation.
18
-
19
- # Parameters
20
- - `x`: The unconstrained JAX Array-like input.
21
-
22
- # Returns
23
- A tuple containing:
24
- - The constrained JAX Array.
25
- - A scalar JAX Array representing the laj of the transformation.
26
- """
27
- pass
6
+ from bayinx.core.constraint import Constraint
28
7
 
29
8
 
30
9
  class LowerBound(Constraint):
@@ -0,0 +1,26 @@
1
+ from abc import abstractmethod
2
+ from typing import Tuple
3
+
4
+ import equinox as eqx
5
+ from jaxtyping import Array, ArrayLike, Scalar
6
+
7
+
8
+ class Constraint(eqx.Module):
9
+ """
10
+ Abstract base class for defining parameter constraints.
11
+ """
12
+
13
+ @abstractmethod
14
+ def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
15
+ """
16
+ Applies the constraining transformation to an unconstrained input and computes the log-absolute-jacobian of the transformation.
17
+
18
+ # Parameters
19
+ - `x`: The unconstrained JAX Array-like input.
20
+
21
+ # Returns
22
+ A tuple containing:
23
+ - The constrained JAX Array.
24
+ - A scalar JAX Array representing the laj of the transformation.
25
+ """
26
+ pass
bayinx/core/model.py CHANGED
@@ -3,10 +3,10 @@ from typing import Any, Dict, Tuple
3
3
 
4
4
  import equinox as eqx
5
5
  import jax.numpy as jnp
6
- import jax.tree_util as jtu
6
+ import jax.tree as jt
7
7
  from jaxtyping import Array, Scalar
8
8
 
9
- from bayinx.core.constraints import Constraint
9
+ from bayinx.core.constraint import Constraint
10
10
 
11
11
 
12
12
  class Model(eqx.Module):
@@ -31,13 +31,13 @@ class Model(eqx.Module):
31
31
  Generates a filter specification to subset relevant parameters for the model.
32
32
  """
33
33
  # Generate empty specification
34
- filter_spec = jtu.tree_map(lambda _: False, self)
34
+ filter_spec = jt.map(lambda _: False, self)
35
35
 
36
36
  # Specify JAX Array parameters
37
37
  filter_spec = eqx.tree_at(
38
38
  lambda model: model.params,
39
39
  filter_spec,
40
- replace=jtu.tree_map(eqx.is_array, self.params),
40
+ replace=jt.map(eqx.is_array, self.params),
41
41
  )
42
42
 
43
43
  return filter_spec
@@ -19,8 +19,8 @@ class Variational(eqx.Module):
19
19
  An abstract base class used to define variational methods.
20
20
 
21
21
  # Attributes
22
- - `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
23
- - `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
22
+ - `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
23
+ - `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
24
24
  """
25
25
 
26
26
  _unflatten: Callable[[Array], Model]
File without changes
@@ -0,0 +1,65 @@
1
+ import jax.lax as lax
2
+ import jax.numpy as jnp
3
+ from jax.scipy.special import gammaincc
4
+ from jaxtyping import Array, ArrayLike, Float
5
+
6
+ from bayinx.dists import gamma2
7
+
8
+
9
+ def prob(
10
+ x: Float[ArrayLike, "..."],
11
+ mu: Float[ArrayLike, "..."],
12
+ nu: Float[ArrayLike, "..."],
13
+ censor: Float[ArrayLike, "..."]
14
+ ) -> Float[Array, "..."]:
15
+ """
16
+ The mixed probability mass/density function (PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
17
+
18
+ # Parameters
19
+ - `x`: Value(s) at which to evaluate the PMF/PDF.
20
+ - `mu`: The positive mean.
21
+ - `nu`: The positive inverse dispersion.
22
+
23
+ # Returns
24
+ The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
25
+ """
26
+ evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
27
+
28
+ # Construct boolean masks
29
+ uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
30
+ censored: Array = jnp.array(x == censor) # pyright: ignore
31
+
32
+ # Evaluate mixed probability (?) function
33
+ evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
34
+ evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
35
+
36
+ return evals
37
+
38
+
39
+ def logprob(
40
+ x: Float[ArrayLike, "..."],
41
+ mu: Float[ArrayLike, "..."],
42
+ nu: Float[ArrayLike, "..."],
43
+ censor: Float[ArrayLike, "..."]
44
+ ) -> Float[Array, "..."]:
45
+ """
46
+ The log-transformed mixed probability mass/density function (log PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
47
+
48
+ # Parameters
49
+ - `x`: Value(s) at which to evaluate the log PMF/PDF.
50
+ - `mu`: The positive mean/location.
51
+ - `nu`: The positive inverse dispersion.
52
+
53
+ # Returns
54
+ The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
55
+ """
56
+ evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
57
+
58
+ # Construct boolean masks
59
+ uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
60
+ censored: Array = jnp.array(x == censor) # pyright: ignore
61
+
62
+ evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
63
+ evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
64
+
65
+ return evals
bayinx/dists/gamma2.py CHANGED
@@ -1,17 +1,17 @@
1
1
  import jax.lax as lax
2
2
  from jax.scipy.special import gammaln
3
- from jaxtyping import Array, ArrayLike, Float, Real
3
+ from jaxtyping import Array, ArrayLike, Float
4
4
 
5
5
 
6
6
  def prob(
7
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], nu: Real[ArrayLike, "..."]
7
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
8
8
  ) -> Float[Array, "..."]:
9
9
  """
10
10
  The probability density function (PDF) for a (mean-precision parameterized) Gamma distribution.
11
11
 
12
12
  # Parameters
13
13
  - `x`: Value(s) at which to evaluate the PDF.
14
- - `mu`: The mean.
14
+ - `mu`: The positive mean.
15
15
  - `nu`: The positive inverse dispersion.
16
16
 
17
17
  # Returns
@@ -22,14 +22,14 @@ def prob(
22
22
 
23
23
 
24
24
  def logprob(
25
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], nu: Real[ArrayLike, "..."]
25
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
26
26
  ) -> Float[Array, "..."]:
27
27
  """
28
28
  The log-transformed probability density function (log PDF) for a (mean-precision parameterized) Gamma distribution.
29
29
 
30
30
  # Parameters
31
31
  - `x`: Value(s) at which to evaluate the log PDF.
32
- - `mu`: The mean/location.
32
+ - `mu`: The positive mean/location.
33
33
  - `nu`: The positive inverse dispersion.
34
34
 
35
35
  # Returns
bayinx/dists/normal.py CHANGED
@@ -1,11 +1,11 @@
1
1
  import jax.lax as lax
2
- from jaxtyping import Array, ArrayLike, Float, Real
2
+ from jaxtyping import Array, ArrayLike, Float
3
3
 
4
4
  __PI = 3.141592653589793
5
5
 
6
6
 
7
7
  def prob(
8
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
8
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
9
9
  ) -> Float[Array, "..."]:
10
10
  """
11
11
  The probability density function (PDF) for a Normal distribution.
@@ -25,7 +25,7 @@ def prob(
25
25
 
26
26
 
27
27
  def logprob(
28
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
28
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
29
29
  ) -> Float[Array, "..."]:
30
30
  """
31
31
  The log of the probability density function (log PDF) for a Normal distribution.
@@ -45,7 +45,7 @@ def logprob(
45
45
 
46
46
 
47
47
  def uprob(
48
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
48
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
49
49
  ) -> Float[Array, "..."]:
50
50
  """
51
51
  The unnormalized probability density function (uPDF) for a Normal distribution.
@@ -63,7 +63,7 @@ def uprob(
63
63
 
64
64
 
65
65
  def ulogprob(
66
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
66
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
67
67
  ) -> Float[Array, "..."]:
68
68
  """
69
69
  The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
bayinx/dists/uniform.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import jax.lax as _lax
2
2
  import jax.numpy as jnp
3
- from jaxtyping import Array, ArrayLike, Float, Real
3
+ from jaxtyping import Array, ArrayLike, Float
4
4
 
5
5
 
6
6
  def prob(
7
- x: Real[ArrayLike, "..."], lb: Real[ArrayLike, "..."], ub: Real[ArrayLike, "..."]
7
+ x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
8
8
  ) -> Float[Array, "..."]:
9
9
  """
10
10
  The probability density function (PDF) for a Uniform distribution.
@@ -22,7 +22,7 @@ def prob(
22
22
 
23
23
 
24
24
  def logprob(
25
- x: Real[ArrayLike, "..."], lb: Real[ArrayLike, "..."], ub: Real[ArrayLike, "..."]
25
+ x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
26
26
  ) -> Float[Array, "..."]:
27
27
  """
28
28
  The log of the probability density function (log PDF) for a Uniform distribution.
@@ -40,7 +40,7 @@ def logprob(
40
40
 
41
41
 
42
42
  def uprob(
43
- x: Real[ArrayLike, "..."], lb: Real[ArrayLike, "..."], ub: Real[ArrayLike, "..."]
43
+ x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
44
44
  ) -> Float[Array, "..."]:
45
45
  """
46
46
  The unnormalized probability density function (uPDF) for a Uniform distribution.
@@ -58,7 +58,7 @@ def uprob(
58
58
 
59
59
 
60
60
  def ulogprob(
61
- x: Real[ArrayLike, "..."], lb: Real[ArrayLike, "..."], ub: Real[ArrayLike, "..."]
61
+ x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
62
62
  ) -> Float[Array, "..."]:
63
63
  """
64
64
  The log of the unnormalized probability density function (log uPDF) for a Uniform distribution.
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Dict, Self
1
+ from typing import Any, Dict, Self
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.numpy as jnp
@@ -20,8 +20,6 @@ class MeanField(Variational):
20
20
  """
21
21
 
22
22
  var_params: Dict[str, Float[Array, "..."]]
23
- _unflatten: Callable[[Float[Array, "..."]], Model]
24
- _constraints: Model
25
23
 
26
24
  def __init__(self, model: Model):
27
25
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.29
3
+ Version: 0.2.30
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,18 +1,22 @@
1
1
  bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
2
2
  bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ bayinx/constraints/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ bayinx/constraints/lower.py,sha256=O37qJ6ojRKbKGJlnQ7Vv7P2VGARcnKrOifwyUNBHma8,912
3
5
  bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
- bayinx/core/constraints.py,sha256=lbVs2-xjGRue17YRPGHz3s_mJ0ZiunpYowbD0QvcD-I,1525
6
+ bayinx/core/constraint.py,sha256=T1QP3WV9RU1nPW-9H7loc6DyBzElHVyDp1c4mmZpnmI,733
5
7
  bayinx/core/flow.py,sha256=A5Vw5t76LPasnMgghjw6ulBkIm5L2jBprusVt-tuwko,2296
6
- bayinx/core/model.py,sha256=Z_HaFr0_-keMjG5tg3xxP3hGML7aDFIcCI8Y5dGrtM4,2145
7
- bayinx/core/variational.py,sha256=W0747jfVJFAtMZqL3mpbtl2wfnARHln-dVBag4xZ23Y,4813
8
+ bayinx/core/model.py,sha256=vfEnqBpHE2MtuJPIDgKvVYIv5n53E2e-KAAXEtqEy0c,2126
9
+ bayinx/core/variational.py,sha256=2stsYKZDri1rLP7mrz7X2GWehBXNESdlWtmF2N9CEas,4787
8
10
  bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
11
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
10
- bayinx/dists/gamma2.py,sha256=jc9jiNuuIAv3tdWghQ2Y2ANpJbDscPJKOvs6dOJVJD0,1315
11
- bayinx/dists/normal.py,sha256=AXsf3Xe2BfCyzfcQ5i8J9AD92LedQjgqviUMQd697D8,2628
12
- bayinx/dists/uniform.py,sha256=mogFe8VuDelM9KXE6RxGek0-tuZYFrwmo_oMOPHXleA,2359
12
+ bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
13
+ bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
14
+ bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
15
+ bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
13
17
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
14
18
  bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
15
- bayinx/mhx/vi/meanfield.py,sha256=8hM1KZ52TpRPLwiQcowsJLlQ-5nJzUEcKrtDiGrFoSs,3732
19
+ bayinx/mhx/vi/meanfield.py,sha256=BobfTagVGA5R-dclv-E0jSA80KZg1X6GGjiw7XR61vE,3643
16
20
  bayinx/mhx/vi/normalizing_flow.py,sha256=FvxDtqGRtaEeeF-bXCYnIEAvOOXVHKUK0oCTF9ma02Y,4622
17
21
  bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
18
22
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
@@ -20,6 +24,6 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=s-fxgzv84BEqNMnmLt6vtvwJqBzixCC2OwWXLz3
20
24
  bayinx/mhx/vi/flows/planar.py,sha256=WVj-oxcRctuoRA6KJjU63ek1ZgKNG2vI-TLN0QqjtKA,1916
21
25
  bayinx/mhx/vi/flows/radial.py,sha256=Obj3SraliawIHmP14F9wRpWt34y3kscY--Izy24eCvM,2499
22
26
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
23
- bayinx-0.2.29.dist-info/METADATA,sha256=kfq898UxqnChe-87HPe46_icfS2btMbJPp3INv3IXo0,3058
24
- bayinx-0.2.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
25
- bayinx-0.2.29.dist-info/RECORD,,
27
+ bayinx-0.2.30.dist-info/METADATA,sha256=VVDldLMllMzTkx_tphb6k2n4u-PSwBbmFOGN-PG1BWc,3058
28
+ bayinx-0.2.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
29
+ bayinx-0.2.30.dist-info/RECORD,,