bayinx 0.2.25__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +2 -1
- bayinx/constraints/__init__.py +1 -0
- bayinx/constraints/lower.py +51 -0
- bayinx/core/__init__.py +1 -0
- bayinx/core/constraint.py +28 -0
- bayinx/core/flow.py +9 -7
- bayinx/core/model.py +23 -19
- bayinx/core/parameter.py +41 -0
- bayinx/core/variational.py +15 -15
- bayinx/dists/censored/gamma2/r.py +65 -0
- bayinx/dists/gamma2.py +39 -0
- bayinx/dists/normal.py +17 -15
- bayinx/dists/uniform.py +9 -9
- bayinx/mhx/vi/flows/fullaffine.py +23 -14
- bayinx/mhx/vi/flows/planar.py +6 -6
- bayinx/mhx/vi/flows/radial.py +5 -5
- bayinx/mhx/vi/meanfield.py +20 -22
- bayinx/mhx/vi/normalizing_flow.py +27 -29
- bayinx/mhx/vi/standard.py +3 -3
- {bayinx-0.2.25.dist-info → bayinx-0.3.2.dist-info}/METADATA +1 -1
- bayinx-0.3.2.dist-info/RECORD +30 -0
- bayinx/core/constraints.py +0 -61
- bayinx/core/utils.py +0 -1
- bayinx/dists/gamma.py +0 -0
- bayinx-0.2.25.dist-info/RECORD +0 -28
- /bayinx/dists/{binomial.py → censored/__init__.py} +0 -0
- {bayinx-0.2.25.dist-info → bayinx-0.3.2.dist-info}/WHEEL +0 -0
bayinx/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1
|
-
from bayinx.core
|
1
|
+
from bayinx.core import Model as Model
|
2
|
+
from bayinx.core import Parameter as Parameter
|
@@ -0,0 +1 @@
|
|
1
|
+
from bayinx.constraints.lower import Lower as Lower
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jax.tree as jt
|
6
|
+
from jaxtyping import PyTree, Scalar, ScalarLike
|
7
|
+
|
8
|
+
from bayinx.core.constraint import Constraint
|
9
|
+
from bayinx.core.parameter import Parameter
|
10
|
+
|
11
|
+
|
12
|
+
class Lower(Constraint):
|
13
|
+
"""
|
14
|
+
Enforces a lower bound on the parameter.
|
15
|
+
"""
|
16
|
+
|
17
|
+
lb: Scalar
|
18
|
+
|
19
|
+
def __init__(self, lb: ScalarLike):
|
20
|
+
self.lb = jnp.array(lb)
|
21
|
+
|
22
|
+
@eqx.filter_jit
|
23
|
+
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
24
|
+
"""
|
25
|
+
Enforces a lower bound on the parameter and adjusts the posterior density.
|
26
|
+
|
27
|
+
# Parameters
|
28
|
+
- `x`: The unconstrained `Parameter`.
|
29
|
+
|
30
|
+
# Parameters
|
31
|
+
A tuple containing:
|
32
|
+
- A modified `Parameter` with relevant leaves satisfying the constraint.
|
33
|
+
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
34
|
+
"""
|
35
|
+
# Extract relevant filter specification
|
36
|
+
filter_spec = x.filter_spec
|
37
|
+
|
38
|
+
# Extract relevant parameters(all Array)
|
39
|
+
dyn_params, static_params = eqx.partition(x, filter_spec)
|
40
|
+
|
41
|
+
# Compute density adjustment
|
42
|
+
laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
|
43
|
+
laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
|
44
|
+
|
45
|
+
# Compute transformation
|
46
|
+
dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
|
47
|
+
|
48
|
+
# Combine into full parameter object
|
49
|
+
x = eqx.combine(dyn_params, static_params)
|
50
|
+
|
51
|
+
return x, laj
|
bayinx/core/__init__.py
CHANGED
@@ -0,0 +1,28 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import equinox as eqx
|
5
|
+
from jaxtyping import Scalar
|
6
|
+
|
7
|
+
from bayinx.core.parameter import Parameter
|
8
|
+
|
9
|
+
|
10
|
+
class Constraint(eqx.Module):
|
11
|
+
"""
|
12
|
+
Abstract base class for defining parameter constraints.
|
13
|
+
"""
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
17
|
+
"""
|
18
|
+
Applies the constraining transformation to a parameter and computes the log-absolute-Jacobian of the transformation.
|
19
|
+
|
20
|
+
# Parameters
|
21
|
+
- `x`: The unconstrained `Parameter`.
|
22
|
+
|
23
|
+
# Returns
|
24
|
+
A tuple containing:
|
25
|
+
- The constrained `Parameter`.
|
26
|
+
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
27
|
+
"""
|
28
|
+
pass
|
bayinx/core/flow.py
CHANGED
@@ -8,11 +8,11 @@ from jaxtyping import Array, Float
|
|
8
8
|
|
9
9
|
class Flow(eqx.Module):
|
10
10
|
"""
|
11
|
-
|
11
|
+
An abstract base class for a flow(of a normalizing flow).
|
12
12
|
|
13
13
|
# Attributes
|
14
14
|
- `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
|
15
|
-
- `constraints`: A dictionary of functions that constrain their corresponding parameter.
|
15
|
+
- `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
|
16
16
|
"""
|
17
17
|
|
18
18
|
params: Dict[str, Float[Array, "..."]]
|
@@ -28,14 +28,16 @@ class Flow(eqx.Module):
|
|
28
28
|
@abstractmethod
|
29
29
|
def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
|
30
30
|
"""
|
31
|
-
Computes the log-absolute-
|
31
|
+
Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
|
32
32
|
|
33
33
|
# Returns
|
34
|
-
|
34
|
+
A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
|
35
35
|
"""
|
36
36
|
pass
|
37
37
|
|
38
38
|
# Default filter specification
|
39
|
+
@property
|
40
|
+
@eqx.filter_jit
|
39
41
|
def filter_spec(self):
|
40
42
|
"""
|
41
43
|
Generates a filter specification to subset relevant parameters for the flow.
|
@@ -53,7 +55,7 @@ class Flow(eqx.Module):
|
|
53
55
|
return filter_spec
|
54
56
|
|
55
57
|
@eqx.filter_jit
|
56
|
-
def
|
58
|
+
def constrain_params(self: Self):
|
57
59
|
"""
|
58
60
|
Constrain `params` to the appropriate domain.
|
59
61
|
|
@@ -68,11 +70,11 @@ class Flow(eqx.Module):
|
|
68
70
|
return t_params
|
69
71
|
|
70
72
|
@eqx.filter_jit
|
71
|
-
def
|
73
|
+
def transform_params(self: Self) -> Dict[str, Array]:
|
72
74
|
"""
|
73
75
|
Apply a custom transformation to `params` if needed.
|
74
76
|
|
75
77
|
# Returns
|
76
78
|
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
77
79
|
"""
|
78
|
-
return self.
|
80
|
+
return self.constrain_params()
|
bayinx/core/model.py
CHANGED
@@ -1,24 +1,25 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
-
from typing import Any, Dict, Tuple
|
2
|
+
from typing import Any, Dict, Generic, Tuple, TypeVar
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
5
|
import jax.numpy as jnp
|
6
|
-
import jax.
|
7
|
-
from jaxtyping import
|
6
|
+
import jax.tree as jt
|
7
|
+
from jaxtyping import PyTree, Scalar
|
8
8
|
|
9
|
-
from bayinx.core.
|
9
|
+
from bayinx.core.constraint import Constraint
|
10
|
+
from bayinx.core.parameter import Parameter
|
10
11
|
|
11
|
-
|
12
|
-
class Model(eqx.Module):
|
12
|
+
P = TypeVar('P', bound=Dict[str, Parameter[PyTree]])
|
13
|
+
class Model(eqx.Module, Generic[P]):
|
13
14
|
"""
|
14
|
-
|
15
|
+
An abstract base class used to define probabilistic models.
|
15
16
|
|
16
17
|
# Attributes
|
17
|
-
- `params`: A dictionary of
|
18
|
-
- `constraints`: A dictionary of
|
18
|
+
- `params`: A dictionary of parameters.
|
19
|
+
- `constraints`: A dictionary of constraints.
|
19
20
|
"""
|
20
21
|
|
21
|
-
params:
|
22
|
+
params: P
|
22
23
|
constraints: Dict[str, Constraint]
|
23
24
|
|
24
25
|
@abstractmethod
|
@@ -26,32 +27,34 @@ class Model(eqx.Module):
|
|
26
27
|
pass
|
27
28
|
|
28
29
|
# Default filter specification
|
30
|
+
@property
|
31
|
+
@eqx.filter_jit
|
29
32
|
def filter_spec(self):
|
30
33
|
"""
|
31
34
|
Generates a filter specification to subset relevant parameters for the model.
|
32
35
|
"""
|
33
36
|
# Generate empty specification
|
34
|
-
filter_spec =
|
37
|
+
filter_spec = jt.map(lambda _: False, self)
|
35
38
|
|
36
|
-
# Specify
|
39
|
+
# Specify relevant parameters
|
37
40
|
filter_spec = eqx.tree_at(
|
38
41
|
lambda model: model.params,
|
39
42
|
filter_spec,
|
40
|
-
replace=
|
43
|
+
replace={key: param.filter_spec for key, param in self.params.items()}
|
41
44
|
)
|
42
45
|
|
43
46
|
return filter_spec
|
44
47
|
|
45
48
|
# Add constrain method
|
46
49
|
@eqx.filter_jit
|
47
|
-
def
|
50
|
+
def constrain_params(self) -> Tuple[P, Scalar]:
|
48
51
|
"""
|
49
52
|
Constrain `params` to the appropriate domain.
|
50
53
|
|
51
54
|
# Returns
|
52
|
-
A dictionary of
|
55
|
+
A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
|
53
56
|
"""
|
54
|
-
t_params:
|
57
|
+
t_params: P = self.params
|
55
58
|
target: Scalar = jnp.array(0.0)
|
56
59
|
|
57
60
|
for par, map in self.constraints.items():
|
@@ -63,12 +66,13 @@ class Model(eqx.Module):
|
|
63
66
|
|
64
67
|
return t_params, target
|
65
68
|
|
66
|
-
|
67
|
-
|
69
|
+
# Add default transform method
|
70
|
+
@eqx.filter_jit
|
71
|
+
def transform_params(self) -> Tuple[P, Scalar]:
|
68
72
|
"""
|
69
73
|
Apply a custom transformation to `params` if needed.
|
70
74
|
|
71
75
|
# Returns
|
72
76
|
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
73
77
|
"""
|
74
|
-
return self.
|
78
|
+
return self.constrain_params()
|
bayinx/core/parameter.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Generic, Self, TypeVar
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.tree as jt
|
5
|
+
from jaxtyping import PyTree
|
6
|
+
|
7
|
+
T = TypeVar('T', bound=PyTree)
|
8
|
+
class Parameter(eqx.Module, Generic[T]):
|
9
|
+
"""
|
10
|
+
A container for a parameter of a `Model`.
|
11
|
+
|
12
|
+
Subclasses can be constructed for custom filter specifications(`filter_spec`).
|
13
|
+
|
14
|
+
# Attributes
|
15
|
+
- `vals`: The parameter's value(s).
|
16
|
+
"""
|
17
|
+
vals: T
|
18
|
+
|
19
|
+
|
20
|
+
def __init__(self, values: T):
|
21
|
+
# Insert parameter values
|
22
|
+
self.vals = values
|
23
|
+
|
24
|
+
# Default filter specification
|
25
|
+
@property
|
26
|
+
@eqx.filter_jit
|
27
|
+
def filter_spec(self) -> Self:
|
28
|
+
"""
|
29
|
+
Generates a filter specification to filter out static parameters.
|
30
|
+
"""
|
31
|
+
# Generate empty specification
|
32
|
+
filter_spec = jt.map(lambda _: False, self)
|
33
|
+
|
34
|
+
# Specify Array leaves
|
35
|
+
filter_spec = eqx.tree_at(
|
36
|
+
lambda params: params.vals,
|
37
|
+
filter_spec,
|
38
|
+
replace=jt.map(eqx.is_array_like, self.vals),
|
39
|
+
)
|
40
|
+
|
41
|
+
return filter_spec
|
bayinx/core/variational.py
CHANGED
@@ -8,7 +8,7 @@ import jax.lax as lax
|
|
8
8
|
import jax.numpy as jnp
|
9
9
|
import jax.random as jr
|
10
10
|
import optax as opx
|
11
|
-
from jaxtyping import Array,
|
11
|
+
from jaxtyping import Array, Key, PyTree, Scalar
|
12
12
|
from optax import GradientTransformation, OptState, Schedule
|
13
13
|
|
14
14
|
from bayinx.core import Model
|
@@ -16,16 +16,23 @@ from bayinx.core import Model
|
|
16
16
|
|
17
17
|
class Variational(eqx.Module):
|
18
18
|
"""
|
19
|
-
|
19
|
+
An abstract base class used to define variational methods.
|
20
20
|
|
21
21
|
# Attributes
|
22
|
-
- `_unflatten`: A
|
23
|
-
- `_constraints`:
|
22
|
+
- `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
|
23
|
+
- `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
|
24
24
|
"""
|
25
25
|
|
26
|
-
_unflatten: Callable[[
|
26
|
+
_unflatten: Callable[[Array], Model]
|
27
27
|
_constraints: Model
|
28
28
|
|
29
|
+
@abstractmethod
|
30
|
+
def filter_spec(self):
|
31
|
+
"""
|
32
|
+
Filter specification for dynamic and static components of the `Variational`.
|
33
|
+
"""
|
34
|
+
pass
|
35
|
+
|
29
36
|
@abstractmethod
|
30
37
|
def sample(self, n: int, key: Key) -> Array:
|
31
38
|
"""
|
@@ -54,13 +61,6 @@ class Variational(eqx.Module):
|
|
54
61
|
"""
|
55
62
|
pass
|
56
63
|
|
57
|
-
@abstractmethod
|
58
|
-
def filter_spec(self):
|
59
|
-
"""
|
60
|
-
Filter specification for dynamic and static components of the `Variational`.
|
61
|
-
"""
|
62
|
-
pass
|
63
|
-
|
64
64
|
@eqx.filter_jit
|
65
65
|
@partial(jax.vmap, in_axes=(None, 0, None))
|
66
66
|
def eval_model(self, draws: Array, data: Any = None) -> Array:
|
@@ -103,7 +103,7 @@ class Variational(eqx.Module):
|
|
103
103
|
- `key`: A PRNG key.
|
104
104
|
"""
|
105
105
|
# Partition variational
|
106
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
106
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
107
107
|
|
108
108
|
# Construct scheduler
|
109
109
|
schedule: Schedule = opx.cosine_decay_schedule(
|
@@ -135,7 +135,7 @@ class Variational(eqx.Module):
|
|
135
135
|
# Update PRNG key
|
136
136
|
key, _ = jr.split(key)
|
137
137
|
|
138
|
-
#
|
138
|
+
# Reconstruct variational
|
139
139
|
vari = eqx.combine(dyn, static)
|
140
140
|
|
141
141
|
# Compute gradient of the ELBO
|
@@ -143,7 +143,7 @@ class Variational(eqx.Module):
|
|
143
143
|
|
144
144
|
# Compute updates
|
145
145
|
updates, opt_state = optim.update(
|
146
|
-
updates, opt_state, eqx.filter(dyn, dyn.filter_spec
|
146
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
|
147
147
|
)
|
148
148
|
|
149
149
|
# Update variational distribution
|
@@ -0,0 +1,65 @@
|
|
1
|
+
import jax.lax as lax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from jax.scipy.special import gammaincc
|
4
|
+
from jaxtyping import Array, ArrayLike, Float
|
5
|
+
|
6
|
+
from bayinx.dists import gamma2
|
7
|
+
|
8
|
+
|
9
|
+
def prob(
|
10
|
+
x: Float[ArrayLike, "..."],
|
11
|
+
mu: Float[ArrayLike, "..."],
|
12
|
+
nu: Float[ArrayLike, "..."],
|
13
|
+
censor: Float[ArrayLike, "..."]
|
14
|
+
) -> Float[Array, "..."]:
|
15
|
+
"""
|
16
|
+
The mixed probability mass/density function (PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
|
17
|
+
|
18
|
+
# Parameters
|
19
|
+
- `x`: Value(s) at which to evaluate the PMF/PDF.
|
20
|
+
- `mu`: The positive mean.
|
21
|
+
- `nu`: The positive inverse dispersion.
|
22
|
+
|
23
|
+
# Returns
|
24
|
+
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
25
|
+
"""
|
26
|
+
evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
|
27
|
+
|
28
|
+
# Construct boolean masks
|
29
|
+
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
30
|
+
censored: Array = jnp.array(x == censor) # pyright: ignore
|
31
|
+
|
32
|
+
# Evaluate mixed probability (?) function
|
33
|
+
evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
|
34
|
+
evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
|
35
|
+
|
36
|
+
return evals
|
37
|
+
|
38
|
+
|
39
|
+
def logprob(
|
40
|
+
x: Float[ArrayLike, "..."],
|
41
|
+
mu: Float[ArrayLike, "..."],
|
42
|
+
nu: Float[ArrayLike, "..."],
|
43
|
+
censor: Float[ArrayLike, "..."]
|
44
|
+
) -> Float[Array, "..."]:
|
45
|
+
"""
|
46
|
+
The log-transformed mixed probability mass/density function (log PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
|
47
|
+
|
48
|
+
# Parameters
|
49
|
+
- `x`: Value(s) at which to evaluate the log PMF/PDF.
|
50
|
+
- `mu`: The positive mean/location.
|
51
|
+
- `nu`: The positive inverse dispersion.
|
52
|
+
|
53
|
+
# Returns
|
54
|
+
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
55
|
+
"""
|
56
|
+
evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
|
57
|
+
|
58
|
+
# Construct boolean masks
|
59
|
+
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
60
|
+
censored: Array = jnp.array(x == censor) # pyright: ignore
|
61
|
+
|
62
|
+
evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
|
63
|
+
evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
|
64
|
+
|
65
|
+
return evals
|
bayinx/dists/gamma2.py
CHANGED
@@ -0,0 +1,39 @@
|
|
1
|
+
import jax.lax as lax
|
2
|
+
from jax.scipy.special import gammaln
|
3
|
+
from jaxtyping import Array, ArrayLike, Float
|
4
|
+
|
5
|
+
|
6
|
+
def prob(
|
7
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
|
8
|
+
) -> Float[Array, "..."]:
|
9
|
+
"""
|
10
|
+
The probability density function (PDF) for a (mean-precision parameterized) Gamma distribution.
|
11
|
+
|
12
|
+
# Parameters
|
13
|
+
- `x`: Value(s) at which to evaluate the PDF.
|
14
|
+
- `mu`: The positive mean.
|
15
|
+
- `nu`: The positive inverse dispersion.
|
16
|
+
|
17
|
+
# Returns
|
18
|
+
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
19
|
+
"""
|
20
|
+
|
21
|
+
return lax.exp(logprob(x, mu, nu))
|
22
|
+
|
23
|
+
|
24
|
+
def logprob(
|
25
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
|
26
|
+
) -> Float[Array, "..."]:
|
27
|
+
"""
|
28
|
+
The log-transformed probability density function (log PDF) for a (mean-precision parameterized) Gamma distribution.
|
29
|
+
|
30
|
+
# Parameters
|
31
|
+
- `x`: Value(s) at which to evaluate the log PDF.
|
32
|
+
- `mu`: The positive mean/location.
|
33
|
+
- `nu`: The positive inverse dispersion.
|
34
|
+
|
35
|
+
# Returns
|
36
|
+
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
37
|
+
"""
|
38
|
+
|
39
|
+
return - gammaln(nu) + nu * (lax.log(nu) - lax.log(mu)) + (nu - 1.0) * lax.log(x) - (x * nu / mu) # pyright: ignore
|
bayinx/dists/normal.py
CHANGED
@@ -1,31 +1,31 @@
|
|
1
|
-
import jax.lax as
|
2
|
-
from jaxtyping import Array, ArrayLike, Float
|
1
|
+
import jax.lax as lax
|
2
|
+
from jaxtyping import Array, ArrayLike, Float
|
3
3
|
|
4
4
|
__PI = 3.141592653589793
|
5
5
|
|
6
6
|
|
7
7
|
def prob(
|
8
|
-
x:
|
8
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
9
9
|
) -> Float[Array, "..."]:
|
10
10
|
"""
|
11
11
|
The probability density function (PDF) for a Normal distribution.
|
12
12
|
|
13
13
|
# Parameters
|
14
14
|
- `x`: Value(s) at which to evaluate the PDF.
|
15
|
-
- `mu`: The mean/location
|
16
|
-
- `sigma`: The
|
15
|
+
- `mu`: The mean/location.
|
16
|
+
- `sigma`: The positive standard deviation.
|
17
17
|
|
18
18
|
# Returns
|
19
19
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
20
20
|
"""
|
21
21
|
|
22
|
-
return
|
23
|
-
sigma *
|
22
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / ( # pyright: ignore
|
23
|
+
sigma * lax.sqrt(2.0 * __PI)
|
24
24
|
)
|
25
25
|
|
26
26
|
|
27
27
|
def logprob(
|
28
|
-
x:
|
28
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
29
29
|
) -> Float[Array, "..."]:
|
30
30
|
"""
|
31
31
|
The log of the probability density function (log PDF) for a Normal distribution.
|
@@ -36,14 +36,16 @@ def logprob(
|
|
36
36
|
- `sigma`: The non-negative standard deviation parameter(s).
|
37
37
|
|
38
38
|
# Returns
|
39
|
-
The log
|
39
|
+
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
40
40
|
"""
|
41
41
|
|
42
|
-
return -
|
42
|
+
return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square(
|
43
|
+
(x - mu) / sigma # pyright: ignore
|
44
|
+
)
|
43
45
|
|
44
46
|
|
45
47
|
def uprob(
|
46
|
-
x:
|
48
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
47
49
|
) -> Float[Array, "..."]:
|
48
50
|
"""
|
49
51
|
The unnormalized probability density function (uPDF) for a Normal distribution.
|
@@ -51,17 +53,17 @@ def uprob(
|
|
51
53
|
# Parameters
|
52
54
|
- `x`: Value(s) at which to evaluate the uPDF.
|
53
55
|
- `mu`: The mean/location parameter(s).
|
54
|
-
- `sigma`: The
|
56
|
+
- `sigma`: The positive standard deviation parameter(s).
|
55
57
|
|
56
58
|
# Returns
|
57
59
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
58
60
|
"""
|
59
61
|
|
60
|
-
return
|
62
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma # pyright: ignore
|
61
63
|
|
62
64
|
|
63
65
|
def ulogprob(
|
64
|
-
x:
|
66
|
+
x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
|
65
67
|
) -> Float[Array, "..."]:
|
66
68
|
"""
|
67
69
|
The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
|
@@ -75,4 +77,4 @@ def ulogprob(
|
|
75
77
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
76
78
|
"""
|
77
79
|
|
78
|
-
return -
|
80
|
+
return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma) # pyright: ignore
|
bayinx/dists/uniform.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
import jax.lax as _lax
|
2
2
|
import jax.numpy as jnp
|
3
|
-
from jaxtyping import Array, ArrayLike, Float
|
3
|
+
from jaxtyping import Array, ArrayLike, Float
|
4
4
|
|
5
5
|
|
6
6
|
def prob(
|
7
|
-
x:
|
7
|
+
x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
|
8
8
|
) -> Float[Array, "..."]:
|
9
9
|
"""
|
10
10
|
The probability density function (PDF) for a Uniform distribution.
|
@@ -18,11 +18,11 @@ def prob(
|
|
18
18
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
19
19
|
"""
|
20
20
|
|
21
|
-
return 1.0 / (ub - lb)
|
21
|
+
return 1.0 / (ub - lb) # pyright: ignore
|
22
22
|
|
23
23
|
|
24
24
|
def logprob(
|
25
|
-
x:
|
25
|
+
x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
|
26
26
|
) -> Float[Array, "..."]:
|
27
27
|
"""
|
28
28
|
The log of the probability density function (log PDF) for a Uniform distribution.
|
@@ -36,11 +36,11 @@ def logprob(
|
|
36
36
|
The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
37
37
|
"""
|
38
38
|
|
39
|
-
return _lax.log(1.0) - _lax.log(ub - lb)
|
39
|
+
return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
|
40
40
|
|
41
41
|
|
42
42
|
def uprob(
|
43
|
-
x:
|
43
|
+
x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
|
44
44
|
) -> Float[Array, "..."]:
|
45
45
|
"""
|
46
46
|
The unnormalized probability density function (uPDF) for a Uniform distribution.
|
@@ -54,11 +54,11 @@ def uprob(
|
|
54
54
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
55
55
|
"""
|
56
56
|
|
57
|
-
return jnp.ones(jnp.broadcast_arrays(x,lb,ub))
|
57
|
+
return jnp.ones(jnp.broadcast_arrays(x, lb, ub))
|
58
58
|
|
59
59
|
|
60
60
|
def ulogprob(
|
61
|
-
x:
|
61
|
+
x: Float[ArrayLike, "..."], lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."]
|
62
62
|
) -> Float[Array, "..."]:
|
63
63
|
"""
|
64
64
|
The log of the unnormalized probability density function (log uPDF) for a Uniform distribution.
|
@@ -72,4 +72,4 @@ def ulogprob(
|
|
72
72
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
|
73
73
|
"""
|
74
74
|
|
75
|
-
return jnp.zeros(jnp.broadcast_arrays(x,lb,ub))
|
75
|
+
return jnp.zeros(jnp.broadcast_arrays(x, lb, ub))
|
@@ -1,29 +1,26 @@
|
|
1
1
|
from functools import partial
|
2
|
-
from typing import
|
2
|
+
from typing import Tuple
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
-
from jaxtyping import Array,
|
7
|
+
from jaxtyping import Array, Scalar
|
8
8
|
|
9
9
|
from bayinx.core import Flow
|
10
10
|
|
11
11
|
|
12
12
|
class FullAffine(Flow):
|
13
13
|
"""
|
14
|
-
|
14
|
+
A full affine flow.
|
15
15
|
|
16
16
|
# Attributes
|
17
17
|
- `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
|
18
18
|
- `constraints`: A dictionary of constraining transformations.
|
19
19
|
"""
|
20
20
|
|
21
|
-
params: Dict[str, Float[Array, "..."]]
|
22
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
23
|
-
|
24
21
|
def __init__(self, dim: int):
|
25
22
|
"""
|
26
|
-
Initializes
|
23
|
+
Initializes a full affine flow.
|
27
24
|
|
28
25
|
# Parameters
|
29
26
|
- `dim`: The dimension of the parameter space.
|
@@ -33,11 +30,23 @@ class FullAffine(Flow):
|
|
33
30
|
"scale": jnp.zeros((dim, dim)),
|
34
31
|
}
|
35
32
|
|
36
|
-
|
33
|
+
if dim == 1:
|
34
|
+
self.constraints = {}
|
35
|
+
else:
|
36
|
+
|
37
|
+
@eqx.filter_jit
|
38
|
+
def constrain_scale(scale: Array):
|
39
|
+
# Extract diagonal and apply exponential
|
40
|
+
diag: Array = jnp.exp(jnp.diag(scale))
|
41
|
+
|
42
|
+
# Return matrix with modified diagonal
|
43
|
+
return jnp.fill_diagonal(jnp.tril(scale), diag, inplace=False)
|
44
|
+
|
45
|
+
self.constraints = {"scale": constrain_scale}
|
37
46
|
|
38
47
|
@eqx.filter_jit
|
39
48
|
def forward(self, draws: Array) -> Array:
|
40
|
-
params = self.
|
49
|
+
params = self.transform_params()
|
41
50
|
|
42
51
|
# Extract parameters
|
43
52
|
shift: Array = params["shift"]
|
@@ -50,8 +59,8 @@ class FullAffine(Flow):
|
|
50
59
|
|
51
60
|
@eqx.filter_jit
|
52
61
|
@partial(jax.vmap, in_axes=(None, 0))
|
53
|
-
def adjust_density(self, draws: Array) -> Tuple[
|
54
|
-
params = self.
|
62
|
+
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
63
|
+
params = self.transform_params()
|
55
64
|
|
56
65
|
# Extract parameters
|
57
66
|
shift: Array = params["shift"]
|
@@ -60,7 +69,7 @@ class FullAffine(Flow):
|
|
60
69
|
# Compute forward transformation
|
61
70
|
draws = draws @ scale + shift
|
62
71
|
|
63
|
-
# Compute
|
64
|
-
|
72
|
+
# Compute laj
|
73
|
+
laj: Scalar = jnp.log(jnp.diag(scale)).sum()
|
65
74
|
|
66
|
-
return
|
75
|
+
return draws, laj
|
bayinx/mhx/vi/flows/planar.py
CHANGED
@@ -39,7 +39,7 @@ class Planar(Flow):
|
|
39
39
|
@eqx.filter_jit
|
40
40
|
@partial(jax.vmap, in_axes=(None, 0))
|
41
41
|
def forward(self, draws: Array) -> Array:
|
42
|
-
params = self.
|
42
|
+
params = self.transform_params()
|
43
43
|
|
44
44
|
# Extract parameters
|
45
45
|
w: Array = params["w"]
|
@@ -53,8 +53,8 @@ class Planar(Flow):
|
|
53
53
|
|
54
54
|
@eqx.filter_jit
|
55
55
|
@partial(jax.vmap, in_axes=(None, 0))
|
56
|
-
def adjust_density(self, draws: Array) -> Tuple[
|
57
|
-
params = self.
|
56
|
+
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
57
|
+
params = self.transform_params()
|
58
58
|
|
59
59
|
# Extract parameters
|
60
60
|
w: Array = params["w"]
|
@@ -67,8 +67,8 @@ class Planar(Flow):
|
|
67
67
|
# Compute forward transformation
|
68
68
|
draws = draws + u * jnp.tanh(x)
|
69
69
|
|
70
|
-
# Compute
|
70
|
+
# Compute laj
|
71
71
|
h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
|
72
|
-
|
72
|
+
laj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
|
73
73
|
|
74
|
-
return
|
74
|
+
return draws, laj
|
bayinx/mhx/vi/flows/radial.py
CHANGED
@@ -49,7 +49,7 @@ class Radial(Flow):
|
|
49
49
|
# Returns
|
50
50
|
The transformed samples.
|
51
51
|
"""
|
52
|
-
params = self.
|
52
|
+
params = self.transform_params()
|
53
53
|
|
54
54
|
# Extract parameters
|
55
55
|
alpha = params["alpha"]
|
@@ -66,8 +66,8 @@ class Radial(Flow):
|
|
66
66
|
|
67
67
|
@partial(jax.vmap, in_axes=(None, 0))
|
68
68
|
@eqx.filter_jit
|
69
|
-
def adjust_density(self, draws: Array) -> Tuple[
|
70
|
-
params = self.
|
69
|
+
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
70
|
+
params = self.transform_params()
|
71
71
|
|
72
72
|
# Extract parameters
|
73
73
|
alpha = params["alpha"]
|
@@ -84,11 +84,11 @@ class Radial(Flow):
|
|
84
84
|
draws = draws + (x) * (draws - center)
|
85
85
|
|
86
86
|
# Compute density adjustment
|
87
|
-
|
87
|
+
laj = jnp.log(
|
88
88
|
jnp.abs(
|
89
89
|
(1.0 + alpha * beta / (alpha + r) ** 2.0)
|
90
90
|
* (1.0 + x) ** (center.size - 1.0)
|
91
91
|
)
|
92
92
|
)
|
93
93
|
|
94
|
-
return
|
94
|
+
return draws, laj
|
bayinx/mhx/vi/meanfield.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Dict, Self
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.numpy as jnp
|
@@ -20,8 +20,6 @@ class MeanField(Variational):
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
var_params: Dict[str, Float[Array, "..."]]
|
23
|
-
_unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
|
24
|
-
_constraints: Model = eqx.field(static=True)
|
25
23
|
|
26
24
|
def __init__(self, model: Model):
|
27
25
|
"""
|
@@ -31,7 +29,7 @@ class MeanField(Variational):
|
|
31
29
|
- `model`: A probabilistic `Model` object.
|
32
30
|
"""
|
33
31
|
# Partition model
|
34
|
-
params, self._constraints = eqx.partition(model, model.filter_spec
|
32
|
+
params, self._constraints = eqx.partition(model, model.filter_spec)
|
35
33
|
|
36
34
|
# Flatten params component
|
37
35
|
params, self._unflatten = ravel_pytree(params)
|
@@ -42,6 +40,22 @@ class MeanField(Variational):
|
|
42
40
|
"log_std": jnp.zeros(params.size, dtype=params.dtype),
|
43
41
|
}
|
44
42
|
|
43
|
+
@property
|
44
|
+
@eqx.filter_jit
|
45
|
+
def filter_spec(self):
|
46
|
+
# Generate empty specification
|
47
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
48
|
+
|
49
|
+
# Specify variational parameters
|
50
|
+
filter_spec = eqx.tree_at(
|
51
|
+
lambda mf: mf.var_params,
|
52
|
+
filter_spec,
|
53
|
+
replace=True,
|
54
|
+
)
|
55
|
+
|
56
|
+
return filter_spec
|
57
|
+
|
58
|
+
|
45
59
|
@eqx.filter_jit
|
46
60
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
47
61
|
# Sample variational draws
|
@@ -61,27 +75,12 @@ class MeanField(Variational):
|
|
61
75
|
sigma=jnp.exp(self.var_params["log_std"]),
|
62
76
|
).sum(axis=1)
|
63
77
|
|
64
|
-
@eqx.filter_jit
|
65
|
-
def filter_spec(self):
|
66
|
-
filter_spec = jtu.tree_map(lambda _: False, self)
|
67
|
-
filter_spec = eqx.tree_at(
|
68
|
-
lambda mf: mf.var_params,
|
69
|
-
filter_spec,
|
70
|
-
replace=True,
|
71
|
-
)
|
72
|
-
return filter_spec
|
73
|
-
|
74
78
|
@eqx.filter_jit
|
75
79
|
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
76
|
-
|
77
|
-
Estimate the ELBO and its gradient(w.r.t the variational parameters).
|
78
|
-
"""
|
79
|
-
# Partition variational
|
80
|
-
dyn, static = eqx.partition(self, self.filter_spec())
|
80
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
81
81
|
|
82
82
|
@eqx.filter_jit
|
83
83
|
def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
|
84
|
-
# Combine
|
85
84
|
vari = eqx.combine(dyn, static)
|
86
85
|
|
87
86
|
# Sample draws from variational distribution
|
@@ -100,8 +99,7 @@ class MeanField(Variational):
|
|
100
99
|
|
101
100
|
@eqx.filter_jit
|
102
101
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
103
|
-
|
104
|
-
dyn, static = eqx.partition(self, self.filter_spec())
|
102
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
105
103
|
|
106
104
|
@eqx.filter_grad
|
107
105
|
@eqx.filter_jit
|
@@ -1,11 +1,11 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Self, Tuple
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.flatten_util as jfu
|
5
5
|
import jax.numpy as jnp
|
6
6
|
import jax.random as jr
|
7
7
|
import jax.tree_util as jtu
|
8
|
-
from jaxtyping import Array,
|
8
|
+
from jaxtyping import Array, Key, Scalar
|
9
9
|
|
10
10
|
from bayinx.core import Flow, Model, Variational
|
11
11
|
|
@@ -17,14 +17,11 @@ class NormalizingFlow(Variational):
|
|
17
17
|
|
18
18
|
# Attributes
|
19
19
|
- `base`: A base variational distribution.
|
20
|
-
- `flows`: An ordered collection of continuously parameterized
|
21
|
-
diffeomorphisms.
|
20
|
+
- `flows`: An ordered collection of continuously parameterized diffeomorphisms.
|
22
21
|
"""
|
23
22
|
|
24
23
|
flows: list[Flow]
|
25
24
|
base: Variational
|
26
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
27
|
-
_constraints: Model
|
28
25
|
|
29
26
|
def __init__(self, base: Variational, flows: list[Flow], model: Model):
|
30
27
|
"""
|
@@ -36,7 +33,7 @@ class NormalizingFlow(Variational):
|
|
36
33
|
- `model`: A probabilistic `Model` object.
|
37
34
|
"""
|
38
35
|
# Partition model
|
39
|
-
params, self._constraints = eqx.partition(model,
|
36
|
+
params, self._constraints = eqx.partition(model, model.filter_spec)
|
40
37
|
|
41
38
|
# Flatten params component
|
42
39
|
_, self._unflatten = jfu.ravel_pytree(params)
|
@@ -44,6 +41,21 @@ class NormalizingFlow(Variational):
|
|
44
41
|
self.base = base
|
45
42
|
self.flows = flows
|
46
43
|
|
44
|
+
@property
|
45
|
+
@eqx.filter_jit
|
46
|
+
def filter_spec(self):
|
47
|
+
# Generate empty specification
|
48
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
49
|
+
|
50
|
+
# Specify variational parameters based on each flow's filter spec.
|
51
|
+
filter_spec = eqx.tree_at(
|
52
|
+
lambda vari: vari.flows,
|
53
|
+
filter_spec,
|
54
|
+
replace=[flow.filter_spec for flow in self.flows],
|
55
|
+
)
|
56
|
+
|
57
|
+
return filter_spec
|
58
|
+
|
47
59
|
@eqx.filter_jit
|
48
60
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)):
|
49
61
|
"""
|
@@ -65,19 +77,18 @@ class NormalizingFlow(Variational):
|
|
65
77
|
|
66
78
|
for map in self.flows:
|
67
79
|
# Compute adjustment
|
68
|
-
|
80
|
+
draws, laj = map.adjust_density(draws)
|
69
81
|
|
70
82
|
# Adjust variational density
|
71
|
-
variational_evals = variational_evals -
|
83
|
+
variational_evals = variational_evals - laj
|
72
84
|
|
73
85
|
return variational_evals
|
74
86
|
|
75
87
|
@eqx.filter_jit
|
76
88
|
def __eval(self, draws: Array, data=None) -> Tuple[Array, Array]:
|
77
89
|
"""
|
78
|
-
Evaluate the posterior and variational densities at the
|
79
|
-
`draws` to avoid extra compute
|
80
|
-
the posterior evaluation.
|
90
|
+
Evaluate the posterior and variational densities together at the
|
91
|
+
transformed `draws` to avoid extra compute.
|
81
92
|
|
82
93
|
# Parameters
|
83
94
|
- `draws`: Draws from the base variational distribution.
|
@@ -91,32 +102,19 @@ class NormalizingFlow(Variational):
|
|
91
102
|
|
92
103
|
for map in self.flows:
|
93
104
|
# Compute adjustment
|
94
|
-
|
105
|
+
draws, laj = map.adjust_density(draws)
|
95
106
|
|
96
107
|
# Adjust variational density
|
97
|
-
variational_evals = variational_evals -
|
108
|
+
variational_evals = variational_evals - laj
|
98
109
|
|
99
110
|
# Evaluate posterior at final variational draws
|
100
111
|
posterior_evals = self.eval_model(draws, data)
|
101
112
|
|
102
113
|
return posterior_evals, variational_evals
|
103
114
|
|
104
|
-
def filter_spec(self):
|
105
|
-
# Generate empty specification
|
106
|
-
filter_spec = jtu.tree_map(lambda _: False, self)
|
107
|
-
|
108
|
-
# Specify variational parameters based on each flow's filter spec.
|
109
|
-
filter_spec = eqx.tree_at(
|
110
|
-
lambda vari: vari.flows,
|
111
|
-
filter_spec,
|
112
|
-
replace=[flow.filter_spec() for flow in self.flows],
|
113
|
-
)
|
114
|
-
|
115
|
-
return filter_spec
|
116
|
-
|
117
115
|
@eqx.filter_jit
|
118
116
|
def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
|
119
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
117
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
120
118
|
|
121
119
|
@eqx.filter_jit
|
122
120
|
def elbo(dyn: Self, n: int, key: Key, data: Any = None):
|
@@ -133,7 +131,7 @@ class NormalizingFlow(Variational):
|
|
133
131
|
|
134
132
|
@eqx.filter_jit
|
135
133
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
136
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
134
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
137
135
|
|
138
136
|
@eqx.filter_grad
|
139
137
|
@eqx.filter_jit
|
bayinx/mhx/vi/standard.py
CHANGED
@@ -19,7 +19,7 @@ class Standard(Variational):
|
|
19
19
|
- `dim`: Dimension of the parameter space.
|
20
20
|
"""
|
21
21
|
|
22
|
-
dim: int
|
22
|
+
dim: int
|
23
23
|
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
24
|
_constraints: Model
|
25
25
|
|
@@ -31,7 +31,7 @@ class Standard(Variational):
|
|
31
31
|
- `model`: A probabilistic `Model` object.
|
32
32
|
"""
|
33
33
|
# Partition model
|
34
|
-
params, self._constraints = eqx.partition(model, model.filter_spec
|
34
|
+
params, self._constraints = eqx.partition(model, model.filter_spec)
|
35
35
|
|
36
36
|
# Flatten params component
|
37
37
|
params, self._unflatten = ravel_pytree(params)
|
@@ -54,7 +54,7 @@ class Standard(Variational):
|
|
54
54
|
sigma=jnp.array(1.0),
|
55
55
|
).sum(axis=1, keepdims=True)
|
56
56
|
|
57
|
-
@
|
57
|
+
@property
|
58
58
|
def filter_spec(self):
|
59
59
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
60
60
|
|
@@ -0,0 +1,30 @@
|
|
1
|
+
bayinx/__init__.py,sha256=htihTsJ54k-ljBLzt4ye8DR7ORwZhxv-bLMcEaDQeqY,86
|
2
|
+
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
bayinx/constraints/__init__.py,sha256=PSxvcuSox2JL61AG1iag2PTNKPcid_DbOQzHpYdj5RE,52
|
4
|
+
bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
|
5
|
+
bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
|
6
|
+
bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
|
7
|
+
bayinx/core/flow.py,sha256=lAPJdQnrIxC3JoowTp77Gvm0p0v_xQn8FMjFJWMnWbc,2340
|
8
|
+
bayinx/core/model.py,sha256=ADSMapUJGyvKf_TpeC7Foaa3BJ03_Kc7FZxIEKNQkZE,2228
|
9
|
+
bayinx/core/parameter.py,sha256=oxCCZcZ-DDBvfWzexfhQkSJPxNQnE1vYXtBhiEZG2bM,1025
|
10
|
+
bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
|
11
|
+
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
+
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
13
|
+
bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
|
14
|
+
bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
|
15
|
+
bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
|
16
|
+
bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
|
18
|
+
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
19
|
+
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
20
|
+
bayinx/mhx/vi/meanfield.py,sha256=M4QrOuHaIMLTuQSD6JNF9vELnTm370tXV68JPB7B67M,3652
|
21
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=9c5ayMJ_Wgq6pUb1GYHIFIzw3Bf1AsIdOjcerLoYMrA,4655
|
22
|
+
bayinx/mhx/vi/standard.py,sha256=DfSV0r9oXzp9UM8OsZBpoJPRUhiDoAq_X2_2z_M83lA,1685
|
23
|
+
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
24
|
+
bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
|
25
|
+
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
26
|
+
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
27
|
+
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
28
|
+
bayinx-0.3.2.dist-info/METADATA,sha256=9cltWLDiwqg6VpnufQfKYEw_5ZCywJRp7gAPZAogLlA,3057
|
29
|
+
bayinx-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
30
|
+
bayinx-0.3.2.dist-info/RECORD,,
|
bayinx/core/constraints.py
DELETED
@@ -1,61 +0,0 @@
|
|
1
|
-
from abc import abstractmethod
|
2
|
-
from typing import Tuple
|
3
|
-
|
4
|
-
import equinox as eqx
|
5
|
-
import jax.numpy as jnp
|
6
|
-
from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
|
7
|
-
|
8
|
-
|
9
|
-
class Constraint(eqx.Module):
|
10
|
-
"""
|
11
|
-
Abstract base class for defining parameter constraints.
|
12
|
-
|
13
|
-
Subclasses should implement the `constrain` method to apply the
|
14
|
-
transformation and compute the ladj adjustment.
|
15
|
-
"""
|
16
|
-
@abstractmethod
|
17
|
-
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
18
|
-
"""
|
19
|
-
Applies the constraining transformation to an unconstrained input
|
20
|
-
and computes the log absolute determinant of the Jacobian (ladj)
|
21
|
-
of this transformation.
|
22
|
-
|
23
|
-
# Parameters
|
24
|
-
- `x`: The unconstrained JAX Array-like input.
|
25
|
-
|
26
|
-
# Returns
|
27
|
-
A tuple containing:
|
28
|
-
- The constrained JAX Array.
|
29
|
-
- A scalar JAX Array representing the ladj of the transformation.
|
30
|
-
"""
|
31
|
-
pass
|
32
|
-
|
33
|
-
|
34
|
-
class LowerBound(Constraint):
|
35
|
-
"""
|
36
|
-
Enforces a lower bound on the parameter.
|
37
|
-
"""
|
38
|
-
lb: ScalarLike
|
39
|
-
|
40
|
-
def __init__(self, lb: ScalarLike):
|
41
|
-
self.lb = lb
|
42
|
-
|
43
|
-
def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
|
44
|
-
"""
|
45
|
-
Applies the lower bound constraint and computes the ladj.
|
46
|
-
|
47
|
-
# Parameters
|
48
|
-
- `x`: The unconstrained JAX Array-like input.
|
49
|
-
|
50
|
-
# Parameters
|
51
|
-
A tuple containing:
|
52
|
-
- The constrained JAX Array (x > self.lb).
|
53
|
-
- A scalar JAX Array representing the ladj of the transformation.
|
54
|
-
"""
|
55
|
-
# Compute transformation adjustment
|
56
|
-
ladj = jnp.sum(x)
|
57
|
-
|
58
|
-
# Compute transformation
|
59
|
-
x = jnp.exp(x) + self.lb
|
60
|
-
|
61
|
-
return x, ladj
|
bayinx/core/utils.py
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
|
bayinx/dists/gamma.py
DELETED
File without changes
|
bayinx-0.2.25.dist-info/RECORD
DELETED
@@ -1,28 +0,0 @@
|
|
1
|
-
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
|
-
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
-
bayinx/core/constraints.py,sha256=Y8FJX3CkgnLQ3HXuTPGuzvLtXlKs0B7z0-YymoHgdfg,1682
|
5
|
-
bayinx/core/flow.py,sha256=9swS5wh7AsIZWgG_IhQS-upcPlr7G-juaP_5rsbX6G0,2363
|
6
|
-
bayinx/core/model.py,sha256=U1xBnAXnIvFJjWF-XIT8BYjP9PtoRZY_PwyhRwJf-HA,2144
|
7
|
-
bayinx/core/utils.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
8
|
-
bayinx/core/variational.py,sha256=vUZ6u5CXCHfs6ZrA8PF4PHfmUXHTK2RJGHyZ3afFfsg,4820
|
9
|
-
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
11
|
-
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
bayinx/dists/normal.py,sha256=rtSDi0NAObH1LGRWiPZk_6cbSVv2dOPHkgxtWn6gFgM,2662
|
15
|
-
bayinx/dists/uniform.py,sha256=PSZIIc2QfNF5XYi-TLGltnr_vnAIA-MZsn1rKV8QXAo,2353
|
16
|
-
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
17
|
-
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
18
|
-
bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
|
19
|
-
bayinx/mhx/vi/normalizing_flow.py,sha256=nj7bpIoMJl6GTOXPxQCAsPArchbHd5vwwqMm3cLbMII,4791
|
20
|
-
bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
|
21
|
-
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
22
|
-
bayinx/mhx/vi/flows/fullaffine.py,sha256=2QbOtA1Jmu-yRcJeFmCKc8N1atm8G7JXYMLEZaEXKV0,1711
|
23
|
-
bayinx/mhx/vi/flows/planar.py,sha256=u9ZVwEeOv4fMfwiORlseCz463atsWMuid_9crRg05Z8,1919
|
24
|
-
bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
|
25
|
-
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
26
|
-
bayinx-0.2.25.dist-info/METADATA,sha256=9TI5NY4M1EBtYwA8E-EDds-QkOECwrfVaaEWq3pTdu4,3058
|
27
|
-
bayinx-0.2.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
28
|
-
bayinx-0.2.25.dist-info/RECORD,,
|
File without changes
|
File without changes
|