bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from bayinx.core import Model, Parameter, constrain
2
-
3
- __all__ = ["Model", "Parameter", "constrain"]
1
+ from .core.model import Model as Model
2
+ from .core.model import define as define
3
+ from .posterior import Posterior as Posterior
@@ -1,3 +1,4 @@
1
- from bayinx.constraints.lower import Lower
2
-
3
- __all__ = ['Lower']
1
+ from .identity import Identity as Identity
2
+ from .interval import Interval as Interval
3
+ from .lower import Lower as Lower
4
+ from .upper import Upper as Upper
@@ -0,0 +1,26 @@
1
+ from typing import Tuple
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import PyTree, Scalar
5
+
6
+ from bayinx.core.constraint import Constraint
7
+ from bayinx.core.types import T
8
+
9
+
10
+ class Identity(Constraint):
11
+ """
12
+ Does nothing.
13
+ """
14
+
15
+ def constrain(self, obj: T, filter_spec: PyTree) -> Tuple[T, Scalar]:
16
+ """
17
+ Applies the identity transformation (does nothing) and computes its log-Jacobian adjustment (0).
18
+
19
+ # Returns
20
+ A tuple containing:
21
+ - The same `PyTree`.
22
+ - A scalar `Array` containing 0.
23
+ """
24
+ log_jac: Scalar = jnp.array(0.0)
25
+
26
+ return obj, log_jac
@@ -0,0 +1,62 @@
1
+ from typing import Any, Tuple
2
+
3
+ import jax.numpy as jnp
4
+ import jax.tree as jt
5
+ from jaxtyping import Scalar, ScalarLike
6
+
7
+ from bayinx.core.constraint import Constraint
8
+ from bayinx.core.types import T
9
+
10
+
11
+ class Interval(Constraint):
12
+ """
13
+ Enforces that the parameter lies in the (lb, ub) interval using a scaled
14
+ and shifted sigmoid transformation.
15
+
16
+ # Attributes
17
+ - `lb`: The lower bound.
18
+ - `ub`: The upper bound.
19
+ """
20
+
21
+ lb: Scalar
22
+ ub: Scalar
23
+
24
+ def __init__(self, lb: ScalarLike, ub: ScalarLike):
25
+ self.lb = jnp.asarray(lb)
26
+ self.ub = jnp.asarray(ub)
27
+
28
+ def constrain(self, obj: T, filter_spec: T) -> Tuple[T, Scalar]:
29
+ """
30
+ Applies the scaled Sigmoid transformation to the leaves of a `PyTree` and
31
+ computes the log-Jacobian adjustment.
32
+
33
+ # Parameters
34
+ - `obj`: The unconstrained `PyTree` (values are in R).
35
+
36
+ # Returns
37
+ A tuple containing:
38
+ - A `PyTree` with its values `y` now satisfying lb < y < ub.
39
+ - A scalar `Array` containing the log-absolute-Jacobian of the
40
+ transformation.
41
+ """
42
+ log_jac: Scalar = jnp.array(0.0)
43
+
44
+ def constrain_leaf(leaf: Any, filter: bool):
45
+ nonlocal log_jac # Reference outer variable
46
+
47
+ if filter:
48
+ # Apply transformation
49
+ constrained = self.lb + (self.ub - self.lb) * jnp.exp(leaf) / (1.0 + jnp.exp(leaf))
50
+
51
+ log_jac = log_jac + (jnp.log(constrained - self.lb) +
52
+ jnp.log(self.ub - constrained) -
53
+ jnp.log(self.ub - self.lb)).sum()
54
+
55
+ return constrained
56
+ else:
57
+ return leaf
58
+
59
+ # Constrain leaves
60
+ obj = jt.map(constrain_leaf, obj, filter_spec)
61
+
62
+ return obj, log_jac
@@ -1,50 +1,57 @@
1
- from typing import Tuple
1
+ from typing import Any, Tuple
2
2
 
3
- import equinox as eqx
4
3
  import jax.numpy as jnp
5
4
  import jax.tree as jt
6
- from jaxtyping import PyTree, Scalar, ScalarLike
5
+ from jaxtyping import Scalar, ScalarLike
7
6
 
8
- from bayinx.core import Constraint, Parameter
7
+ from bayinx.core.constraint import Constraint
8
+ from bayinx.core.types import T
9
9
 
10
10
 
11
11
  class Lower(Constraint):
12
12
  """
13
13
  Enforces a lower bound on the parameter.
14
+
15
+ # Attributes
16
+ - `lb`: The lower bound.
14
17
  """
15
18
 
16
19
  lb: Scalar
17
20
 
18
21
  def __init__(self, lb: ScalarLike):
19
- self.lb = jnp.array(lb)
22
+ self.lb = jnp.asarray(lb)
20
23
 
21
- @eqx.filter_jit
22
- def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
24
+ def constrain(self, obj: T, filter_spec: T) -> Tuple[T, Scalar]:
23
25
  """
24
- Enforces a lower bound on the parameter and adjusts the posterior density.
26
+ Applies the exponential transformation to the leaves of a `PyTree` and
27
+ computes the log-Jacobian adjustment of the transformation.
25
28
 
26
29
  # Parameters
27
- - `x`: The unconstrained `Parameter`.
30
+ - `x`: The unconstrained `PyTree`.
28
31
 
29
- # Parameters
32
+ # Returns
30
33
  A tuple containing:
31
- - A modified `Parameter` with relevant leaves satisfying the constraint.
32
- - A scalar Array representing the log-absolute-Jacobian of the transformation.
34
+ - A `PyTree` with its values `x` now satisfying `lb <= x`.
35
+ - A scalar `Array` containing the log-absolute-Jacobian of the
36
+ transformation.
33
37
  """
34
- # Extract relevant filter specification
35
- filter_spec = x.filter_spec
38
+ log_jac: Scalar = jnp.array(0.0)
39
+
40
+ def constrain_leaf(leaf: Any, filter: bool):
41
+ nonlocal log_jac # Reference outer variable
36
42
 
37
- # Extract relevant parameters(all Array)
38
- dyn_params, static_params = eqx.partition(x, filter_spec)
43
+ if filter:
44
+ # Apply transformation
45
+ constrained = jnp.exp(leaf) + self.lb
39
46
 
40
- # Compute density adjustment
41
- laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
42
- laj: Scalar = jt.reduce(lambda a, b: a + b, laj)
47
+ # Accumulate Jacobian adjustment
48
+ log_jac = log_jac + jnp.sum(leaf)
43
49
 
44
- # Compute transformation
45
- dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
50
+ return constrained
51
+ else:
52
+ return leaf
46
53
 
47
- # Combine into full parameter object
48
- x = eqx.combine(dyn_params, static_params)
54
+ # Constrain leaves
55
+ obj = jt.map(constrain_leaf, obj, filter_spec)
49
56
 
50
- return x, laj
57
+ return obj, log_jac
@@ -0,0 +1,57 @@
1
+ from typing import Any, Tuple
2
+
3
+ import jax.numpy as jnp
4
+ import jax.tree as jt
5
+ from jaxtyping import Scalar, ScalarLike
6
+
7
+ from bayinx.core.constraint import Constraint
8
+ from bayinx.core.types import T
9
+
10
+
11
+ class Upper(Constraint):
12
+ """
13
+ Enforces an upper bound on the parameter.
14
+
15
+ # Attributes
16
+ - `ub`: The upper bound.
17
+ """
18
+
19
+ ub: Scalar
20
+
21
+ def __init__(self, ub: ScalarLike):
22
+ self.ub = jnp.asarray(ub)
23
+
24
+ def constrain(self, obj: T, filter_spec: T) -> Tuple[T, Scalar]:
25
+ """
26
+ Applies the negated exponential transformation to the leaves of a `PyTree` and
27
+ computes the log-Jacobian adjustment of the transformation.
28
+
29
+ # Parameters
30
+ - `x`: The unconstrained `PyTree`.
31
+
32
+ # Returns
33
+ A tuple containing:
34
+ - A `PyTree` with its values `x` now satisfying `x <= ub`.
35
+ - A scalar `Array` containing the log-absolute-Jacobian of the
36
+ transformation.
37
+ """
38
+ log_jac: Scalar = jnp.array(0.0)
39
+
40
+ def constrain_leaf(leaf: Any, include: bool):
41
+ nonlocal log_jac # Reference outer variable
42
+
43
+ if include:
44
+ # Apply transformation
45
+ constrained = -jnp.exp(leaf) + self.ub
46
+
47
+ # Accumulate Jacobian adjustment
48
+ log_jac = log_jac + jnp.sum(leaf)
49
+
50
+ return constrained
51
+ else:
52
+ return leaf
53
+
54
+ # Constrain leaves
55
+ obj = jt.map(constrain_leaf, obj, filter_spec)
56
+
57
+ return obj, log_jac
bayinx/core/__init__.py CHANGED
@@ -1,7 +0,0 @@
1
- from ._constraint import Constraint
2
- from ._flow import Flow
3
- from ._model import Model, constrain
4
- from ._parameter import Parameter
5
- from ._variational import Variational
6
-
7
- __all__ = ["Constraint", "Flow", "Model", "constrain", "Parameter", "Variational"]
@@ -0,0 +1,32 @@
1
+ from abc import abstractmethod
2
+ from typing import TYPE_CHECKING, Tuple
3
+
4
+ import equinox as eqx
5
+ from jaxtyping import PyTree, Scalar
6
+
7
+ if TYPE_CHECKING:
8
+ from bayinx.core.types import T
9
+
10
+
11
+ class Constraint(eqx.Module):
12
+ """
13
+ Abstract base class for defining constraints (for stochastic nodes).
14
+ """
15
+
16
+ @abstractmethod
17
+ def constrain(self, obj: "T", filter_spec: PyTree) -> Tuple["T", Scalar]:
18
+ """
19
+ Applies the constraining transformation to the leaves of a `PyTree` and
20
+ computes the log-Jacobian adjustment of the transformation.
21
+
22
+ # Parameters
23
+ - `x`: The unconstrained values.
24
+ - `filter_spec`: The filter specification for `values`.
25
+
26
+ # Returns
27
+ A tuple containing:
28
+ - A `PyTree` with its leaves now constrained.
29
+ - A scalar `Array` containing the log-Jacobian adjustment of the
30
+ transformation.
31
+ """
32
+ pass
bayinx/core/context.py ADDED
@@ -0,0 +1,42 @@
1
+ import threading
2
+ from contextlib import contextmanager
3
+ from dataclasses import dataclass
4
+
5
+ import jax.numpy as jnp
6
+ from jaxtyping import Scalar
7
+
8
+ # Local storage for model context
9
+ _model_context = threading.local()
10
+
11
+
12
+ @dataclass
13
+ class Target:
14
+ value: Scalar
15
+
16
+ # Define relevant methods to get around lack of explicit mutability in JAX
17
+ def __iadd__(self, other):
18
+ self.value = self.value + other
19
+ return self
20
+
21
+ def __add__(self, other):
22
+ return self.value + other
23
+
24
+ def __radd__(self, other):
25
+ return self.value + other
26
+
27
+
28
+ @contextmanager
29
+ def model_context():
30
+ """
31
+ Context manager that sets up the implicit `target` accumulator.
32
+
33
+ This context allows the `<<` operator for Node classes to automatically
34
+ accumulate log probabilities without requiring explicit handling of target.
35
+ """
36
+ _model_context.target = Target(jnp.array(0.0))
37
+
38
+ try:
39
+ yield _model_context.target
40
+ finally: # Remove old context if present
41
+ if hasattr(_model_context, "target"):
42
+ delattr(_model_context, "target")
@@ -0,0 +1,34 @@
1
+ from typing import Protocol, Tuple
2
+
3
+ import jax.random as jr
4
+ from jaxtyping import PRNGKeyArray, Scalar
5
+
6
+ from bayinx.core.node import Node
7
+
8
+
9
+ class Distribution(Protocol):
10
+ """
11
+ A protocol used for defining the structure of distributions.
12
+ """
13
+
14
+ def logprob(self, node: Node) -> Scalar: ...
15
+
16
+ def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)): ...
17
+
18
+ def __rlshift__(self, node: Node):
19
+ """
20
+ Implicitly accumulate the log probability into the current model context.
21
+ """
22
+ from bayinx.core.context import _model_context
23
+
24
+ # Evaluate log posterior
25
+ log_prob = self.logprob(node)
26
+
27
+ # Accumulate log probability into context
28
+ if hasattr(_model_context, "target"):
29
+ _model_context.target += log_prob
30
+ else:
31
+ raise RuntimeError(
32
+ "Model context doesn't exist. Make sure you're calling "
33
+ + "this within the 'model' method."
34
+ )
bayinx/core/flow.py ADDED
@@ -0,0 +1,99 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Callable, Dict, Self, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax.tree_util as jtu
6
+ from jaxtyping import Array, PyTree
7
+
8
+
9
+ class FlowLayer(eqx.Module):
10
+ """
11
+ An abstract base class for a flow layer.
12
+
13
+ # Attributes
14
+ - `params`: The parameters of the diffeomorphism. # TODO FOR ALL FLOWS
15
+ - `constraints`: The constraining transformations for parameters.
16
+ - `static`: Whether the flow layer is frozen (parameters are not subject to further optimization).
17
+ """
18
+
19
+ params: Dict[str, PyTree]
20
+ constraints: Dict[str, Callable[[PyTree], Array]]
21
+ static: bool
22
+
23
+ @abstractmethod
24
+ def forward(self, draws: Array) -> Array:
25
+ """
26
+ Computes the forward transformation of `draws`.
27
+ """
28
+ pass
29
+
30
+ @abstractmethod
31
+ def adjust(self, draws: Array) -> Array:
32
+ """
33
+ Computes the log-Jacobian adjustment for each draw in `draws`.
34
+
35
+ # Returns
36
+ An array of the log-Jacobian adjustments.
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def forward_and_adjust(self, draws: Array) -> Tuple[Array, Array]:
42
+ """
43
+ Computes the log-Jacobian adjustment at `draws` and applies the forward transformation.
44
+
45
+ # Returns
46
+ A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
47
+ """
48
+ pass
49
+
50
+ # Default filter specification
51
+ @property
52
+ def filter_spec(self) -> "FlowLayer":
53
+ """
54
+ Generates a filter specification to subset relevant parameters for the flow.
55
+ """
56
+ # Generate empty specification
57
+ filter_spec = jtu.tree_map(lambda _: False, self)
58
+
59
+ if self.static is False:
60
+ # Specify parameters
61
+ filter_spec = eqx.tree_at(
62
+ lambda flow: flow.params,
63
+ filter_spec,
64
+ replace=True,
65
+ )
66
+
67
+ return filter_spec
68
+
69
+ def constrain_params(self: Self) -> Dict[str, PyTree]:
70
+ """
71
+ Constrain flow parameters to the appropriate domain.
72
+
73
+ # Returns
74
+ The constrained parameters of the diffeomorphism.
75
+ """
76
+ params = self.params
77
+
78
+ for par, map in self.constraints.items():
79
+ params[par] = map(params[par])
80
+
81
+ return params
82
+
83
+ def transform_params(self: Self) -> Dict[str, PyTree]:
84
+ """
85
+ Apply a custom transformation to `params` if needed. Defaults to `constrain_params()`.
86
+
87
+ # Returns
88
+ The transformed parameters of the diffeomorphism.
89
+ """
90
+ return self.constrain_params()
91
+
92
+
93
+ class FlowSpec(ABC):
94
+ """
95
+ A specification for a flow layer.
96
+ """
97
+
98
+ @abstractmethod
99
+ def construct(self, dim: int) -> FlowLayer: ...
bayinx/core/model.py ADDED
@@ -0,0 +1,228 @@
1
+ from abc import abstractmethod
2
+ from dataclasses import field, fields
3
+ from typing import (
4
+ Dict,
5
+ Optional,
6
+ Self,
7
+ Tuple,
8
+ Type,
9
+ get_origin,
10
+ get_type_hints,
11
+ )
12
+
13
+ import equinox as eqx
14
+ import jax.numpy as jnp
15
+ import jax.tree as jt
16
+ from jaxtyping import PyTree, Scalar
17
+
18
+ from bayinx.constraints import Identity, Interval, Lower, Upper
19
+ from bayinx.core.context import _model_context, model_context
20
+ from bayinx.core.node import Node
21
+ from bayinx.core.types import HasConstraint
22
+ from bayinx.core.utils import _extract_shape_params, _resolve_shape_spec
23
+ from bayinx.nodes import Continuous, Observed, Stochastic
24
+
25
+
26
+ def define(
27
+ shape: Optional[int | str | Tuple[int | str, ...]] = None,
28
+ init: Optional[PyTree] = None,
29
+ lower: Optional[float] = None,
30
+ upper: Optional[float] = None
31
+ ):
32
+ """
33
+ Define a stochastic node.
34
+ """
35
+ metadata: Dict = {}
36
+
37
+ if shape is not None:
38
+ metadata["shape"] = shape
39
+ if init is not None:
40
+ metadata["init"] = init
41
+
42
+ match (lower, upper):
43
+ case (float() | int(), None):
44
+ metadata["constraint"] = Lower(lower) # type: ignore
45
+ case (None, float() | int()):
46
+ metadata["constraint"] = Upper(upper) # type: ignore
47
+ case (float() | int(), float() | int()):
48
+ metadata["constraint"] = Interval(lower, upper) # type: ignore
49
+ case (None, None):
50
+ metadata["constraint"] = Identity()
51
+ case (_):
52
+ raise TypeError("TODO.")
53
+
54
+ return field(metadata=metadata)
55
+
56
+
57
+ class Model(eqx.Module):
58
+ """
59
+ A base class used to define probabilistic models.
60
+ """
61
+
62
+ def __init_subclass__(cls, **kwargs):
63
+ # Consume 'init' argument before passing it up to Equinox
64
+ kwargs.pop('init', None)
65
+ super().__init_subclass__(**kwargs)
66
+
67
+ def __init__(self, **kwargs):
68
+ cls = self.__class__
69
+
70
+ # Grab initialized parameters
71
+ init_params: set[str] = {f.name for f in fields(cls) if f.name in kwargs.keys()}
72
+
73
+ # Grab shape parameters from model definition
74
+ shape_params: set[str] = set()
75
+ for node_defn in fields(cls):
76
+ if (shape_spec := node_defn.metadata.get("shape")) is not None:
77
+ shape_params = shape_params | _extract_shape_params(shape_spec)
78
+
79
+ # Check all shape parameters are passed as arguments
80
+ if not shape_params.issubset(kwargs.keys()):
81
+ missing_params = shape_params - kwargs.keys()
82
+ raise TypeError(
83
+ f"Following shape parameters were not specified during model initialization: {", ".join(missing_params)}"
84
+ )
85
+
86
+ # Define all initialized dimensions
87
+ shape_values: dict = {
88
+ shape_param: kwargs[shape_param]
89
+ for shape_param in shape_params
90
+ }
91
+
92
+ # Grab node types
93
+ node_types: dict[str, Type[Node]] = {k: get_origin(v) for k, v in get_type_hints(cls).items()}
94
+
95
+
96
+ # Auto-initialize parameters based on field metadata and type annotations
97
+ for node_defn in fields(cls):
98
+ # Grab node type
99
+ node_type = node_types[node_defn.name]
100
+
101
+ # Grab shape information if available
102
+ shape_spec: str | None = node_defn.metadata.get("shape")
103
+ shape = _resolve_shape_spec(shape_spec, shape_values)
104
+
105
+ # Construct object
106
+ if node_defn.name in init_params: # Initialized in model construction
107
+ obj = kwargs[node_defn.name]
108
+ elif "init" in node_defn.metadata: # Initialized in model definition
109
+ obj = node_defn.metadata["init"]
110
+ elif shape is not None: # Shape defined in model definition
111
+ obj = jnp.zeros(shape) # TODO: will change later for discrete objects
112
+ else:
113
+ raise ValueError(f"Node '{node_defn.name}' not initialized or defined.")
114
+
115
+ # Check shape
116
+ if shape is not None and jnp.shape(obj) != shape:
117
+ raise ValueError(f"Expected shape {shape} for {node_defn.name} but got {jnp.shape(obj)}.")
118
+
119
+ if issubclass(node_type, Stochastic):
120
+ if node_type == Continuous:
121
+ setattr(
122
+ self,
123
+ node_defn.name,
124
+ Continuous(obj, node_defn.metadata["constraint"]),
125
+ )
126
+ else:
127
+ TypeError(f"{node_type.__name__} is not implemented yet")
128
+ elif issubclass(node_type, Observed):
129
+ setattr(self, node_defn.name, Observed(obj))
130
+ else:
131
+ raise TypeError(f"{node_defn.name} node is neither Stochastic nor Observed but {node_type.__name__}.")
132
+
133
+
134
+ @property
135
+ def filter_spec(self) -> Self:
136
+ """
137
+ Generates a filter specification to subset stochastic elements of the model.
138
+ """
139
+ # Generate empty specification
140
+ filter_spec: Self = jt.map(lambda _: False, self)
141
+
142
+ for f in fields(self): # type: ignore
143
+ # Extract attribute
144
+ node: Node = getattr(self, f.name)
145
+
146
+ # Check if attribute is stochastic
147
+ if isinstance(node, Stochastic):
148
+ # Update model's filter specification at node
149
+ filter_spec: Self = eqx.tree_at(
150
+ lambda model: getattr(model, f.name),
151
+ filter_spec,
152
+ replace=node.filter_spec
153
+ )
154
+
155
+ return filter_spec
156
+
157
+ def filter_for(self, node_type: Type[Stochastic]) -> Self:
158
+ """
159
+ Generates a filter specification to subset stochastic elements of a certain type of the model.
160
+ """
161
+ # Generate empty specification
162
+ filter_spec: Self = jt.map(lambda _: False, self)
163
+
164
+ for f in fields(self): # type: ignore
165
+ # Extract node
166
+ node: Node = getattr(self, f.name)
167
+
168
+ if isinstance(node, node_type):
169
+ # Update model's filter specification for node
170
+ filter_spec: Self = eqx.tree_at(
171
+ lambda model: getattr(model, f.name),
172
+ filter_spec,
173
+ replace=node.filter_spec,
174
+ )
175
+
176
+ return filter_spec
177
+
178
+
179
+ def constrain(self, jacobian: bool = True) -> Tuple[Self, Scalar]:
180
+ """
181
+ Constrain nodes to the appropriate domain.
182
+
183
+ # Returns
184
+ A tuple containing the constrained `Model` object and the log-Jacobian adjustment.
185
+ """
186
+ model: Self = self
187
+ target: Scalar = jnp.array(0.0)
188
+
189
+ for f in fields(self): # type: ignore
190
+ # Extract attribute
191
+ node = getattr(self, f.name)
192
+
193
+ # Check if node has a constraint
194
+ if isinstance(node, HasConstraint):
195
+ # Apply constraint
196
+ obj, laj = node._constraint.constrain(node.obj, node._filter_spec)
197
+
198
+ # Update values with constrained counterpart
199
+ model = eqx.tree_at(
200
+ where=lambda model: getattr(model, f.name).obj,
201
+ pytree=model,
202
+ replace=obj,
203
+ )
204
+
205
+ # Adjust posterior density
206
+ if jacobian:
207
+ target += laj
208
+
209
+ return model, target
210
+
211
+ @abstractmethod
212
+ def model(self, target: Scalar) -> Scalar:
213
+ pass
214
+
215
+ @eqx.filter_jit
216
+ def __call__(self) -> Scalar:
217
+ with model_context():
218
+ # Constrain the model and accumulate Jacobian adjustments
219
+ self, target = self.constrain()
220
+
221
+ # Accumulate manual increments
222
+ target += self.model(jnp.array(0.0))
223
+
224
+ # Accumulate implicit increments
225
+ target += _model_context.target.value
226
+
227
+ # Return the accumulated target
228
+ return target