bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,64 @@
1
+ from typing import Any, Optional
2
+
3
+ import equinox as eqx
4
+ import jax.numpy as jnp
5
+ import jax.tree as jt
6
+ import numpy as np
7
+ from jaxtyping import Array, PyTree
8
+
9
+ from bayinx.constraints import Identity
10
+ from bayinx.core.constraint import Constraint
11
+ from bayinx.core.types import T
12
+ from bayinx.nodes.stochastic import Stochastic
13
+
14
+
15
+ def is_float_like(element: Any) -> bool:
16
+ """
17
+ Check if `element` is float-like.
18
+
19
+ The structure of this function is borrowed from the `Equinox` library.
20
+ """
21
+ if hasattr(element, "__jax_array__"):
22
+ element = element.__jax_array__()
23
+ if isinstance(element, (np.ndarray, np.generic)):
24
+ return bool(np.issubdtype(element.dtype, np.floating))
25
+ elif isinstance(element, Array):
26
+ return jnp.issubdtype(element.dtype, jnp.floating)
27
+ else:
28
+ return isinstance(element, float)
29
+
30
+
31
+ class Continuous(Stochastic[T]):
32
+ """
33
+ A container for continuous stochastic nodes of a probabilistic model.
34
+
35
+
36
+ # Attributes
37
+ - `obj`: An internal realization of the node.
38
+ - `_filter_spec`: An internal filter specification for `obj`.
39
+ - `_constraint`: A constraining transformation.
40
+ """
41
+
42
+ _constraint: Constraint
43
+
44
+
45
+ def __init__(
46
+ self,
47
+ obj: T,
48
+ constraint: Constraint = Identity(),
49
+ filter_spec: Optional[PyTree] = None
50
+ ):
51
+ if filter_spec is None: # Default filter specification
52
+ # Generate empty specification
53
+ filter_spec = jt.map(lambda _: False, obj)
54
+
55
+ # Specify float-like leaves
56
+ filter_spec = eqx.tree_at(
57
+ where=lambda obj: obj,
58
+ pytree=filter_spec,
59
+ replace=jt.map(is_float_like, obj),
60
+ )
61
+
62
+ self.obj = obj
63
+ self._filter_spec = filter_spec
64
+ self._constraint = constraint
@@ -0,0 +1,36 @@
1
+ from typing import Optional
2
+
3
+ import equinox as eqx
4
+ import jax.tree as jt
5
+ from jaxtyping import PyTree
6
+
7
+ from bayinx.core.node import Node
8
+ from bayinx.core.types import T
9
+
10
+
11
+ class Observed(Node[T]):
12
+ """
13
+ A container for observed nodes of a probabilistic model.
14
+
15
+
16
+ # Attributes
17
+ - `obj`: An internal realization of the node.
18
+ - `_filter_spec`: An internal filter specification for `obj`.
19
+ """
20
+
21
+ def __init__(
22
+ self, obj: T, filter_spec: Optional[PyTree] = None
23
+ ):
24
+ if filter_spec is None: # Default filter specification
25
+ # Generate empty specification
26
+ filter_spec = jt.map(lambda _: False, obj)
27
+
28
+ # Specify array-like leaves
29
+ filter_spec = eqx.tree_at(
30
+ where=lambda obj: obj,
31
+ pytree=filter_spec,
32
+ replace=jt.map(eqx.is_array_like, obj),
33
+ )
34
+
35
+ self.obj = obj
36
+ self._filter_spec = filter_spec
@@ -0,0 +1,25 @@
1
+ from abc import abstractmethod
2
+ from typing import Optional
3
+
4
+ from bayinx.core.node import Node
5
+ from bayinx.core.types import T
6
+
7
+
8
+ class Stochastic(Node[T]):
9
+ """
10
+ A container for stochastic (unobserved) nodes of a probabilistic model.
11
+
12
+ Subclasses can be constructed with defined filter specifications (implement the `filter_spec` property).
13
+
14
+ # Attributes
15
+ - `obj`: An internal realization of the node.
16
+ - `_filter_spec`: An internal filter specification for `obj`.
17
+ """
18
+
19
+ @abstractmethod
20
+ def __init__(
21
+ self,
22
+ obj: T,
23
+ filter_spec: Optional[T],
24
+ ):
25
+ pass
bayinx/ops.py ADDED
@@ -0,0 +1,104 @@
1
+ import jax.numpy as jnp
2
+ import jax.tree as jt
3
+ from jaxtyping import Array, ArrayLike, Real
4
+
5
+ from bayinx.core.node import Node
6
+ from bayinx.core.utils import _extract_obj
7
+
8
+
9
+ def exp(node: Node) -> Node:
10
+ """
11
+ Apply the exponential transformation (jnp.exp) to a node.
12
+ """
13
+ obj, filter_spec = _extract_obj(node)
14
+
15
+ # Helper function for the single-leaf exponential transform
16
+ def leaf_exp(x: Real[ArrayLike, "..."]) -> Array:
17
+ return jnp.exp(x)
18
+
19
+ # Apply exponential
20
+ new_obj = jt.map(leaf_exp, obj)
21
+
22
+ return type(node)(new_obj, filter_spec)
23
+
24
+
25
+ def log(node: Node) -> Node:
26
+ """
27
+ Apply the natural logarithm transformation (jnp.log) to a node.
28
+ Handles input value restrictions (must be positive).
29
+ """
30
+ obj, filter_spec = _extract_obj(node)
31
+
32
+ # Helper function for the single-leaf log transform
33
+ def leaf_log(x: Real[ArrayLike, "..."]) -> Array:
34
+ return jnp.log(x)
35
+
36
+ # Apply logarithm
37
+ new_obj = jt.map(leaf_log, obj)
38
+
39
+ return type(node)(new_obj, filter_spec)
40
+
41
+
42
+ def sin(node: Node) -> Node:
43
+ """
44
+ Apply the sine transformation (jnp.sin) to a node.
45
+ """
46
+ obj, filter_spec = _extract_obj(node)
47
+
48
+ # Helper function for the single-leaf sine transform
49
+ def leaf_sin(x: Real[ArrayLike, "..."]) -> Array:
50
+ return jnp.sin(x)
51
+
52
+ # Apply sine
53
+ new_obj = jt.map(leaf_sin, obj)
54
+
55
+ return type(node)(new_obj, filter_spec)
56
+
57
+
58
+ def cos(node: Node) -> Node:
59
+ """
60
+ Apply the cosine transformation (jnp.cos) to a node.
61
+ """
62
+ obj, filter_spec = _extract_obj(node)
63
+
64
+ # Helper function for the single-leaf cosine transform
65
+ def leaf_cos(x: Real[ArrayLike, "..."]) -> Array:
66
+ return jnp.cos(x)
67
+
68
+ # Apply cosine
69
+ new_obj = jt.map(leaf_cos, obj)
70
+
71
+ return type(node)(new_obj, filter_spec)
72
+
73
+
74
+ def tanh(node: Node) -> Node:
75
+ """
76
+ Apply the hyperbolic tangent transformation (jnp.tanh) to a node.
77
+ """
78
+ obj, filter_spec = _extract_obj(node)
79
+
80
+ # Helper function for the single-leaf tanh transform
81
+ def leaf_tanh(x: Real[ArrayLike, "..."]) -> Array:
82
+ return jnp.tanh(x)
83
+
84
+ # Apply tanh
85
+ new_obj = jt.map(leaf_tanh, obj)
86
+
87
+ return type(node)(new_obj, filter_spec)
88
+
89
+
90
+ def sigmoid(node: Node) -> Node:
91
+ """
92
+ Apply the sigmoid (logistic) transformation to a node.
93
+ Sigmoid formula: 1 / (1 + exp(-x))
94
+ """
95
+ obj, filter_spec = _extract_obj(node)
96
+
97
+ # Helper function for the single-leaf sigmoid transform
98
+ def leaf_sigmoid(x: Real[ArrayLike, "..."]) -> Array:
99
+ return 1.0 / (1.0 + jnp.exp(-x)) # type: ignore
100
+
101
+ # Apply sigmoid
102
+ new_obj = jt.map(leaf_sigmoid, obj)
103
+
104
+ return type(node)(new_obj, filter_spec)
bayinx/posterior.py ADDED
@@ -0,0 +1,220 @@
1
+
2
+ from functools import partial
3
+ from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type
4
+
5
+ import equinox as eqx
6
+ import jax
7
+ import jax.random as jr
8
+ from jax.lax import scan
9
+ from jaxtyping import Array, PRNGKeyArray
10
+
11
+ from bayinx.core.flow import FlowSpec
12
+ from bayinx.core.node import Node
13
+ from bayinx.core.variational import M
14
+ from bayinx.vi.normalizing_flow import NormalizingFlow
15
+ from bayinx.vi.standard import Standard
16
+
17
+
18
+ class Posterior(Generic[M]):
19
+ """
20
+ The posterior distribution for a model.
21
+
22
+ # Attributes
23
+ - `vari`: The variational approximation of the posterior.
24
+ - `config` The configuration for the posterior.
25
+ """
26
+ vari: NormalizingFlow[M]
27
+ config: Dict[str, Any]
28
+
29
+ def __init__(self, model_def: Type[M], **kwargs: Any):
30
+ # (hopefully) omit intermediate model construction through jit
31
+ @eqx.filter_jit
32
+ def construct_base(model_def):
33
+ # Construct model
34
+ model = model_def(**kwargs)
35
+
36
+ return Standard(model)
37
+
38
+ # Construct standard normal base distribution
39
+ self.vari = construct_base(model_def)
40
+
41
+ # Include default attributes
42
+ self.config = {
43
+ "learning_rate": 0.1 / self.vari.dim**0.5,
44
+ "tolerance": 1e-4,
45
+ "grad_draws": 4,
46
+ "batch_size": 1
47
+ }
48
+
49
+
50
+ def configure(
51
+ self,
52
+ flowspecs: Optional[List[FlowSpec]] = None,
53
+ learning_rate: Optional[float] = None,
54
+ tolerance: Optional[float] = None,
55
+ grad_draws: Optional[int] = None,
56
+ batch_size: Optional[int] = None
57
+ ):
58
+ """
59
+ Configure the variational approximation.
60
+
61
+ # Parameters
62
+ - `flowspecs`: The specification for a sequence of flows.
63
+ """
64
+ # Append new NF architecture
65
+ if flowspecs is not None:
66
+ # Initialize NF architecture
67
+ flows = [
68
+ flowspec.construct(self.vari.dim) for flowspec in flowspecs
69
+ ]
70
+
71
+ if isinstance(self.vari, Standard):
72
+ # Construct new normalizing flow
73
+ self.vari = NormalizingFlow(
74
+ base = self.vari,
75
+ flows = flows,
76
+ _static = self.vari._static,
77
+ _unflatten = self.vari._unflatten
78
+ )
79
+ elif isinstance(self.vari, NormalizingFlow):
80
+ # Freeze current flows
81
+ for flow in self.vari.flows:
82
+ object.__setattr__(flow, 'static', True) # kind of illegal but I need to avoid copies
83
+
84
+ # Append new flows
85
+ self.vari.flows.extend(flows)
86
+
87
+ # Include other settings
88
+ if learning_rate is not None:
89
+ self.config["learning_rate"] = learning_rate
90
+ if tolerance is not None:
91
+ self.config["tolerance"] = tolerance
92
+ if grad_draws is not None:
93
+ self.config["grad_draws"] = grad_draws
94
+ if batch_size is not None:
95
+ self.config["batch_size"] = batch_size
96
+
97
+
98
+ def fit(
99
+ self,
100
+ max_iters: int = 50_000,
101
+ learning_rate: Optional[float] = None,
102
+ tolerance: Optional[float] = None,
103
+ grad_draws: Optional[int] = None,
104
+ batch_size: Optional[int] = None,
105
+ key: PRNGKeyArray = jr.key(0),
106
+ ):
107
+ # Include settings
108
+ if learning_rate is not None:
109
+ self.config["learning_rate"] = learning_rate
110
+ if tolerance is not None:
111
+ self.config["tolerance"] = tolerance
112
+ if grad_draws is not None:
113
+ self.config["grad_draws"] = grad_draws
114
+ if batch_size is not None:
115
+ self.config["batch_size"] = batch_size
116
+
117
+ if isinstance(self.vari, Standard):
118
+ # Construct default sequence of optimization
119
+ print("TODO")
120
+ else:
121
+ # Optimize variational approximation with user-specified flows
122
+ self.vari = self.vari.fit(
123
+ max_iters,
124
+ self.config["learning_rate"],
125
+ self.config["tolerance"],
126
+ self.config["grad_draws"],
127
+ self.config["batch_size"],
128
+ key
129
+ )
130
+
131
+ def sample(
132
+ self,
133
+ node: str,
134
+ n_draws: int,
135
+ batch_size: Optional[int] = None,
136
+ key: PRNGKeyArray = jr.key(0)
137
+ ) -> Array:
138
+ """
139
+ Sample a node from the posterior distribution.
140
+
141
+ # Parameters
142
+ - `node`: The name of the node.
143
+ - `n_draws`: The number of draws from the posterior.
144
+ - `batch_size`: The number of draws for the full model ever initialized in memory at once.
145
+ - `key`: The PRNG key.
146
+ """
147
+ if batch_size is None:
148
+ batch_size = n_draws
149
+
150
+ # Split keys
151
+ keys = jr.split(key, n_draws // batch_size)
152
+
153
+ @partial(jax.vmap, in_axes = 0)
154
+ def reconstruct_and_subset(draw: Array):
155
+ model = self.vari.reconstruct_model(draw).constrain()[0]
156
+
157
+ return getattr(model, node).obj
158
+
159
+ def batched_sample(carry: None, key: PRNGKeyArray):
160
+ # Sample draws
161
+ draws = self.vari.sample(batch_size, key)
162
+
163
+ return None, reconstruct_and_subset(draws)
164
+
165
+ posterior_draws = scan(
166
+ batched_sample,
167
+ init=None,
168
+ xs=keys,
169
+ length=n_draws // batch_size
170
+ )[1].squeeze()
171
+
172
+ return posterior_draws
173
+
174
+
175
+ def predictive(
176
+ self,
177
+ func: Callable[[M, PRNGKeyArray], Node[Array] | Array],
178
+ n_draws: int,
179
+ batch_size: Optional[int] = None,
180
+ key: PRNGKeyArray = jr.key(0)
181
+ ) -> Array:
182
+ """
183
+
184
+ """
185
+ if batch_size is None:
186
+ batch_size = n_draws
187
+
188
+ # Split keys
189
+ keys = jr.split(key, n_draws // batch_size)
190
+
191
+ @partial(jax.vmap, in_axes = (0, 0))
192
+ def reconstruct_and_predict(draw: Array, key: PRNGKeyArray) -> Array:
193
+ model = self.vari.reconstruct_model(draw).constrain()[0]
194
+
195
+ # Compute predictive
196
+ obj = func(model, key)
197
+
198
+ # Coerce from Node if needed
199
+ if isinstance(obj, Node):
200
+ obj: Array = obj.obj # type: ignore
201
+
202
+ return obj
203
+
204
+ def batched_sample(carry: None, key: PRNGKeyArray) -> Tuple[None, Array]:
205
+ # Sample draws
206
+ draws = self.vari.sample(batch_size, key)
207
+
208
+ # Generate additional keys for each draw
209
+ keys = jr.split(key, batch_size)
210
+
211
+ return None, reconstruct_and_predict(draws, keys)
212
+
213
+ posterior_draws: Array = scan(
214
+ batched_sample,
215
+ init=None,
216
+ xs=keys,
217
+ length=n_draws // batch_size
218
+ )[1].squeeze()
219
+
220
+ return posterior_draws
bayinx/vi/__init__.py ADDED
File without changes
@@ -1,25 +1,28 @@
1
- from typing import Any, Dict, Generic, Self, TypeVar
1
+ from typing import Generic, Self
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.numpy as jnp
5
5
  import jax.random as jr
6
6
  import jax.tree_util as jtu
7
7
  from jax.flatten_util import ravel_pytree
8
- from jaxtyping import Array, Float, Key, Scalar
8
+ from jaxtyping import Array, PRNGKeyArray, Scalar
9
9
 
10
- from bayinx.core import Model, Variational
10
+ from bayinx.core.variational import M, Variational
11
11
  from bayinx.dists import normal
12
12
 
13
- M = TypeVar('M', bound=Model)
13
+
14
14
  class MeanField(Variational, Generic[M]):
15
15
  """
16
16
  A fully factorized Gaussian approximation to a posterior distribution.
17
17
 
18
18
  # Attributes
19
- - `var_params`: The variational parameters for the approximation.
19
+ - `dim`: The dimension of the support.
20
+ - `mean`: The mean of the unconstrained approximation.
21
+ - `log_std` The log-transformed standard deviation of the unconstrained approximation.
20
22
  """
21
23
 
22
- var_params: Dict[str, Float[Array, "..."]] #todo: just expand to attributes
24
+ mean: Array
25
+ log_std: Array
23
26
 
24
27
  def __init__(self, model: M):
25
28
  """
@@ -27,28 +30,31 @@ class MeanField(Variational, Generic[M]):
27
30
 
28
31
  # Parameters
29
32
  - `model`: A probabilistic `Model` object.
33
+ - `init_log_std`: The initial log-transformed standard deviation of the Gaussian approximation.
30
34
  """
31
35
  # Partition model
32
- params, self._constraints = eqx.partition(model, model.filter_spec)
36
+ params, self._static = eqx.partition(model, model.filter_spec)
33
37
 
34
38
  # Flatten params component
35
39
  params, self._unflatten = ravel_pytree(params)
36
40
 
37
41
  # Initialize variational parameters
38
- self.var_params = {
39
- "mean": params,
40
- "log_std": jnp.zeros(params.size, dtype=params.dtype),
41
- }
42
+ self.mean = params
43
+ self.log_std = jnp.full(params.size, 0.0)
42
44
 
43
45
  @property
44
- @eqx.filter_jit
45
46
  def filter_spec(self):
46
47
  # Generate empty specification
47
48
  filter_spec = jtu.tree_map(lambda _: False, self)
48
49
 
49
50
  # Specify variational parameters
50
51
  filter_spec = eqx.tree_at(
51
- lambda mf: mf.var_params,
52
+ lambda mf: mf.mean,
53
+ filter_spec,
54
+ replace=True,
55
+ )
56
+ filter_spec = eqx.tree_at(
57
+ lambda mf: mf.log_std,
52
58
  filter_spec,
53
59
  replace=True,
54
60
  )
@@ -56,12 +62,11 @@ class MeanField(Variational, Generic[M]):
56
62
  return filter_spec
57
63
 
58
64
  @eqx.filter_jit
59
- def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
65
+ def sample(self, n: int, key: PRNGKeyArray = jr.PRNGKey(0)) -> Array:
60
66
  # Sample variational draws
61
67
  draws: Array = (
62
- jr.normal(key=key, shape=(n, self.var_params["mean"].size))
63
- * jnp.exp(self.var_params["log_std"])
64
- + self.var_params["mean"]
68
+ jr.normal(key=key, shape=(n, self.mean.size)) * jnp.exp(self.log_std)
69
+ + self.mean
65
70
  )
66
71
 
67
72
  return draws
@@ -70,23 +75,23 @@ class MeanField(Variational, Generic[M]):
70
75
  def eval(self, draws: Array) -> Array:
71
76
  return normal.logprob(
72
77
  x=draws,
73
- mu=self.var_params["mean"],
74
- sigma=jnp.exp(self.var_params["log_std"]),
78
+ mu=self.mean,
79
+ sigma=jnp.exp(self.log_std),
75
80
  ).sum(axis=1)
76
81
 
77
82
  @eqx.filter_jit
78
- def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
83
+ def elbo(self, n: int, batch_size: int, key: PRNGKeyArray) -> Scalar:
79
84
  dyn, static = eqx.partition(self, self.filter_spec)
80
85
 
81
86
  @eqx.filter_jit
82
- def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
87
+ def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
83
88
  vari = eqx.combine(dyn, static)
84
89
 
85
90
  # Sample draws from variational distribution
86
91
  draws: Array = vari.sample(n, key)
87
92
 
88
93
  # Evaluate posterior density for each draw
89
- posterior_evals: Array = vari.eval_model(draws, data)
94
+ posterior_evals: Array = vari.eval_model(draws)
90
95
 
91
96
  # Evaluate variational density for each draw
92
97
  variational_evals: Array = vari.eval(draws)
@@ -94,23 +99,22 @@ class MeanField(Variational, Generic[M]):
94
99
  # Evaluate ELBO
95
100
  return jnp.mean(posterior_evals - variational_evals)
96
101
 
97
- return elbo(dyn, n, key, data)
102
+ return elbo(dyn, n, key)
98
103
 
99
104
  @eqx.filter_jit
100
- def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
105
+ def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Self:
101
106
  dyn, static = eqx.partition(self, self.filter_spec)
102
107
 
103
- @eqx.filter_grad
104
108
  @eqx.filter_jit
105
- def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
106
- # Combine
109
+ @eqx.filter_grad
110
+ def elbo_grad(dyn: Self, n: int, key: PRNGKeyArray):
107
111
  vari = eqx.combine(dyn, static)
108
112
 
109
113
  # Sample draws from variational distribution
110
114
  draws: Array = vari.sample(n, key)
111
115
 
112
116
  # Evaluate posterior density for each draw
113
- posterior_evals: Array = vari.eval_model(draws, data)
117
+ posterior_evals: Array = vari.eval_model(draws)
114
118
 
115
119
  # Evaluate variational density for each draw
116
120
  variational_evals: Array = vari.eval(draws)
@@ -118,4 +122,4 @@ class MeanField(Variational, Generic[M]):
118
122
  # Evaluate ELBO
119
123
  return jnp.mean(posterior_evals - variational_evals)
120
124
 
121
- return elbo_grad(dyn, n, key, data)
125
+ return elbo_grad(dyn, n, key)