bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +4 -3
- bayinx/constraints/identity.py +26 -0
- bayinx/constraints/interval.py +62 -0
- bayinx/constraints/lower.py +31 -24
- bayinx/constraints/upper.py +57 -0
- bayinx/core/__init__.py +0 -7
- bayinx/core/constraint.py +32 -0
- bayinx/core/context.py +42 -0
- bayinx/core/distribution.py +34 -0
- bayinx/core/flow.py +99 -0
- bayinx/core/model.py +228 -0
- bayinx/core/node.py +201 -0
- bayinx/core/types.py +17 -0
- bayinx/core/utils.py +109 -0
- bayinx/core/variational.py +170 -0
- bayinx/dists/__init__.py +5 -3
- bayinx/dists/bernoulli.py +180 -11
- bayinx/dists/binomial.py +215 -0
- bayinx/dists/exponential.py +211 -0
- bayinx/dists/normal.py +131 -59
- bayinx/dists/poisson.py +203 -0
- bayinx/flows/__init__.py +5 -0
- bayinx/flows/diagaffine.py +120 -0
- bayinx/flows/fullaffine.py +123 -0
- bayinx/flows/lowrankaffine.py +165 -0
- bayinx/flows/planar.py +155 -0
- bayinx/flows/radial.py +1 -0
- bayinx/flows/sylvester.py +225 -0
- bayinx/nodes/__init__.py +3 -0
- bayinx/nodes/continuous.py +64 -0
- bayinx/nodes/observed.py +36 -0
- bayinx/nodes/stochastic.py +25 -0
- bayinx/ops.py +104 -0
- bayinx/posterior.py +220 -0
- bayinx/vi/__init__.py +0 -0
- bayinx/{mhx/vi → vi}/meanfield.py +33 -29
- bayinx/vi/normalizing_flow.py +246 -0
- bayinx/vi/standard.py +95 -0
- bayinx-0.5.3.dist-info/METADATA +93 -0
- bayinx-0.5.3.dist-info/RECORD +44 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
- bayinx/core/_constraint.py +0 -28
- bayinx/core/_flow.py +0 -80
- bayinx/core/_model.py +0 -98
- bayinx/core/_parameter.py +0 -44
- bayinx/core/_variational.py +0 -181
- bayinx/dists/censored/__init__.py +0 -3
- bayinx/dists/censored/gamma2/__init__.py +0 -3
- bayinx/dists/censored/gamma2/r.py +0 -68
- bayinx/dists/censored/posnormal/__init__.py +0 -3
- bayinx/dists/censored/posnormal/r.py +0 -116
- bayinx/dists/gamma2.py +0 -49
- bayinx/dists/posnormal.py +0 -260
- bayinx/dists/uniform.py +0 -75
- bayinx/mhx/__init__.py +0 -1
- bayinx/mhx/vi/__init__.py +0 -5
- bayinx/mhx/vi/flows/__init__.py +0 -3
- bayinx/mhx/vi/flows/fullaffine.py +0 -75
- bayinx/mhx/vi/flows/planar.py +0 -74
- bayinx/mhx/vi/flows/radial.py +0 -94
- bayinx/mhx/vi/flows/sylvester.py +0 -19
- bayinx/mhx/vi/normalizing_flow.py +0 -149
- bayinx/mhx/vi/standard.py +0 -63
- bayinx-0.3.10.dist-info/METADATA +0 -39
- bayinx-0.3.10.dist-info/RECORD +0 -35
- /bayinx/{py.typed → flows/otflow.py} +0 -0
- {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
from .core.model import Model as Model
|
|
2
|
+
from .core.model import define as define
|
|
3
|
+
from .posterior import Posterior as Posterior
|
bayinx/constraints/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
from .identity import Identity as Identity
|
|
2
|
+
from .interval import Interval as Interval
|
|
3
|
+
from .lower import Lower as Lower
|
|
4
|
+
from .upper import Upper as Upper
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import PyTree, Scalar
|
|
5
|
+
|
|
6
|
+
from bayinx.core.constraint import Constraint
|
|
7
|
+
from bayinx.core.types import T
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Identity(Constraint):
|
|
11
|
+
"""
|
|
12
|
+
Does nothing.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def constrain(self, obj: T, filter_spec: PyTree) -> Tuple[T, Scalar]:
|
|
16
|
+
"""
|
|
17
|
+
Applies the identity transformation (does nothing) and computes its log-Jacobian adjustment (0).
|
|
18
|
+
|
|
19
|
+
# Returns
|
|
20
|
+
A tuple containing:
|
|
21
|
+
- The same `PyTree`.
|
|
22
|
+
- A scalar `Array` containing 0.
|
|
23
|
+
"""
|
|
24
|
+
log_jac: Scalar = jnp.array(0.0)
|
|
25
|
+
|
|
26
|
+
return obj, log_jac
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from typing import Any, Tuple
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import jax.tree as jt
|
|
5
|
+
from jaxtyping import Scalar, ScalarLike
|
|
6
|
+
|
|
7
|
+
from bayinx.core.constraint import Constraint
|
|
8
|
+
from bayinx.core.types import T
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Interval(Constraint):
|
|
12
|
+
"""
|
|
13
|
+
Enforces that the parameter lies in the (lb, ub) interval using a scaled
|
|
14
|
+
and shifted sigmoid transformation.
|
|
15
|
+
|
|
16
|
+
# Attributes
|
|
17
|
+
- `lb`: The lower bound.
|
|
18
|
+
- `ub`: The upper bound.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
lb: Scalar
|
|
22
|
+
ub: Scalar
|
|
23
|
+
|
|
24
|
+
def __init__(self, lb: ScalarLike, ub: ScalarLike):
|
|
25
|
+
self.lb = jnp.asarray(lb)
|
|
26
|
+
self.ub = jnp.asarray(ub)
|
|
27
|
+
|
|
28
|
+
def constrain(self, obj: T, filter_spec: T) -> Tuple[T, Scalar]:
|
|
29
|
+
"""
|
|
30
|
+
Applies the scaled Sigmoid transformation to the leaves of a `PyTree` and
|
|
31
|
+
computes the log-Jacobian adjustment.
|
|
32
|
+
|
|
33
|
+
# Parameters
|
|
34
|
+
- `obj`: The unconstrained `PyTree` (values are in R).
|
|
35
|
+
|
|
36
|
+
# Returns
|
|
37
|
+
A tuple containing:
|
|
38
|
+
- A `PyTree` with its values `y` now satisfying lb < y < ub.
|
|
39
|
+
- A scalar `Array` containing the log-absolute-Jacobian of the
|
|
40
|
+
transformation.
|
|
41
|
+
"""
|
|
42
|
+
log_jac: Scalar = jnp.array(0.0)
|
|
43
|
+
|
|
44
|
+
def constrain_leaf(leaf: Any, filter: bool):
|
|
45
|
+
nonlocal log_jac # Reference outer variable
|
|
46
|
+
|
|
47
|
+
if filter:
|
|
48
|
+
# Apply transformation
|
|
49
|
+
constrained = self.lb + (self.ub - self.lb) * jnp.exp(leaf) / (1.0 + jnp.exp(leaf))
|
|
50
|
+
|
|
51
|
+
log_jac = log_jac + (jnp.log(constrained - self.lb) +
|
|
52
|
+
jnp.log(self.ub - constrained) -
|
|
53
|
+
jnp.log(self.ub - self.lb)).sum()
|
|
54
|
+
|
|
55
|
+
return constrained
|
|
56
|
+
else:
|
|
57
|
+
return leaf
|
|
58
|
+
|
|
59
|
+
# Constrain leaves
|
|
60
|
+
obj = jt.map(constrain_leaf, obj, filter_spec)
|
|
61
|
+
|
|
62
|
+
return obj, log_jac
|
bayinx/constraints/lower.py
CHANGED
|
@@ -1,50 +1,57 @@
|
|
|
1
|
-
from typing import Tuple
|
|
1
|
+
from typing import Any, Tuple
|
|
2
2
|
|
|
3
|
-
import equinox as eqx
|
|
4
3
|
import jax.numpy as jnp
|
|
5
4
|
import jax.tree as jt
|
|
6
|
-
from jaxtyping import
|
|
5
|
+
from jaxtyping import Scalar, ScalarLike
|
|
7
6
|
|
|
8
|
-
from bayinx.core import Constraint
|
|
7
|
+
from bayinx.core.constraint import Constraint
|
|
8
|
+
from bayinx.core.types import T
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class Lower(Constraint):
|
|
12
12
|
"""
|
|
13
13
|
Enforces a lower bound on the parameter.
|
|
14
|
+
|
|
15
|
+
# Attributes
|
|
16
|
+
- `lb`: The lower bound.
|
|
14
17
|
"""
|
|
15
18
|
|
|
16
19
|
lb: Scalar
|
|
17
20
|
|
|
18
21
|
def __init__(self, lb: ScalarLike):
|
|
19
|
-
self.lb = jnp.
|
|
22
|
+
self.lb = jnp.asarray(lb)
|
|
20
23
|
|
|
21
|
-
|
|
22
|
-
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
|
24
|
+
def constrain(self, obj: T, filter_spec: T) -> Tuple[T, Scalar]:
|
|
23
25
|
"""
|
|
24
|
-
|
|
26
|
+
Applies the exponential transformation to the leaves of a `PyTree` and
|
|
27
|
+
computes the log-Jacobian adjustment of the transformation.
|
|
25
28
|
|
|
26
29
|
# Parameters
|
|
27
|
-
- `x`: The unconstrained `
|
|
30
|
+
- `x`: The unconstrained `PyTree`.
|
|
28
31
|
|
|
29
|
-
#
|
|
32
|
+
# Returns
|
|
30
33
|
A tuple containing:
|
|
31
|
-
- A
|
|
32
|
-
- A scalar Array
|
|
34
|
+
- A `PyTree` with its values `x` now satisfying `lb <= x`.
|
|
35
|
+
- A scalar `Array` containing the log-absolute-Jacobian of the
|
|
36
|
+
transformation.
|
|
33
37
|
"""
|
|
34
|
-
|
|
35
|
-
|
|
38
|
+
log_jac: Scalar = jnp.array(0.0)
|
|
39
|
+
|
|
40
|
+
def constrain_leaf(leaf: Any, filter: bool):
|
|
41
|
+
nonlocal log_jac # Reference outer variable
|
|
36
42
|
|
|
37
|
-
|
|
38
|
-
|
|
43
|
+
if filter:
|
|
44
|
+
# Apply transformation
|
|
45
|
+
constrained = jnp.exp(leaf) + self.lb
|
|
39
46
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
laj: Scalar = jt.reduce(lambda a, b: a + b, laj)
|
|
47
|
+
# Accumulate Jacobian adjustment
|
|
48
|
+
log_jac = log_jac + jnp.sum(leaf)
|
|
43
49
|
|
|
44
|
-
|
|
45
|
-
|
|
50
|
+
return constrained
|
|
51
|
+
else:
|
|
52
|
+
return leaf
|
|
46
53
|
|
|
47
|
-
#
|
|
48
|
-
|
|
54
|
+
# Constrain leaves
|
|
55
|
+
obj = jt.map(constrain_leaf, obj, filter_spec)
|
|
49
56
|
|
|
50
|
-
return
|
|
57
|
+
return obj, log_jac
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import Any, Tuple
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import jax.tree as jt
|
|
5
|
+
from jaxtyping import Scalar, ScalarLike
|
|
6
|
+
|
|
7
|
+
from bayinx.core.constraint import Constraint
|
|
8
|
+
from bayinx.core.types import T
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Upper(Constraint):
|
|
12
|
+
"""
|
|
13
|
+
Enforces an upper bound on the parameter.
|
|
14
|
+
|
|
15
|
+
# Attributes
|
|
16
|
+
- `ub`: The upper bound.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
ub: Scalar
|
|
20
|
+
|
|
21
|
+
def __init__(self, ub: ScalarLike):
|
|
22
|
+
self.ub = jnp.asarray(ub)
|
|
23
|
+
|
|
24
|
+
def constrain(self, obj: T, filter_spec: T) -> Tuple[T, Scalar]:
|
|
25
|
+
"""
|
|
26
|
+
Applies the negated exponential transformation to the leaves of a `PyTree` and
|
|
27
|
+
computes the log-Jacobian adjustment of the transformation.
|
|
28
|
+
|
|
29
|
+
# Parameters
|
|
30
|
+
- `x`: The unconstrained `PyTree`.
|
|
31
|
+
|
|
32
|
+
# Returns
|
|
33
|
+
A tuple containing:
|
|
34
|
+
- A `PyTree` with its values `x` now satisfying `x <= ub`.
|
|
35
|
+
- A scalar `Array` containing the log-absolute-Jacobian of the
|
|
36
|
+
transformation.
|
|
37
|
+
"""
|
|
38
|
+
log_jac: Scalar = jnp.array(0.0)
|
|
39
|
+
|
|
40
|
+
def constrain_leaf(leaf: Any, include: bool):
|
|
41
|
+
nonlocal log_jac # Reference outer variable
|
|
42
|
+
|
|
43
|
+
if include:
|
|
44
|
+
# Apply transformation
|
|
45
|
+
constrained = -jnp.exp(leaf) + self.ub
|
|
46
|
+
|
|
47
|
+
# Accumulate Jacobian adjustment
|
|
48
|
+
log_jac = log_jac + jnp.sum(leaf)
|
|
49
|
+
|
|
50
|
+
return constrained
|
|
51
|
+
else:
|
|
52
|
+
return leaf
|
|
53
|
+
|
|
54
|
+
# Constrain leaves
|
|
55
|
+
obj = jt.map(constrain_leaf, obj, filter_spec)
|
|
56
|
+
|
|
57
|
+
return obj, log_jac
|
bayinx/core/__init__.py
CHANGED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Tuple
|
|
3
|
+
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
from jaxtyping import PyTree, Scalar
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from bayinx.core.types import T
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Constraint(eqx.Module):
|
|
12
|
+
"""
|
|
13
|
+
Abstract base class for defining constraints (for stochastic nodes).
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def constrain(self, obj: "T", filter_spec: PyTree) -> Tuple["T", Scalar]:
|
|
18
|
+
"""
|
|
19
|
+
Applies the constraining transformation to the leaves of a `PyTree` and
|
|
20
|
+
computes the log-Jacobian adjustment of the transformation.
|
|
21
|
+
|
|
22
|
+
# Parameters
|
|
23
|
+
- `x`: The unconstrained values.
|
|
24
|
+
- `filter_spec`: The filter specification for `values`.
|
|
25
|
+
|
|
26
|
+
# Returns
|
|
27
|
+
A tuple containing:
|
|
28
|
+
- A `PyTree` with its leaves now constrained.
|
|
29
|
+
- A scalar `Array` containing the log-Jacobian adjustment of the
|
|
30
|
+
transformation.
|
|
31
|
+
"""
|
|
32
|
+
pass
|
bayinx/core/context.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from jaxtyping import Scalar
|
|
7
|
+
|
|
8
|
+
# Local storage for model context
|
|
9
|
+
_model_context = threading.local()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Target:
|
|
14
|
+
value: Scalar
|
|
15
|
+
|
|
16
|
+
# Define relevant methods to get around lack of explicit mutability in JAX
|
|
17
|
+
def __iadd__(self, other):
|
|
18
|
+
self.value = self.value + other
|
|
19
|
+
return self
|
|
20
|
+
|
|
21
|
+
def __add__(self, other):
|
|
22
|
+
return self.value + other
|
|
23
|
+
|
|
24
|
+
def __radd__(self, other):
|
|
25
|
+
return self.value + other
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@contextmanager
|
|
29
|
+
def model_context():
|
|
30
|
+
"""
|
|
31
|
+
Context manager that sets up the implicit `target` accumulator.
|
|
32
|
+
|
|
33
|
+
This context allows the `<<` operator for Node classes to automatically
|
|
34
|
+
accumulate log probabilities without requiring explicit handling of target.
|
|
35
|
+
"""
|
|
36
|
+
_model_context.target = Target(jnp.array(0.0))
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
yield _model_context.target
|
|
40
|
+
finally: # Remove old context if present
|
|
41
|
+
if hasattr(_model_context, "target"):
|
|
42
|
+
delattr(_model_context, "target")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Protocol, Tuple
|
|
2
|
+
|
|
3
|
+
import jax.random as jr
|
|
4
|
+
from jaxtyping import PRNGKeyArray, Scalar
|
|
5
|
+
|
|
6
|
+
from bayinx.core.node import Node
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Distribution(Protocol):
|
|
10
|
+
"""
|
|
11
|
+
A protocol used for defining the structure of distributions.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def logprob(self, node: Node) -> Scalar: ...
|
|
15
|
+
|
|
16
|
+
def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)): ...
|
|
17
|
+
|
|
18
|
+
def __rlshift__(self, node: Node):
|
|
19
|
+
"""
|
|
20
|
+
Implicitly accumulate the log probability into the current model context.
|
|
21
|
+
"""
|
|
22
|
+
from bayinx.core.context import _model_context
|
|
23
|
+
|
|
24
|
+
# Evaluate log posterior
|
|
25
|
+
log_prob = self.logprob(node)
|
|
26
|
+
|
|
27
|
+
# Accumulate log probability into context
|
|
28
|
+
if hasattr(_model_context, "target"):
|
|
29
|
+
_model_context.target += log_prob
|
|
30
|
+
else:
|
|
31
|
+
raise RuntimeError(
|
|
32
|
+
"Model context doesn't exist. Make sure you're calling "
|
|
33
|
+
+ "this within the 'model' method."
|
|
34
|
+
)
|
bayinx/core/flow.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Callable, Dict, Self, Tuple
|
|
3
|
+
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
import jax.tree_util as jtu
|
|
6
|
+
from jaxtyping import Array, PyTree
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FlowLayer(eqx.Module):
|
|
10
|
+
"""
|
|
11
|
+
An abstract base class for a flow layer.
|
|
12
|
+
|
|
13
|
+
# Attributes
|
|
14
|
+
- `params`: The parameters of the diffeomorphism. # TODO FOR ALL FLOWS
|
|
15
|
+
- `constraints`: The constraining transformations for parameters.
|
|
16
|
+
- `static`: Whether the flow layer is frozen (parameters are not subject to further optimization).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
params: Dict[str, PyTree]
|
|
20
|
+
constraints: Dict[str, Callable[[PyTree], Array]]
|
|
21
|
+
static: bool
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def forward(self, draws: Array) -> Array:
|
|
25
|
+
"""
|
|
26
|
+
Computes the forward transformation of `draws`.
|
|
27
|
+
"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def adjust(self, draws: Array) -> Array:
|
|
32
|
+
"""
|
|
33
|
+
Computes the log-Jacobian adjustment for each draw in `draws`.
|
|
34
|
+
|
|
35
|
+
# Returns
|
|
36
|
+
An array of the log-Jacobian adjustments.
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def forward_and_adjust(self, draws: Array) -> Tuple[Array, Array]:
|
|
42
|
+
"""
|
|
43
|
+
Computes the log-Jacobian adjustment at `draws` and applies the forward transformation.
|
|
44
|
+
|
|
45
|
+
# Returns
|
|
46
|
+
A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
|
|
47
|
+
"""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
# Default filter specification
|
|
51
|
+
@property
|
|
52
|
+
def filter_spec(self) -> "FlowLayer":
|
|
53
|
+
"""
|
|
54
|
+
Generates a filter specification to subset relevant parameters for the flow.
|
|
55
|
+
"""
|
|
56
|
+
# Generate empty specification
|
|
57
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
|
58
|
+
|
|
59
|
+
if self.static is False:
|
|
60
|
+
# Specify parameters
|
|
61
|
+
filter_spec = eqx.tree_at(
|
|
62
|
+
lambda flow: flow.params,
|
|
63
|
+
filter_spec,
|
|
64
|
+
replace=True,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return filter_spec
|
|
68
|
+
|
|
69
|
+
def constrain_params(self: Self) -> Dict[str, PyTree]:
|
|
70
|
+
"""
|
|
71
|
+
Constrain flow parameters to the appropriate domain.
|
|
72
|
+
|
|
73
|
+
# Returns
|
|
74
|
+
The constrained parameters of the diffeomorphism.
|
|
75
|
+
"""
|
|
76
|
+
params = self.params
|
|
77
|
+
|
|
78
|
+
for par, map in self.constraints.items():
|
|
79
|
+
params[par] = map(params[par])
|
|
80
|
+
|
|
81
|
+
return params
|
|
82
|
+
|
|
83
|
+
def transform_params(self: Self) -> Dict[str, PyTree]:
|
|
84
|
+
"""
|
|
85
|
+
Apply a custom transformation to `params` if needed. Defaults to `constrain_params()`.
|
|
86
|
+
|
|
87
|
+
# Returns
|
|
88
|
+
The transformed parameters of the diffeomorphism.
|
|
89
|
+
"""
|
|
90
|
+
return self.constrain_params()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class FlowSpec(ABC):
|
|
94
|
+
"""
|
|
95
|
+
A specification for a flow layer.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def construct(self, dim: int) -> FlowLayer: ...
|
bayinx/core/model.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from dataclasses import field, fields
|
|
3
|
+
from typing import (
|
|
4
|
+
Dict,
|
|
5
|
+
Optional,
|
|
6
|
+
Self,
|
|
7
|
+
Tuple,
|
|
8
|
+
Type,
|
|
9
|
+
get_origin,
|
|
10
|
+
get_type_hints,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
import equinox as eqx
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
import jax.tree as jt
|
|
16
|
+
from jaxtyping import PyTree, Scalar
|
|
17
|
+
|
|
18
|
+
from bayinx.constraints import Identity, Interval, Lower, Upper
|
|
19
|
+
from bayinx.core.context import _model_context, model_context
|
|
20
|
+
from bayinx.core.node import Node
|
|
21
|
+
from bayinx.core.types import HasConstraint
|
|
22
|
+
from bayinx.core.utils import _extract_shape_params, _resolve_shape_spec
|
|
23
|
+
from bayinx.nodes import Continuous, Observed, Stochastic
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def define(
|
|
27
|
+
shape: Optional[int | str | Tuple[int | str, ...]] = None,
|
|
28
|
+
init: Optional[PyTree] = None,
|
|
29
|
+
lower: Optional[float] = None,
|
|
30
|
+
upper: Optional[float] = None
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Define a stochastic node.
|
|
34
|
+
"""
|
|
35
|
+
metadata: Dict = {}
|
|
36
|
+
|
|
37
|
+
if shape is not None:
|
|
38
|
+
metadata["shape"] = shape
|
|
39
|
+
if init is not None:
|
|
40
|
+
metadata["init"] = init
|
|
41
|
+
|
|
42
|
+
match (lower, upper):
|
|
43
|
+
case (float() | int(), None):
|
|
44
|
+
metadata["constraint"] = Lower(lower) # type: ignore
|
|
45
|
+
case (None, float() | int()):
|
|
46
|
+
metadata["constraint"] = Upper(upper) # type: ignore
|
|
47
|
+
case (float() | int(), float() | int()):
|
|
48
|
+
metadata["constraint"] = Interval(lower, upper) # type: ignore
|
|
49
|
+
case (None, None):
|
|
50
|
+
metadata["constraint"] = Identity()
|
|
51
|
+
case (_):
|
|
52
|
+
raise TypeError("TODO.")
|
|
53
|
+
|
|
54
|
+
return field(metadata=metadata)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Model(eqx.Module):
|
|
58
|
+
"""
|
|
59
|
+
A base class used to define probabilistic models.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init_subclass__(cls, **kwargs):
|
|
63
|
+
# Consume 'init' argument before passing it up to Equinox
|
|
64
|
+
kwargs.pop('init', None)
|
|
65
|
+
super().__init_subclass__(**kwargs)
|
|
66
|
+
|
|
67
|
+
def __init__(self, **kwargs):
|
|
68
|
+
cls = self.__class__
|
|
69
|
+
|
|
70
|
+
# Grab initialized parameters
|
|
71
|
+
init_params: set[str] = {f.name for f in fields(cls) if f.name in kwargs.keys()}
|
|
72
|
+
|
|
73
|
+
# Grab shape parameters from model definition
|
|
74
|
+
shape_params: set[str] = set()
|
|
75
|
+
for node_defn in fields(cls):
|
|
76
|
+
if (shape_spec := node_defn.metadata.get("shape")) is not None:
|
|
77
|
+
shape_params = shape_params | _extract_shape_params(shape_spec)
|
|
78
|
+
|
|
79
|
+
# Check all shape parameters are passed as arguments
|
|
80
|
+
if not shape_params.issubset(kwargs.keys()):
|
|
81
|
+
missing_params = shape_params - kwargs.keys()
|
|
82
|
+
raise TypeError(
|
|
83
|
+
f"Following shape parameters were not specified during model initialization: {", ".join(missing_params)}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Define all initialized dimensions
|
|
87
|
+
shape_values: dict = {
|
|
88
|
+
shape_param: kwargs[shape_param]
|
|
89
|
+
for shape_param in shape_params
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# Grab node types
|
|
93
|
+
node_types: dict[str, Type[Node]] = {k: get_origin(v) for k, v in get_type_hints(cls).items()}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# Auto-initialize parameters based on field metadata and type annotations
|
|
97
|
+
for node_defn in fields(cls):
|
|
98
|
+
# Grab node type
|
|
99
|
+
node_type = node_types[node_defn.name]
|
|
100
|
+
|
|
101
|
+
# Grab shape information if available
|
|
102
|
+
shape_spec: str | None = node_defn.metadata.get("shape")
|
|
103
|
+
shape = _resolve_shape_spec(shape_spec, shape_values)
|
|
104
|
+
|
|
105
|
+
# Construct object
|
|
106
|
+
if node_defn.name in init_params: # Initialized in model construction
|
|
107
|
+
obj = kwargs[node_defn.name]
|
|
108
|
+
elif "init" in node_defn.metadata: # Initialized in model definition
|
|
109
|
+
obj = node_defn.metadata["init"]
|
|
110
|
+
elif shape is not None: # Shape defined in model definition
|
|
111
|
+
obj = jnp.zeros(shape) # TODO: will change later for discrete objects
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(f"Node '{node_defn.name}' not initialized or defined.")
|
|
114
|
+
|
|
115
|
+
# Check shape
|
|
116
|
+
if shape is not None and jnp.shape(obj) != shape:
|
|
117
|
+
raise ValueError(f"Expected shape {shape} for {node_defn.name} but got {jnp.shape(obj)}.")
|
|
118
|
+
|
|
119
|
+
if issubclass(node_type, Stochastic):
|
|
120
|
+
if node_type == Continuous:
|
|
121
|
+
setattr(
|
|
122
|
+
self,
|
|
123
|
+
node_defn.name,
|
|
124
|
+
Continuous(obj, node_defn.metadata["constraint"]),
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
TypeError(f"{node_type.__name__} is not implemented yet")
|
|
128
|
+
elif issubclass(node_type, Observed):
|
|
129
|
+
setattr(self, node_defn.name, Observed(obj))
|
|
130
|
+
else:
|
|
131
|
+
raise TypeError(f"{node_defn.name} node is neither Stochastic nor Observed but {node_type.__name__}.")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def filter_spec(self) -> Self:
|
|
136
|
+
"""
|
|
137
|
+
Generates a filter specification to subset stochastic elements of the model.
|
|
138
|
+
"""
|
|
139
|
+
# Generate empty specification
|
|
140
|
+
filter_spec: Self = jt.map(lambda _: False, self)
|
|
141
|
+
|
|
142
|
+
for f in fields(self): # type: ignore
|
|
143
|
+
# Extract attribute
|
|
144
|
+
node: Node = getattr(self, f.name)
|
|
145
|
+
|
|
146
|
+
# Check if attribute is stochastic
|
|
147
|
+
if isinstance(node, Stochastic):
|
|
148
|
+
# Update model's filter specification at node
|
|
149
|
+
filter_spec: Self = eqx.tree_at(
|
|
150
|
+
lambda model: getattr(model, f.name),
|
|
151
|
+
filter_spec,
|
|
152
|
+
replace=node.filter_spec
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return filter_spec
|
|
156
|
+
|
|
157
|
+
def filter_for(self, node_type: Type[Stochastic]) -> Self:
|
|
158
|
+
"""
|
|
159
|
+
Generates a filter specification to subset stochastic elements of a certain type of the model.
|
|
160
|
+
"""
|
|
161
|
+
# Generate empty specification
|
|
162
|
+
filter_spec: Self = jt.map(lambda _: False, self)
|
|
163
|
+
|
|
164
|
+
for f in fields(self): # type: ignore
|
|
165
|
+
# Extract node
|
|
166
|
+
node: Node = getattr(self, f.name)
|
|
167
|
+
|
|
168
|
+
if isinstance(node, node_type):
|
|
169
|
+
# Update model's filter specification for node
|
|
170
|
+
filter_spec: Self = eqx.tree_at(
|
|
171
|
+
lambda model: getattr(model, f.name),
|
|
172
|
+
filter_spec,
|
|
173
|
+
replace=node.filter_spec,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return filter_spec
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def constrain(self, jacobian: bool = True) -> Tuple[Self, Scalar]:
|
|
180
|
+
"""
|
|
181
|
+
Constrain nodes to the appropriate domain.
|
|
182
|
+
|
|
183
|
+
# Returns
|
|
184
|
+
A tuple containing the constrained `Model` object and the log-Jacobian adjustment.
|
|
185
|
+
"""
|
|
186
|
+
model: Self = self
|
|
187
|
+
target: Scalar = jnp.array(0.0)
|
|
188
|
+
|
|
189
|
+
for f in fields(self): # type: ignore
|
|
190
|
+
# Extract attribute
|
|
191
|
+
node = getattr(self, f.name)
|
|
192
|
+
|
|
193
|
+
# Check if node has a constraint
|
|
194
|
+
if isinstance(node, HasConstraint):
|
|
195
|
+
# Apply constraint
|
|
196
|
+
obj, laj = node._constraint.constrain(node.obj, node._filter_spec)
|
|
197
|
+
|
|
198
|
+
# Update values with constrained counterpart
|
|
199
|
+
model = eqx.tree_at(
|
|
200
|
+
where=lambda model: getattr(model, f.name).obj,
|
|
201
|
+
pytree=model,
|
|
202
|
+
replace=obj,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Adjust posterior density
|
|
206
|
+
if jacobian:
|
|
207
|
+
target += laj
|
|
208
|
+
|
|
209
|
+
return model, target
|
|
210
|
+
|
|
211
|
+
@abstractmethod
|
|
212
|
+
def model(self, target: Scalar) -> Scalar:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
@eqx.filter_jit
|
|
216
|
+
def __call__(self) -> Scalar:
|
|
217
|
+
with model_context():
|
|
218
|
+
# Constrain the model and accumulate Jacobian adjustments
|
|
219
|
+
self, target = self.constrain()
|
|
220
|
+
|
|
221
|
+
# Accumulate manual increments
|
|
222
|
+
target += self.model(jnp.array(0.0))
|
|
223
|
+
|
|
224
|
+
# Accumulate implicit increments
|
|
225
|
+
target += _model_context.target.value
|
|
226
|
+
|
|
227
|
+
# Return the accumulated target
|
|
228
|
+
return target
|