bayinx 0.2.27__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/__init__.py CHANGED
@@ -1 +1,2 @@
1
- from bayinx.core.model import Model as Model
1
+ from bayinx.core import Model as Model
2
+ from bayinx.core import Parameter as Parameter
@@ -0,0 +1 @@
1
+ from bayinx.constraints.lower import Lower as Lower
@@ -0,0 +1,51 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.numpy as jnp
5
+ import jax.tree as jt
6
+ from jaxtyping import PyTree, Scalar, ScalarLike
7
+
8
+ from bayinx.core.constraint import Constraint
9
+ from bayinx.core.parameter import Parameter
10
+
11
+
12
+ class Lower(Constraint):
13
+ """
14
+ Enforces a lower bound on the parameter.
15
+ """
16
+
17
+ lb: Scalar
18
+
19
+ def __init__(self, lb: ScalarLike):
20
+ self.lb = jnp.array(lb)
21
+
22
+ @eqx.filter_jit
23
+ def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
24
+ """
25
+ Enforces a lower bound on the parameter and adjusts the posterior density.
26
+
27
+ # Parameters
28
+ - `x`: The unconstrained `Parameter`.
29
+
30
+ # Parameters
31
+ A tuple containing:
32
+ - A modified `Parameter` with relevant leaves satisfying the constraint.
33
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
34
+ """
35
+ # Extract relevant filter specification
36
+ filter_spec = x.filter_spec
37
+
38
+ # Extract relevant parameters(all Array)
39
+ dyn_params, static_params = eqx.partition(x, filter_spec)
40
+
41
+ # Compute density adjustment
42
+ laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
43
+ laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
44
+
45
+ # Compute transformation
46
+ dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
47
+
48
+ # Combine into full parameter object
49
+ x = eqx.combine(dyn_params, static_params)
50
+
51
+ return x, laj
bayinx/core/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
1
  from bayinx.core.flow import Flow as Flow
2
2
  from bayinx.core.model import Model as Model
3
+ from bayinx.core.parameter import Parameter as Parameter
3
4
  from bayinx.core.variational import Variational as Variational
@@ -0,0 +1,28 @@
1
+ from abc import abstractmethod
2
+ from typing import Tuple
3
+
4
+ import equinox as eqx
5
+ from jaxtyping import Scalar
6
+
7
+ from bayinx.core.parameter import Parameter
8
+
9
+
10
+ class Constraint(eqx.Module):
11
+ """
12
+ Abstract base class for defining parameter constraints.
13
+ """
14
+
15
+ @abstractmethod
16
+ def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
17
+ """
18
+ Applies the constraining transformation to a parameter and computes the log-absolute-Jacobian of the transformation.
19
+
20
+ # Parameters
21
+ - `x`: The unconstrained `Parameter`.
22
+
23
+ # Returns
24
+ A tuple containing:
25
+ - The constrained `Parameter`.
26
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
27
+ """
28
+ pass
bayinx/core/flow.py CHANGED
@@ -8,11 +8,11 @@ from jaxtyping import Array, Float
8
8
 
9
9
  class Flow(eqx.Module):
10
10
  """
11
- A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
11
+ An abstract base class for a flow(of a normalizing flow).
12
12
 
13
13
  # Attributes
14
14
  - `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
15
- - `constraints`: A dictionary of functions that constrain their corresponding parameter.
15
+ - `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
16
16
  """
17
17
 
18
18
  params: Dict[str, Float[Array, "..."]]
@@ -28,14 +28,16 @@ class Flow(eqx.Module):
28
28
  @abstractmethod
29
29
  def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
30
30
  """
31
- Computes the log-absolute-determinant of the Jacobian at `draws` and applies the forward transformation.
31
+ Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
32
32
 
33
33
  # Returns
34
- A tuple of JAX Arrays containing the log-absolute-determinant of the Jacobians and transformed draws.
34
+ A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
35
35
  """
36
36
  pass
37
37
 
38
38
  # Default filter specification
39
+ @property
40
+ @eqx.filter_jit
39
41
  def filter_spec(self):
40
42
  """
41
43
  Generates a filter specification to subset relevant parameters for the flow.
@@ -53,7 +55,7 @@ class Flow(eqx.Module):
53
55
  return filter_spec
54
56
 
55
57
  @eqx.filter_jit
56
- def constrain_pars(self: Self):
58
+ def constrain_params(self: Self):
57
59
  """
58
60
  Constrain `params` to the appropriate domain.
59
61
 
@@ -68,11 +70,11 @@ class Flow(eqx.Module):
68
70
  return t_params
69
71
 
70
72
  @eqx.filter_jit
71
- def transform_pars(self: Self) -> Dict[str, Array]:
73
+ def transform_params(self: Self) -> Dict[str, Array]:
72
74
  """
73
75
  Apply a custom transformation to `params` if needed.
74
76
 
75
77
  # Returns
76
78
  A dictionary of transformed JAX Arrays representing the transformed parameters.
77
79
  """
78
- return self.constrain_pars()
80
+ return self.constrain_params()
bayinx/core/model.py CHANGED
@@ -1,24 +1,25 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, Dict, Tuple
2
+ from typing import Any, Dict, Generic, Tuple, TypeVar
3
3
 
4
4
  import equinox as eqx
5
5
  import jax.numpy as jnp
6
- import jax.tree_util as jtu
7
- from jaxtyping import Array, Scalar
6
+ import jax.tree as jt
7
+ from jaxtyping import PyTree, Scalar
8
8
 
9
- from bayinx.core.constraints import Constraint
9
+ from bayinx.core.constraint import Constraint
10
+ from bayinx.core.parameter import Parameter
10
11
 
11
-
12
- class Model(eqx.Module):
12
+ P = TypeVar('P', bound=Dict[str, Parameter[PyTree]])
13
+ class Model(eqx.Module, Generic[P]):
13
14
  """
14
- A superclass used to define probabilistic models.
15
+ An abstract base class used to define probabilistic models.
15
16
 
16
17
  # Attributes
17
- - `params`: A dictionary of JAX Arrays representing parameters of the model.
18
- - `constraints`: A dictionary of functions that constrain their corresponding parameter.
18
+ - `params`: A dictionary of parameters.
19
+ - `constraints`: A dictionary of constraints.
19
20
  """
20
21
 
21
- params: Dict[str, Array]
22
+ params: P
22
23
  constraints: Dict[str, Constraint]
23
24
 
24
25
  @abstractmethod
@@ -26,32 +27,34 @@ class Model(eqx.Module):
26
27
  pass
27
28
 
28
29
  # Default filter specification
30
+ @property
31
+ @eqx.filter_jit
29
32
  def filter_spec(self):
30
33
  """
31
34
  Generates a filter specification to subset relevant parameters for the model.
32
35
  """
33
36
  # Generate empty specification
34
- filter_spec = jtu.tree_map(lambda _: False, self)
37
+ filter_spec = jt.map(lambda _: False, self)
35
38
 
36
- # Specify JAX Array parameters
39
+ # Specify relevant parameters
37
40
  filter_spec = eqx.tree_at(
38
41
  lambda model: model.params,
39
42
  filter_spec,
40
- replace=jtu.tree_map(eqx.is_array, self.params),
43
+ replace={key: param.filter_spec for key, param in self.params.items()}
41
44
  )
42
45
 
43
46
  return filter_spec
44
47
 
45
48
  # Add constrain method
46
49
  @eqx.filter_jit
47
- def constrain_pars(self) -> Tuple[Dict[str, Array], Scalar]:
50
+ def constrain_params(self) -> Tuple[P, Scalar]:
48
51
  """
49
52
  Constrain `params` to the appropriate domain.
50
53
 
51
54
  # Returns
52
- A dictionary of transformed JAX Arrays representing the constrained parameters and the adjustment to the posterior density.
55
+ A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
53
56
  """
54
- t_params: Dict[str, Array] = self.params
57
+ t_params: P = self.params
55
58
  target: Scalar = jnp.array(0.0)
56
59
 
57
60
  for par, map in self.constraints.items():
@@ -63,12 +66,13 @@ class Model(eqx.Module):
63
66
 
64
67
  return t_params, target
65
68
 
66
-
67
- def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
69
+ # Add default transform method
70
+ @eqx.filter_jit
71
+ def transform_params(self) -> Tuple[P, Scalar]:
68
72
  """
69
73
  Apply a custom transformation to `params` if needed.
70
74
 
71
75
  # Returns
72
76
  A dictionary of transformed JAX Arrays representing the transformed parameters.
73
77
  """
74
- return self.constrain_pars()
78
+ return self.constrain_params()
@@ -0,0 +1,41 @@
1
+ from typing import Generic, Self, TypeVar
2
+
3
+ import equinox as eqx
4
+ import jax.tree as jt
5
+ from jaxtyping import PyTree
6
+
7
+ T = TypeVar('T', bound=PyTree)
8
+ class Parameter(eqx.Module, Generic[T]):
9
+ """
10
+ A container for a parameter of a `Model`.
11
+
12
+ Subclasses can be constructed for custom filter specifications(`filter_spec`).
13
+
14
+ # Attributes
15
+ - `vals`: The parameter's value(s).
16
+ """
17
+ vals: T
18
+
19
+
20
+ def __init__(self, values: T):
21
+ # Insert parameter values
22
+ self.vals = values
23
+
24
+ # Default filter specification
25
+ @property
26
+ @eqx.filter_jit
27
+ def filter_spec(self) -> Self:
28
+ """
29
+ Generates a filter specification to filter out static parameters.
30
+ """
31
+ # Generate empty specification
32
+ filter_spec = jt.map(lambda _: False, self)
33
+
34
+ # Specify Array leaves
35
+ filter_spec = eqx.tree_at(
36
+ lambda params: params.vals,
37
+ filter_spec,
38
+ replace=jt.map(eqx.is_array_like, self.vals),
39
+ )
40
+
41
+ return filter_spec
@@ -8,7 +8,7 @@ import jax.lax as lax
8
8
  import jax.numpy as jnp
9
9
  import jax.random as jr
10
10
  import optax as opx
11
- from jaxtyping import Array, Float, Key, PyTree, Scalar
11
+ from jaxtyping import Array, Key, PyTree, Scalar
12
12
  from optax import GradientTransformation, OptState, Schedule
13
13
 
14
14
  from bayinx.core import Model
@@ -16,16 +16,23 @@ from bayinx.core import Model
16
16
 
17
17
  class Variational(eqx.Module):
18
18
  """
19
- A superclass used to define variational methods.
19
+ An abstract base class used to define variational methods.
20
20
 
21
21
  # Attributes
22
- - `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
23
- - `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
22
+ - `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
23
+ - `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
24
24
  """
25
25
 
26
- _unflatten: Callable[[Float[Array, "..."]], Model]
26
+ _unflatten: Callable[[Array], Model]
27
27
  _constraints: Model
28
28
 
29
+ @abstractmethod
30
+ def filter_spec(self):
31
+ """
32
+ Filter specification for dynamic and static components of the `Variational`.
33
+ """
34
+ pass
35
+
29
36
  @abstractmethod
30
37
  def sample(self, n: int, key: Key) -> Array:
31
38
  """
@@ -54,13 +61,6 @@ class Variational(eqx.Module):
54
61
  """
55
62
  pass
56
63
 
57
- @abstractmethod
58
- def filter_spec(self):
59
- """
60
- Filter specification for dynamic and static components of the `Variational`.
61
- """
62
- pass
63
-
64
64
  @eqx.filter_jit
65
65
  @partial(jax.vmap, in_axes=(None, 0, None))
66
66
  def eval_model(self, draws: Array, data: Any = None) -> Array:
@@ -103,7 +103,7 @@ class Variational(eqx.Module):
103
103
  - `key`: A PRNG key.
104
104
  """
105
105
  # Partition variational
106
- dyn, static = eqx.partition(self, self.filter_spec())
106
+ dyn, static = eqx.partition(self, self.filter_spec)
107
107
 
108
108
  # Construct scheduler
109
109
  schedule: Schedule = opx.cosine_decay_schedule(
@@ -135,7 +135,7 @@ class Variational(eqx.Module):
135
135
  # Update PRNG key
136
136
  key, _ = jr.split(key)
137
137
 
138
- # Combine variational
138
+ # Reconstruct variational
139
139
  vari = eqx.combine(dyn, static)
140
140
 
141
141
  # Compute gradient of the ELBO
@@ -143,7 +143,7 @@ class Variational(eqx.Module):
143
143
 
144
144
  # Compute updates
145
145
  updates, opt_state = optim.update(
146
- updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
146
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
147
147
  )
148
148
 
149
149
  # Update variational distribution
@@ -0,0 +1,65 @@
1
+ import jax.lax as lax
2
+ import jax.numpy as jnp
3
+ from jax.scipy.special import gammaincc
4
+ from jaxtyping import Array, ArrayLike, Float
5
+
6
+ from bayinx.dists import gamma2
7
+
8
+
9
+ def prob(
10
+ x: Float[ArrayLike, "..."],
11
+ mu: Float[ArrayLike, "..."],
12
+ nu: Float[ArrayLike, "..."],
13
+ censor: Float[ArrayLike, "..."]
14
+ ) -> Float[Array, "..."]:
15
+ """
16
+ The mixed probability mass/density function (PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
17
+
18
+ # Parameters
19
+ - `x`: Value(s) at which to evaluate the PMF/PDF.
20
+ - `mu`: The positive mean.
21
+ - `nu`: The positive inverse dispersion.
22
+
23
+ # Returns
24
+ The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
25
+ """
26
+ evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
27
+
28
+ # Construct boolean masks
29
+ uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
30
+ censored: Array = jnp.array(x == censor) # pyright: ignore
31
+
32
+ # Evaluate mixed probability (?) function
33
+ evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
34
+ evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
35
+
36
+ return evals
37
+
38
+
39
+ def logprob(
40
+ x: Float[ArrayLike, "..."],
41
+ mu: Float[ArrayLike, "..."],
42
+ nu: Float[ArrayLike, "..."],
43
+ censor: Float[ArrayLike, "..."]
44
+ ) -> Float[Array, "..."]:
45
+ """
46
+ The log-transformed mixed probability mass/density function (log PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
47
+
48
+ # Parameters
49
+ - `x`: Value(s) at which to evaluate the log PMF/PDF.
50
+ - `mu`: The positive mean/location.
51
+ - `nu`: The positive inverse dispersion.
52
+
53
+ # Returns
54
+ The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
55
+ """
56
+ evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
57
+
58
+ # Construct boolean masks
59
+ uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
60
+ censored: Array = jnp.array(x == censor) # pyright: ignore
61
+
62
+ evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
63
+ evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
64
+
65
+ return evals
bayinx/dists/gamma2.py CHANGED
@@ -0,0 +1,39 @@
1
+ import jax.lax as lax
2
+ from jax.scipy.special import gammaln
3
+ from jaxtyping import Array, ArrayLike, Float
4
+
5
+
6
+ def prob(
7
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
8
+ ) -> Float[Array, "..."]:
9
+ """
10
+ The probability density function (PDF) for a (mean-precision parameterized) Gamma distribution.
11
+
12
+ # Parameters
13
+ - `x`: Value(s) at which to evaluate the PDF.
14
+ - `mu`: The positive mean.
15
+ - `nu`: The positive inverse dispersion.
16
+
17
+ # Returns
18
+ The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
19
+ """
20
+
21
+ return lax.exp(logprob(x, mu, nu))
22
+
23
+
24
+ def logprob(
25
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
26
+ ) -> Float[Array, "..."]:
27
+ """
28
+ The log-transformed probability density function (log PDF) for a (mean-precision parameterized) Gamma distribution.
29
+
30
+ # Parameters
31
+ - `x`: Value(s) at which to evaluate the log PDF.
32
+ - `mu`: The positive mean/location.
33
+ - `nu`: The positive inverse dispersion.
34
+
35
+ # Returns
36
+ The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
37
+ """
38
+
39
+ return - gammaln(nu) + nu * (lax.log(nu) - lax.log(mu)) + (nu - 1.0) * lax.log(x) - (x * nu / mu) # pyright: ignore
bayinx/dists/normal.py CHANGED
@@ -1,31 +1,31 @@
1
- import jax.lax as _lax
2
- from jaxtyping import Array, ArrayLike, Float, Real
1
+ import jax.lax as lax
2
+ from jaxtyping import Array, ArrayLike, Float
3
3
 
4
4
  __PI = 3.141592653589793
5
5
 
6
6
 
7
7
  def prob(
8
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
8
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
9
9
  ) -> Float[Array, "..."]:
10
10
  """
11
11
  The probability density function (PDF) for a Normal distribution.
12
12
 
13
13
  # Parameters
14
14
  - `x`: Value(s) at which to evaluate the PDF.
15
- - `mu`: The mean/location parameter(s).
16
- - `sigma`: The non-negative standard deviation parameter(s).
15
+ - `mu`: The mean/location.
16
+ - `sigma`: The positive standard deviation.
17
17
 
18
18
  # Returns
19
19
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
20
20
  """
21
21
 
22
- return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / ( # pyright: ignore
23
- sigma * _lax.sqrt(2.0 * __PI)
22
+ return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / ( # pyright: ignore
23
+ sigma * lax.sqrt(2.0 * __PI)
24
24
  )
25
25
 
26
26
 
27
27
  def logprob(
28
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
28
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
29
29
  ) -> Float[Array, "..."]:
30
30
  """
31
31
  The log of the probability density function (log PDF) for a Normal distribution.
@@ -36,14 +36,16 @@ def logprob(
36
36
  - `sigma`: The non-negative standard deviation parameter(s).
37
37
 
38
38
  # Returns
39
- The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
39
+ The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
40
40
  """
41
41
 
42
- return -_lax.log(sigma * _lax.sqrt(2.0 * __PI)) - 0.5 * _lax.square((x - mu) / sigma) # pyright: ignore
42
+ return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square(
43
+ (x - mu) / sigma # pyright: ignore
44
+ )
43
45
 
44
46
 
45
47
  def uprob(
46
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
48
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
47
49
  ) -> Float[Array, "..."]:
48
50
  """
49
51
  The unnormalized probability density function (uPDF) for a Normal distribution.
@@ -51,17 +53,17 @@ def uprob(
51
53
  # Parameters
52
54
  - `x`: Value(s) at which to evaluate the uPDF.
53
55
  - `mu`: The mean/location parameter(s).
54
- - `sigma`: The non-negative standard deviation parameter(s).
56
+ - `sigma`: The positive standard deviation parameter(s).
55
57
 
56
58
  # Returns
57
59
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
58
60
  """
59
61
 
60
- return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma # pyright: ignore
62
+ return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma # pyright: ignore
61
63
 
62
64
 
63
65
  def ulogprob(
64
- x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
66
+ x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
65
67
  ) -> Float[Array, "..."]:
66
68
  """
67
69
  The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
@@ -75,4 +77,4 @@ def ulogprob(
75
77
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
76
78
  """
77
79
 
78
- return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma) # pyright: ignore
80
+ return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma) # pyright: ignore
bayinx/dists/uniform.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import jax.lax as _lax
2
2
  import jax.numpy as jnp
3
- from jaxtyping import Array, ArrayLike, Float, Real
3
+ from jaxtyping import Array, ArrayLike, Float
4
4
 
5
5
 
6
6
  def prob(
7
- x: Real[ArrayLike, "..."], lb: Real[ArrayLike, "..."], ub: Real[ArrayLike, "..."]
7
+ x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
8
8
  ) -> Float[Array, "..."]:
9
9
  """
10
10
  The probability density function (PDF) for a Uniform distribution.
@@ -18,11 +18,11 @@ def prob(
18
18
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
19
19
  """
20
20
 
21
- return 1.0 / (ub - lb) # pyright: ignore
21
+ return 1.0 / (ub - lb) # pyright: ignore
22
22
 
23
23
 
24
24
  def logprob(
25
- x: Real[ArrayLike, "..."], lb: Real[ArrayLike, "..."], ub: Real[ArrayLike, "..."]
25
+ x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
26
26
  ) -> Float[Array, "..."]:
27
27
  """
28
28
  The log of the probability density function (log PDF) for a Uniform distribution.
@@ -36,11 +36,11 @@ def logprob(
36
36
  The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
37
37
  """
38
38
 
39
- return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
39
+ return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
40
40
 
41
41
 
42
42
  def uprob(
43
- x: Real[ArrayLike, "..."], lb: Real[ArrayLike, "..."], ub: Real[ArrayLike, "..."]
43
+ x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
44
44
  ) -> Float[Array, "..."]:
45
45
  """
46
46
  The unnormalized probability density function (uPDF) for a Uniform distribution.
@@ -54,11 +54,11 @@ def uprob(
54
54
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
55
55
  """
56
56
 
57
- return jnp.ones(jnp.broadcast_arrays(x,lb,ub))
57
+ return jnp.ones(jnp.broadcast_arrays(x, lb, ub))
58
58
 
59
59
 
60
60
  def ulogprob(
61
- x: Real[ArrayLike, "..."], lb: Real[ArrayLike, "..."], ub: Real[ArrayLike, "..."]
61
+ x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
62
62
  ) -> Float[Array, "..."]:
63
63
  """
64
64
  The log of the unnormalized probability density function (log uPDF) for a Uniform distribution.
@@ -72,4 +72,4 @@ def ulogprob(
72
72
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
73
73
  """
74
74
 
75
- return jnp.zeros(jnp.broadcast_arrays(x,lb,ub))
75
+ return jnp.zeros(jnp.broadcast_arrays(x, lb, ub))
@@ -1,29 +1,26 @@
1
1
  from functools import partial
2
- from typing import Callable, Dict, Tuple
2
+ from typing import Tuple
3
3
 
4
4
  import equinox as eqx
5
5
  import jax
6
6
  import jax.numpy as jnp
7
- from jaxtyping import Array, Float, Scalar
7
+ from jaxtyping import Array, Scalar
8
8
 
9
9
  from bayinx.core import Flow
10
10
 
11
11
 
12
12
  class FullAffine(Flow):
13
13
  """
14
- An affine flow.
14
+ A full affine flow.
15
15
 
16
16
  # Attributes
17
17
  - `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
18
18
  - `constraints`: A dictionary of constraining transformations.
19
19
  """
20
20
 
21
- params: Dict[str, Float[Array, "..."]]
22
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
23
-
24
21
  def __init__(self, dim: int):
25
22
  """
26
- Initializes an affine flow.
23
+ Initializes a full affine flow.
27
24
 
28
25
  # Parameters
29
26
  - `dim`: The dimension of the parameter space.
@@ -33,26 +30,23 @@ class FullAffine(Flow):
33
30
  "scale": jnp.zeros((dim, dim)),
34
31
  }
35
32
 
36
- self.constraints = {"scale": lambda m: jnp.tril(m)}
37
-
38
- def transform_pars(self):
39
- # Get constrained parameters
40
- params = self.constrain_pars()
41
-
42
- # Extract diagonal and apply exponential
43
- diag: Array = jnp.exp(jnp.diag(params['scale']))
44
-
45
- # Fill diagonal
46
- params['scale'] = jnp.fill_diagonal(params['scale'], diag, inplace=False)
47
-
33
+ if dim == 1:
34
+ self.constraints = {}
35
+ else:
48
36
 
49
- return params
37
+ @eqx.filter_jit
38
+ def constrain_scale(scale: Array):
39
+ # Extract diagonal and apply exponential
40
+ diag: Array = jnp.exp(jnp.diag(scale))
50
41
 
42
+ # Return matrix with modified diagonal
43
+ return jnp.fill_diagonal(jnp.tril(scale), diag, inplace=False)
51
44
 
45
+ self.constraints = {"scale": constrain_scale}
52
46
 
53
47
  @eqx.filter_jit
54
48
  def forward(self, draws: Array) -> Array:
55
- params = self.transform_pars()
49
+ params = self.transform_params()
56
50
 
57
51
  # Extract parameters
58
52
  shift: Array = params["shift"]
@@ -65,8 +59,8 @@ class FullAffine(Flow):
65
59
 
66
60
  @eqx.filter_jit
67
61
  @partial(jax.vmap, in_axes=(None, 0))
68
- def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
69
- params = self.transform_pars()
62
+ def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
63
+ params = self.transform_params()
70
64
 
71
65
  # Extract parameters
72
66
  shift: Array = params["shift"]
@@ -75,7 +69,7 @@ class FullAffine(Flow):
75
69
  # Compute forward transformation
76
70
  draws = draws @ scale + shift
77
71
 
78
- # Compute ladj
79
- ladj: Scalar = jnp.log(jnp.diag(scale)).sum()
72
+ # Compute laj
73
+ laj: Scalar = jnp.log(jnp.diag(scale)).sum()
80
74
 
81
- return ladj, draws
75
+ return draws, laj
@@ -39,7 +39,7 @@ class Planar(Flow):
39
39
  @eqx.filter_jit
40
40
  @partial(jax.vmap, in_axes=(None, 0))
41
41
  def forward(self, draws: Array) -> Array:
42
- params = self.transform_pars()
42
+ params = self.transform_params()
43
43
 
44
44
  # Extract parameters
45
45
  w: Array = params["w"]
@@ -53,8 +53,8 @@ class Planar(Flow):
53
53
 
54
54
  @eqx.filter_jit
55
55
  @partial(jax.vmap, in_axes=(None, 0))
56
- def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
57
- params = self.transform_pars()
56
+ def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
57
+ params = self.transform_params()
58
58
 
59
59
  # Extract parameters
60
60
  w: Array = params["w"]
@@ -67,8 +67,8 @@ class Planar(Flow):
67
67
  # Compute forward transformation
68
68
  draws = draws + u * jnp.tanh(x)
69
69
 
70
- # Compute ladj
70
+ # Compute laj
71
71
  h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
72
- ladj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
72
+ laj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
73
73
 
74
- return ladj, draws
74
+ return draws, laj
@@ -49,7 +49,7 @@ class Radial(Flow):
49
49
  # Returns
50
50
  The transformed samples.
51
51
  """
52
- params = self.transform_pars()
52
+ params = self.transform_params()
53
53
 
54
54
  # Extract parameters
55
55
  alpha = params["alpha"]
@@ -66,8 +66,8 @@ class Radial(Flow):
66
66
 
67
67
  @partial(jax.vmap, in_axes=(None, 0))
68
68
  @eqx.filter_jit
69
- def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
70
- params = self.transform_pars()
69
+ def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
70
+ params = self.transform_params()
71
71
 
72
72
  # Extract parameters
73
73
  alpha = params["alpha"]
@@ -84,11 +84,11 @@ class Radial(Flow):
84
84
  draws = draws + (x) * (draws - center)
85
85
 
86
86
  # Compute density adjustment
87
- ladj = jnp.log(
87
+ laj = jnp.log(
88
88
  jnp.abs(
89
89
  (1.0 + alpha * beta / (alpha + r) ** 2.0)
90
90
  * (1.0 + x) ** (center.size - 1.0)
91
91
  )
92
92
  )
93
93
 
94
- return ladj, draws
94
+ return draws, laj
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Dict, Self
1
+ from typing import Any, Dict, Self
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.numpy as jnp
@@ -20,8 +20,6 @@ class MeanField(Variational):
20
20
  """
21
21
 
22
22
  var_params: Dict[str, Float[Array, "..."]]
23
- _unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
24
- _constraints: Model = eqx.field(static=True)
25
23
 
26
24
  def __init__(self, model: Model):
27
25
  """
@@ -31,7 +29,7 @@ class MeanField(Variational):
31
29
  - `model`: A probabilistic `Model` object.
32
30
  """
33
31
  # Partition model
34
- params, self._constraints = eqx.partition(model, model.filter_spec())
32
+ params, self._constraints = eqx.partition(model, model.filter_spec)
35
33
 
36
34
  # Flatten params component
37
35
  params, self._unflatten = ravel_pytree(params)
@@ -42,6 +40,22 @@ class MeanField(Variational):
42
40
  "log_std": jnp.zeros(params.size, dtype=params.dtype),
43
41
  }
44
42
 
43
+ @property
44
+ @eqx.filter_jit
45
+ def filter_spec(self):
46
+ # Generate empty specification
47
+ filter_spec = jtu.tree_map(lambda _: False, self)
48
+
49
+ # Specify variational parameters
50
+ filter_spec = eqx.tree_at(
51
+ lambda mf: mf.var_params,
52
+ filter_spec,
53
+ replace=True,
54
+ )
55
+
56
+ return filter_spec
57
+
58
+
45
59
  @eqx.filter_jit
46
60
  def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
47
61
  # Sample variational draws
@@ -61,27 +75,12 @@ class MeanField(Variational):
61
75
  sigma=jnp.exp(self.var_params["log_std"]),
62
76
  ).sum(axis=1)
63
77
 
64
- @eqx.filter_jit
65
- def filter_spec(self):
66
- filter_spec = jtu.tree_map(lambda _: False, self)
67
- filter_spec = eqx.tree_at(
68
- lambda mf: mf.var_params,
69
- filter_spec,
70
- replace=True,
71
- )
72
- return filter_spec
73
-
74
78
  @eqx.filter_jit
75
79
  def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
76
- """
77
- Estimate the ELBO and its gradient(w.r.t the variational parameters).
78
- """
79
- # Partition variational
80
- dyn, static = eqx.partition(self, self.filter_spec())
80
+ dyn, static = eqx.partition(self, self.filter_spec)
81
81
 
82
82
  @eqx.filter_jit
83
83
  def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
84
- # Combine
85
84
  vari = eqx.combine(dyn, static)
86
85
 
87
86
  # Sample draws from variational distribution
@@ -100,8 +99,7 @@ class MeanField(Variational):
100
99
 
101
100
  @eqx.filter_jit
102
101
  def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
103
- # Partition
104
- dyn, static = eqx.partition(self, self.filter_spec())
102
+ dyn, static = eqx.partition(self, self.filter_spec)
105
103
 
106
104
  @eqx.filter_grad
107
105
  @eqx.filter_jit
@@ -1,11 +1,11 @@
1
- from typing import Any, Callable, Self, Tuple
1
+ from typing import Any, Self, Tuple
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.flatten_util as jfu
5
5
  import jax.numpy as jnp
6
6
  import jax.random as jr
7
7
  import jax.tree_util as jtu
8
- from jaxtyping import Array, Float, Key, Scalar
8
+ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Flow, Model, Variational
11
11
 
@@ -17,14 +17,11 @@ class NormalizingFlow(Variational):
17
17
 
18
18
  # Attributes
19
19
  - `base`: A base variational distribution.
20
- - `flows`: An ordered collection of continuously parameterized
21
- diffeomorphisms.
20
+ - `flows`: An ordered collection of continuously parameterized diffeomorphisms.
22
21
  """
23
22
 
24
23
  flows: list[Flow]
25
24
  base: Variational
26
- _unflatten: Callable[[Float[Array, "..."]], Model]
27
- _constraints: Model
28
25
 
29
26
  def __init__(self, base: Variational, flows: list[Flow], model: Model):
30
27
  """
@@ -36,7 +33,7 @@ class NormalizingFlow(Variational):
36
33
  - `model`: A probabilistic `Model` object.
37
34
  """
38
35
  # Partition model
39
- params, self._constraints = eqx.partition(model, eqx.is_array)
36
+ params, self._constraints = eqx.partition(model, model.filter_spec)
40
37
 
41
38
  # Flatten params component
42
39
  _, self._unflatten = jfu.ravel_pytree(params)
@@ -44,6 +41,21 @@ class NormalizingFlow(Variational):
44
41
  self.base = base
45
42
  self.flows = flows
46
43
 
44
+ @property
45
+ @eqx.filter_jit
46
+ def filter_spec(self):
47
+ # Generate empty specification
48
+ filter_spec = jtu.tree_map(lambda _: False, self)
49
+
50
+ # Specify variational parameters based on each flow's filter spec.
51
+ filter_spec = eqx.tree_at(
52
+ lambda vari: vari.flows,
53
+ filter_spec,
54
+ replace=[flow.filter_spec for flow in self.flows],
55
+ )
56
+
57
+ return filter_spec
58
+
47
59
  @eqx.filter_jit
48
60
  def sample(self, n: int, key: Key = jr.PRNGKey(0)):
49
61
  """
@@ -65,19 +77,18 @@ class NormalizingFlow(Variational):
65
77
 
66
78
  for map in self.flows:
67
79
  # Compute adjustment
68
- ladj, draws = map.adjust_density(draws)
80
+ draws, laj = map.adjust_density(draws)
69
81
 
70
82
  # Adjust variational density
71
- variational_evals = variational_evals - ladj
83
+ variational_evals = variational_evals - laj
72
84
 
73
85
  return variational_evals
74
86
 
75
87
  @eqx.filter_jit
76
88
  def __eval(self, draws: Array, data=None) -> Tuple[Array, Array]:
77
89
  """
78
- Evaluate the posterior and variational densities at the transformed
79
- `draws` to avoid extra compute when requiring variational draws for
80
- the posterior evaluation.
90
+ Evaluate the posterior and variational densities together at the
91
+ transformed `draws` to avoid extra compute.
81
92
 
82
93
  # Parameters
83
94
  - `draws`: Draws from the base variational distribution.
@@ -91,32 +102,19 @@ class NormalizingFlow(Variational):
91
102
 
92
103
  for map in self.flows:
93
104
  # Compute adjustment
94
- ladj, draws = map.adjust_density(draws)
105
+ draws, laj = map.adjust_density(draws)
95
106
 
96
107
  # Adjust variational density
97
- variational_evals = variational_evals - ladj
108
+ variational_evals = variational_evals - laj
98
109
 
99
110
  # Evaluate posterior at final variational draws
100
111
  posterior_evals = self.eval_model(draws, data)
101
112
 
102
113
  return posterior_evals, variational_evals
103
114
 
104
- def filter_spec(self):
105
- # Generate empty specification
106
- filter_spec = jtu.tree_map(lambda _: False, self)
107
-
108
- # Specify variational parameters based on each flow's filter spec.
109
- filter_spec = eqx.tree_at(
110
- lambda vari: vari.flows,
111
- filter_spec,
112
- replace=[flow.filter_spec() for flow in self.flows],
113
- )
114
-
115
- return filter_spec
116
-
117
115
  @eqx.filter_jit
118
116
  def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
119
- dyn, static = eqx.partition(self, self.filter_spec())
117
+ dyn, static = eqx.partition(self, self.filter_spec)
120
118
 
121
119
  @eqx.filter_jit
122
120
  def elbo(dyn: Self, n: int, key: Key, data: Any = None):
@@ -133,7 +131,7 @@ class NormalizingFlow(Variational):
133
131
 
134
132
  @eqx.filter_jit
135
133
  def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
136
- dyn, static = eqx.partition(self, self.filter_spec())
134
+ dyn, static = eqx.partition(self, self.filter_spec)
137
135
 
138
136
  @eqx.filter_grad
139
137
  @eqx.filter_jit
bayinx/mhx/vi/standard.py CHANGED
@@ -19,7 +19,7 @@ class Standard(Variational):
19
19
  - `dim`: Dimension of the parameter space.
20
20
  """
21
21
 
22
- dim: int = eqx.field(static=True)
22
+ dim: int
23
23
  _unflatten: Callable[[Float[Array, "..."]], Model]
24
24
  _constraints: Model
25
25
 
@@ -31,7 +31,7 @@ class Standard(Variational):
31
31
  - `model`: A probabilistic `Model` object.
32
32
  """
33
33
  # Partition model
34
- params, self._constraints = eqx.partition(model, model.filter_spec())
34
+ params, self._constraints = eqx.partition(model, model.filter_spec)
35
35
 
36
36
  # Flatten params component
37
37
  params, self._unflatten = ravel_pytree(params)
@@ -54,7 +54,7 @@ class Standard(Variational):
54
54
  sigma=jnp.array(1.0),
55
55
  ).sum(axis=1, keepdims=True)
56
56
 
57
- @eqx.filter_jit
57
+ @property
58
58
  def filter_spec(self):
59
59
  filter_spec = jtu.tree_map(lambda _: False, self)
60
60
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.27
3
+ Version: 0.3.2
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -0,0 +1,30 @@
1
+ bayinx/__init__.py,sha256=htihTsJ54k-ljBLzt4ye8DR7ORwZhxv-bLMcEaDQeqY,86
2
+ bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ bayinx/constraints/__init__.py,sha256=PSxvcuSox2JL61AG1iag2PTNKPcid_DbOQzHpYdj5RE,52
4
+ bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
5
+ bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
6
+ bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
7
+ bayinx/core/flow.py,sha256=lAPJdQnrIxC3JoowTp77Gvm0p0v_xQn8FMjFJWMnWbc,2340
8
+ bayinx/core/model.py,sha256=ADSMapUJGyvKf_TpeC7Foaa3BJ03_Kc7FZxIEKNQkZE,2228
9
+ bayinx/core/parameter.py,sha256=oxCCZcZ-DDBvfWzexfhQkSJPxNQnE1vYXtBhiEZG2bM,1025
10
+ bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
11
+ bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
13
+ bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
14
+ bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
15
+ bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
16
+ bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
18
+ bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
19
+ bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
20
+ bayinx/mhx/vi/meanfield.py,sha256=M4QrOuHaIMLTuQSD6JNF9vELnTm370tXV68JPB7B67M,3652
21
+ bayinx/mhx/vi/normalizing_flow.py,sha256=9c5ayMJ_Wgq6pUb1GYHIFIzw3Bf1AsIdOjcerLoYMrA,4655
22
+ bayinx/mhx/vi/standard.py,sha256=DfSV0r9oXzp9UM8OsZBpoJPRUhiDoAq_X2_2z_M83lA,1685
23
+ bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
24
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
25
+ bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
26
+ bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
27
+ bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
28
+ bayinx-0.3.2.dist-info/METADATA,sha256=9cltWLDiwqg6VpnufQfKYEw_5ZCywJRp7gAPZAogLlA,3057
29
+ bayinx-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
30
+ bayinx-0.3.2.dist-info/RECORD,,
@@ -1,61 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Tuple
3
-
4
- import equinox as eqx
5
- import jax.numpy as jnp
6
- from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
7
-
8
-
9
- class Constraint(eqx.Module):
10
- """
11
- Abstract base class for defining parameter constraints.
12
-
13
- Subclasses should implement the `constrain` method to apply the
14
- transformation and compute the ladj adjustment.
15
- """
16
- @abstractmethod
17
- def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
18
- """
19
- Applies the constraining transformation to an unconstrained input
20
- and computes the log absolute determinant of the Jacobian (ladj)
21
- of this transformation.
22
-
23
- # Parameters
24
- - `x`: The unconstrained JAX Array-like input.
25
-
26
- # Returns
27
- A tuple containing:
28
- - The constrained JAX Array.
29
- - A scalar JAX Array representing the ladj of the transformation.
30
- """
31
- pass
32
-
33
-
34
- class LowerBound(Constraint):
35
- """
36
- Enforces a lower bound on the parameter.
37
- """
38
- lb: ScalarLike
39
-
40
- def __init__(self, lb: ScalarLike):
41
- self.lb = lb
42
-
43
- def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
44
- """
45
- Applies the lower bound constraint and computes the ladj.
46
-
47
- # Parameters
48
- - `x`: The unconstrained JAX Array-like input.
49
-
50
- # Parameters
51
- A tuple containing:
52
- - The constrained JAX Array (x > self.lb).
53
- - A scalar JAX Array representing the ladj of the transformation.
54
- """
55
- # Compute transformation adjustment
56
- ladj = jnp.sum(x)
57
-
58
- # Compute transformation
59
- x = jnp.exp(x) + self.lb
60
-
61
- return x, ladj
bayinx/core/utils.py DELETED
@@ -1 +0,0 @@
1
-
bayinx/dists/gamma.py DELETED
File without changes
@@ -1,28 +0,0 @@
1
- bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
2
- bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
- bayinx/core/constraints.py,sha256=Y8FJX3CkgnLQ3HXuTPGuzvLtXlKs0B7z0-YymoHgdfg,1682
5
- bayinx/core/flow.py,sha256=9swS5wh7AsIZWgG_IhQS-upcPlr7G-juaP_5rsbX6G0,2363
6
- bayinx/core/model.py,sha256=U1xBnAXnIvFJjWF-XIT8BYjP9PtoRZY_PwyhRwJf-HA,2144
7
- bayinx/core/utils.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
8
- bayinx/core/variational.py,sha256=vUZ6u5CXCHfs6ZrA8PF4PHfmUXHTK2RJGHyZ3afFfsg,4820
9
- bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
11
- bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- bayinx/dists/normal.py,sha256=rtSDi0NAObH1LGRWiPZk_6cbSVv2dOPHkgxtWn6gFgM,2662
15
- bayinx/dists/uniform.py,sha256=PSZIIc2QfNF5XYi-TLGltnr_vnAIA-MZsn1rKV8QXAo,2353
16
- bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
17
- bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
18
- bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
19
- bayinx/mhx/vi/normalizing_flow.py,sha256=nj7bpIoMJl6GTOXPxQCAsPArchbHd5vwwqMm3cLbMII,4791
20
- bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
21
- bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
22
- bayinx/mhx/vi/flows/fullaffine.py,sha256=nXcPTZ_GSxIg7tmVxag694Fl1F95SKFSyDyt-9EDC9I,2049
23
- bayinx/mhx/vi/flows/planar.py,sha256=u9ZVwEeOv4fMfwiORlseCz463atsWMuid_9crRg05Z8,1919
24
- bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
25
- bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
26
- bayinx-0.2.27.dist-info/METADATA,sha256=5RPhGKmb6wWJquxrUlyt6QXWTSPEQycu5nFVZmQN9bU,3058
27
- bayinx-0.2.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
- bayinx-0.2.27.dist-info/RECORD,,
File without changes