bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/core/node.py ADDED
@@ -0,0 +1,201 @@
1
+ from typing import Any, Generic, Iterator, Self
2
+
3
+ import equinox as eqx
4
+ import jax.tree as jt
5
+ from jaxtyping import PyTree
6
+
7
+ from bayinx.core.types import T
8
+ from bayinx.core.utils import _extract_obj, _merge_filter_specs
9
+
10
+
11
+ class Node(eqx.Module, Generic[T]):
12
+ """
13
+ A thin wrapper for nodes of a probabilistic model.
14
+
15
+
16
+ # Attributes
17
+ - `obj`: An internal realization of the node.
18
+ - `_filter_spec`: An internal filter specification for `obj`.
19
+ """
20
+
21
+ obj: T
22
+ _filter_spec: PyTree
23
+
24
+ def __init__(self, obj, filter_spec):
25
+ self.obj = obj
26
+ self._filter_spec = filter_spec
27
+
28
+ @property
29
+ def filter_spec(self) -> Self:
30
+ """
31
+ An outer filter specification for the full node.
32
+ """
33
+ # Generate empty specification
34
+ node_filter_spec: Self = jt.map(lambda _: False, self)
35
+
36
+ # Filter based on inner filter specification for 'obj'
37
+ node_filter_spec = eqx.tree_at(
38
+ lambda node: node.obj,
39
+ node_filter_spec,
40
+ replace=self._filter_spec,
41
+ )
42
+
43
+ return node_filter_spec
44
+
45
+ # Wrappers around internal dunder methods ----
46
+ def __getitem__(self, key: Any) -> "Node":
47
+ if isinstance(key, Node):
48
+ raise TypeError("Subsetting nodes with nodes is not yet supported.")
49
+
50
+ # Subset internally
51
+ new_obj = self.obj[key]
52
+ if type(self.obj) is type(self._filter_spec):
53
+ new_filter_spec = self._filter_spec[key]
54
+ else:
55
+ new_filter_spec = self._filter_spec
56
+
57
+ # Create new subsetted node
58
+ return type(self)(new_obj, new_filter_spec)
59
+
60
+ def __iter__(self) -> Iterator["Node"]:
61
+ for obj_i, spec_i in zip(self.obj, self._filter_spec):
62
+ # Create a new Node for the current element
63
+ yield Node(obj_i, spec_i)
64
+
65
+
66
+ ## Arithmetic ----
67
+ def __add__(self, other: Any) -> "Node":
68
+ # Extract internal objects and their filter specifications
69
+ lhs_obj, lhs_filter_spec = _extract_obj(self)
70
+ rhs_obj, rhs_filter_spec = _extract_obj(other)
71
+
72
+ # Perform addition
73
+ new_obj = lhs_obj + rhs_obj
74
+
75
+ # Merge filter specifications
76
+ new_filter_spec = _merge_filter_specs(
77
+ [lhs_filter_spec, rhs_filter_spec],
78
+ [lhs_obj, rhs_obj],
79
+ new_obj
80
+ )
81
+
82
+ return Node(new_obj, new_filter_spec)
83
+
84
+ def __sub__(self, other: Any) -> "Node":
85
+ # Extract internal objects and their filter specifications
86
+ lhs_obj, lhs_filter_spec = _extract_obj(self)
87
+ rhs_obj, rhs_filter_spec = _extract_obj(other)
88
+
89
+ # Perform subtraction
90
+ new_obj = lhs_obj - rhs_obj
91
+
92
+ # Merge filter specifications
93
+ new_filter_spec = _merge_filter_specs(
94
+ [lhs_filter_spec, rhs_filter_spec],
95
+ [lhs_obj, rhs_obj],
96
+ new_obj
97
+ )
98
+
99
+ return Node(new_obj, new_filter_spec)
100
+
101
+ def __mul__(self, other: Any) -> "Node":
102
+ # Extract internal objects and their filter specifications
103
+ lhs_obj, lhs_filter_spec = _extract_obj(self)
104
+ rhs_obj, rhs_filter_spec = _extract_obj(other)
105
+
106
+ # Perform multiplication
107
+ new_obj = lhs_obj * rhs_obj
108
+
109
+ # Merge filter specifications
110
+ new_filter_spec = _merge_filter_specs(
111
+ [lhs_filter_spec, rhs_filter_spec],
112
+ [lhs_obj, rhs_obj],
113
+ new_obj
114
+ )
115
+
116
+ return Node(new_obj, new_filter_spec)
117
+
118
+ def __matmul__(self, other: Any) -> "Node":
119
+ # Extract internal objects and their filter specifications
120
+ lhs_obj, lhs_filter_spec = _extract_obj(self)
121
+ rhs_obj, rhs_filter_spec = _extract_obj(other)
122
+
123
+ # Perform matrix multiplication
124
+ new_obj = lhs_obj @ rhs_obj
125
+
126
+ # Merge filter specifications
127
+ new_filter_spec = _merge_filter_specs(
128
+ [lhs_filter_spec, rhs_filter_spec],
129
+ [lhs_obj, rhs_obj],
130
+ new_obj
131
+ )
132
+
133
+ return Node(new_obj, new_filter_spec)
134
+
135
+ def __truediv__(self, other: Any) -> "Node":
136
+ # Extract internal objects and their filter specifications
137
+ lhs_obj, lhs_filter_spec = _extract_obj(self)
138
+ rhs_obj, rhs_filter_spec = _extract_obj(other)
139
+
140
+ # Perform true division
141
+ new_obj = lhs_obj / rhs_obj
142
+
143
+ # Merge filter specifications
144
+ new_filter_spec = _merge_filter_specs(
145
+ [lhs_filter_spec, rhs_filter_spec],
146
+ [lhs_obj, rhs_obj],
147
+ new_obj
148
+ )
149
+
150
+ return Node(new_obj, new_filter_spec)
151
+
152
+ def __floordiv__(self, other: Any) -> "Node":
153
+ # Extract internal objects and their filter specifications
154
+ lhs_obj, lhs_filter_spec = _extract_obj(self)
155
+ rhs_obj, rhs_filter_spec = _extract_obj(other)
156
+
157
+ # Perform floor division
158
+ new_obj = lhs_obj // rhs_obj
159
+
160
+ # Merge filter specifications
161
+ new_filter_spec = _merge_filter_specs(
162
+ [lhs_filter_spec, rhs_filter_spec],
163
+ [lhs_obj, rhs_obj],
164
+ new_obj
165
+ )
166
+
167
+ return Node(new_obj, new_filter_spec)
168
+
169
+ def __pow__(self, other: Any) -> "Node":
170
+ # Extract internal objects and their filter specifications
171
+ lhs_obj, lhs_filter_spec = _extract_obj(self)
172
+ rhs_obj, rhs_filter_spec = _extract_obj(other)
173
+
174
+ # Perform floor division
175
+ new_obj = lhs_obj ** rhs_obj
176
+
177
+ # Merge filter specifications
178
+ new_filter_spec = _merge_filter_specs(
179
+ [lhs_filter_spec, rhs_filter_spec],
180
+ [lhs_obj, rhs_obj],
181
+ new_obj
182
+ )
183
+
184
+ return Node(new_obj, new_filter_spec)
185
+
186
+ def __mod__(self, other: Any) -> "Node":
187
+ # Extract internal objects and their filter specifications
188
+ lhs_obj, lhs_filter_spec = _extract_obj(self)
189
+ rhs_obj, rhs_filter_spec = _extract_obj(other)
190
+
191
+ # Perform modulus
192
+ new_obj = lhs_obj % rhs_obj
193
+
194
+ # Merge filter specifications
195
+ new_filter_spec = _merge_filter_specs(
196
+ [lhs_filter_spec, rhs_filter_spec],
197
+ [lhs_obj, rhs_obj],
198
+ new_obj
199
+ )
200
+
201
+ return Node(new_obj, new_filter_spec)
bayinx/core/types.py ADDED
@@ -0,0 +1,17 @@
1
+ from typing import Generic, Protocol, TypeVar, runtime_checkable
2
+
3
+ from jaxtyping import PyTree
4
+
5
+ from bayinx.core.constraint import Constraint
6
+
7
+ T = TypeVar("T", bound=PyTree)
8
+
9
+ @runtime_checkable
10
+ class HasConstraint(Protocol, Generic[T]):
11
+ """
12
+ Protocol for probabilistic nodes that have constraints.
13
+ """
14
+
15
+ obj: T
16
+ _filter_spec: PyTree
17
+ _constraint: Constraint
bayinx/core/utils.py ADDED
@@ -0,0 +1,109 @@
1
+
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import jax.tree as jt
5
+ from jaxtyping import PyTree
6
+
7
+
8
+ def _extract_shape_params(shape_spec: int | str | Tuple[int | str, ...]) -> set[str]:
9
+ """
10
+ Extract parameter names from shape specification.
11
+ """
12
+ params = set()
13
+ if isinstance(shape_spec, str):
14
+ params.add(shape_spec)
15
+ elif isinstance(shape_spec, tuple):
16
+ for item in shape_spec:
17
+ if isinstance(item, str):
18
+ params.add(item)
19
+ #
20
+ return params
21
+
22
+ def _resolve_shape_spec(
23
+ shape_spec: None | int | str | Tuple[int | str, ...],
24
+ shape_values: Dict[str, int]
25
+ ) -> None | Tuple[int, ...]:
26
+ """
27
+ Replaces named dimensions in a shape specification with their integer or tuple values.
28
+
29
+ # Example
30
+ For `shape_values = {'k': 5, 's': (3,2,1)}`:
31
+ `('k', 4, 's')` --> `(5, 4, 3, 2, 1)`
32
+ """
33
+ if shape_spec is None:
34
+ return None
35
+
36
+ # Coerce to tuple for uniform processing
37
+ if isinstance(shape_spec, (int, str)):
38
+ shape_spec = (shape_spec,)
39
+
40
+ resolved_spec: List[int] = []
41
+ for dim in shape_spec:
42
+ if isinstance(dim, str):
43
+ if dim in shape_values:
44
+ resolved_value = shape_values[dim] # Grab initialized value
45
+
46
+ if isinstance(resolved_value, int):
47
+ # Scalar shape dimension (e.g., 'k' -> 3)
48
+ resolved_spec.append(resolved_value)
49
+ elif isinstance(resolved_value, tuple):
50
+ # Packed shape dimension (e.g., 'shape' -> (3, 2, 1))
51
+ resolved_spec.extend(resolved_value)
52
+ else:
53
+ raise TypeError(f"Shape parameter '{dim}' resolved to an unsupported type: {type(resolved_value).__name__}")
54
+ else: # dim not in shape_values
55
+ raise TypeError(f"Shape parameter '{dim}' was not initialized with a value.")
56
+ elif isinstance(dim, int):
57
+ # Literal integer (e.g., 3 -> 3)
58
+ resolved_spec.append(dim)
59
+ else:
60
+ raise TypeError(f"Shape parameter {dim} was incorrectly specified (must be 'int' or 'str', got '{type(dim).__name__}').")
61
+
62
+ return tuple(resolved_spec)
63
+
64
+ def _extract_obj(x: Any) -> Tuple[Any, Any]:
65
+ """
66
+ Extract the object and its (potentially implicit) filter specification.
67
+ """
68
+ from bayinx.core.node import Node
69
+
70
+ if isinstance(x, Node):
71
+ obj = x.obj
72
+ filter_spec = x._filter_spec
73
+ else:
74
+ obj: Any = x # type: ignore
75
+ filter_spec = True # implicit filter specification
76
+
77
+ return (obj, filter_spec)
78
+
79
+
80
+ def _merge_filter_specs(
81
+ filter_specs: List[PyTree],
82
+ objs: Optional[List[PyTree]] = None,
83
+ obj: Optional[PyTree] = None
84
+ ) -> PyTree:
85
+ """
86
+ Merge filter specifications that share the type of `obj`.
87
+
88
+ If `obj` and `objs` are not provided then `filter_specs` will be merged as is.
89
+ """
90
+
91
+ def _merge(*args):
92
+ return all(args)
93
+
94
+ if objs is None and obj is None:
95
+ filter_spec: Any = jt.map(_merge, *filter_specs) # type: ignore
96
+ else:
97
+ selected_specs: List[PyTree] = []
98
+
99
+ # Include filter specs whose objects share the correct type
100
+ for cur_obj, cur_spec in zip(objs, filter_specs): # type: ignore
101
+ if type(obj) is type(cur_obj):
102
+ selected_specs.append(cur_spec)
103
+
104
+ if len(selected_specs) != 0:
105
+ filter_spec = jt.map(_merge, *selected_specs)
106
+ else:
107
+ filter_spec = True
108
+
109
+ return filter_spec
@@ -0,0 +1,170 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ from typing import Callable, Generic, Self, Tuple, TypeVar
4
+
5
+ import equinox as eqx
6
+ import jax
7
+ import jax.lax as lax
8
+ import jax.numpy as jnp
9
+ import jax.random as jr
10
+ import optax as opx
11
+ from jaxtyping import Array, PRNGKeyArray, PyTree, Scalar
12
+ from optax import GradientTransformation, OptState
13
+
14
+ from bayinx.core.model import Model
15
+
16
+ M = TypeVar("M", bound=Model)
17
+
18
+ class Variational(eqx.Module, Generic[M]):
19
+ """
20
+ An abstract base class used to define variational methods.
21
+
22
+ # Attributes
23
+ - `dim`: The dimension of the support.
24
+ - `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
25
+ - `_static`: The static component of a partitioned `Model` used to initialize the `Variational` object.
26
+ """
27
+ dim: int
28
+ _unflatten: Callable[[Array], M]
29
+ _static: M
30
+
31
+ @property
32
+ @abstractmethod
33
+ def filter_spec(self) -> Self:
34
+ """
35
+ Filter specification for dynamic and static components of the
36
+ `Variational` object.
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def sample(self, n: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Array:
42
+ """
43
+ Sample from the variational distribution.
44
+ """
45
+ pass
46
+
47
+ @abstractmethod
48
+ def eval(self, draws: Array) -> Array:
49
+ """
50
+ Evaluate the variational distribution at `draws`.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def elbo(self, n: int, batch_size: int, key: PRNGKeyArray) -> Array:
56
+ """
57
+ Evaluate the ELBO.
58
+ """
59
+ pass
60
+
61
+ @abstractmethod
62
+ def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> M:
63
+ """
64
+ Evaluate the gradient of the ELBO.
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def elbo_and_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Tuple[Scalar, PyTree]:
70
+ """
71
+ Evaluate the ELBO and its gradient.
72
+ """
73
+ pass
74
+
75
+ @eqx.filter_jit
76
+ def reconstruct_model(self, draw: Array) -> M:
77
+ # Unflatten variational draw
78
+ model: M = self._unflatten(draw)
79
+
80
+ # Combine with static components
81
+ model: M = eqx.combine(model, self._static)
82
+
83
+ return model
84
+
85
+ @eqx.filter_jit
86
+ @partial(jax.vmap, in_axes=(None, 0))
87
+ def eval_model(self, draws: Array) -> Array:
88
+ """
89
+ Reconstruct models from variational draws and evaluate their posterior.
90
+
91
+ # Parameters
92
+ - `draws`: A set of variational draws.
93
+ """
94
+ # Unflatten variational draw
95
+ model: M = self.reconstruct_model(draws)
96
+
97
+ # Evaluate posterior
98
+ return model()
99
+
100
+ @eqx.filter_jit
101
+ def fit(
102
+ self,
103
+ max_iters: int,
104
+ learning_rate: float,
105
+ tolerance: float,
106
+ grad_draws: int,
107
+ batch_size: int,
108
+ key: PRNGKeyArray = jr.key(0),
109
+ ) -> Self:
110
+ """
111
+ Optimize the variational distribution.
112
+ """
113
+ # Partition variational
114
+ dyn, static = eqx.partition(self, self.filter_spec)
115
+
116
+ # Construct scheduler
117
+ schedule = opx.warmup_cosine_decay_schedule(
118
+ 1e-8,
119
+ learning_rate,
120
+ int(max_iters / 10),
121
+ max_iters - int(max_iters / 10)
122
+ )
123
+
124
+ # Initialize optimizer
125
+ optim: GradientTransformation = opx.chain(
126
+ opx.scale(-1.0), opx.adam(schedule) # replace learning_rate with scheduler
127
+ )
128
+ opt_state: OptState = optim.init(dyn)
129
+
130
+ def condition(state: Tuple[Self, OptState, Scalar, PRNGKeyArray]):
131
+ # Unpack iteration state
132
+ dyn, opt_state, i, key = state
133
+
134
+ return i < max_iters
135
+
136
+ def body(state: Tuple[Self, OptState, Scalar, PRNGKeyArray]):
137
+ # Unpack iteration state
138
+ dyn, opt_state, i, key = state
139
+
140
+ # Update iteration
141
+ i = i + 1
142
+
143
+ # Update PRNG key
144
+ key, _ = jr.split(key)
145
+
146
+ # Reconstruct variational
147
+ vari: Self = eqx.combine(dyn, static)
148
+
149
+ # Compute gradient of the ELBO for update
150
+ update: M = vari.elbo_grad(grad_draws, batch_size, key)
151
+
152
+ # Transform update through optimizer
153
+ update, opt_state = optim.update( # type: ignore
154
+ update, opt_state, eqx.filter(dyn, dyn.filter_spec) # type: ignore
155
+ )
156
+
157
+ # Update variational distribution
158
+ dyn: Self = eqx.apply_updates(dyn, update)
159
+
160
+ return dyn, opt_state, i, key
161
+
162
+ # Run optimization loop
163
+ dyn = lax.while_loop(
164
+ cond_fun=condition,
165
+ body_fun=body,
166
+ init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
167
+ )[0]
168
+
169
+ # Return optimized variational
170
+ return eqx.combine(dyn, static)
bayinx/dists/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
- from bayinx.dists import censored, gamma2, normal, posnormal
2
-
3
- __all__ = ['censored', "gamma2", "normal", "posnormal"]
1
+ from .bernoulli import Bernoulli as Bernoulli
2
+ from .binomial import Binomial as Binomial
3
+ from .exponential import Exponential as Exponential
4
+ from .normal import Normal as Normal
5
+ from .poisson import Poisson as Poisson