bayinx 0.3.19__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayinx might be problematic. Click here for more details.

Files changed (48) hide show
  1. bayinx-0.4.0/PKG-INFO +47 -0
  2. bayinx-0.4.0/README.md +33 -0
  3. {bayinx-0.3.19 → bayinx-0.4.0}/pyproject.toml +4 -2
  4. bayinx-0.4.0/src/bayinx/__init__.py +3 -0
  5. bayinx-0.3.19/src/bayinx/constraints/lower.py → bayinx-0.4.0/src/bayinx/constraints.py +9 -13
  6. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/__init__.py +2 -2
  7. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_flow.py +0 -3
  8. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_model.py +41 -12
  9. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_optimization.py +3 -0
  10. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_parameter.py +7 -6
  11. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_variational.py +7 -10
  12. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/posnormal/r.py +6 -1
  13. bayinx-0.4.0/src/bayinx/dists/negbinom3.py +113 -0
  14. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/posnormal.py +38 -1
  15. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/uniform.py +6 -2
  16. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -2
  17. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/meanfield.py +3 -6
  18. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/normalizing_flow.py +4 -5
  19. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/standard.py +2 -1
  20. bayinx-0.4.0/tests/__init__.py +0 -0
  21. {bayinx-0.3.19 → bayinx-0.4.0}/tests/test_variational.py +8 -12
  22. {bayinx-0.3.19 → bayinx-0.4.0}/uv.lock +266 -266
  23. bayinx-0.3.19/PKG-INFO +0 -39
  24. bayinx-0.3.19/README.md +0 -27
  25. bayinx-0.3.19/src/bayinx/__init__.py +0 -3
  26. bayinx-0.3.19/src/bayinx/constraints/__init__.py +0 -3
  27. bayinx-0.3.19/tests/test_predictive.py +0 -45
  28. {bayinx-0.3.19 → bayinx-0.4.0}/.github/workflows/release_and_publish.yml +0 -0
  29. {bayinx-0.3.19 → bayinx-0.4.0}/.gitignore +0 -0
  30. {bayinx-0.3.19 → bayinx-0.4.0}/LICENSE +0 -0
  31. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/core/_constraint.py +0 -0
  32. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/__init__.py +0 -0
  33. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/bernoulli.py +0 -0
  34. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/__init__.py +0 -0
  35. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/gamma2/__init__.py +0 -0
  36. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/gamma2/r.py +0 -0
  37. bayinx-0.3.19/src/bayinx/mhx/opt/__init__.py → bayinx-0.4.0/src/bayinx/dists/censored/negbinom3/r.py +0 -0
  38. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/censored/posnormal/__init__.py +0 -0
  39. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/gamma2.py +0 -0
  40. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/dists/normal.py +0 -0
  41. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/__init__.py +0 -0
  42. {bayinx-0.3.19/tests → bayinx-0.4.0/src/bayinx/mhx/opt}/__init__.py +0 -0
  43. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/__init__.py +0 -0
  44. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  45. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  46. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/radial.py +1 -1
  47. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  48. {bayinx-0.3.19 → bayinx-0.4.0}/src/bayinx/py.typed +0 -0
bayinx-0.4.0/PKG-INFO ADDED
@@ -0,0 +1,47 @@
1
+ Metadata-Version: 2.4
2
+ Name: bayinx
3
+ Version: 0.4.0
4
+ Summary: Bayesian Inference with JAX
5
+ Author: Todd McCready
6
+ Maintainer: Todd McCready
7
+ License-File: LICENSE
8
+ Requires-Python: >=3.12
9
+ Requires-Dist: equinox>=0.11.12
10
+ Requires-Dist: jax>=0.4.38
11
+ Requires-Dist: jaxtyping>=0.2.36
12
+ Requires-Dist: optax>=0.2.4
13
+ Description-Content-Type: text/markdown
14
+
15
+ # `Bayinx`: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
16
+
17
+ The original aim of this project was to build a PPL in Python that is similar in feel to `Stan` or `Nimble`(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
18
+
19
+ Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to `Stan` and parse it to valid Rust code. Additionally, the current state of `bayinx` is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough functionality for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite `disize` in Python with JAX, and `bayinx` makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
20
+
21
+ Instead, this project is narrowing on implementing much of `Stan`'s functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
22
+
23
+ ```py
24
+ class NormalDist(Model):
25
+ x: Parameter[Array] = define(shape = (2,))
26
+
27
+ def eval(self, data: Dict[str, Array]):
28
+ # Constrain parameters
29
+ self, target = self.constrain_params() # this does nothing for the current model
30
+
31
+ # Evaluate x ~ Normal(10.0, 1.0)
32
+ target += normal.logprob(self.x(), 10.0, 1.0).sum()
33
+
34
+ return target
35
+ ```
36
+
37
+ I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for `disize`.
38
+
39
+
40
+ # TODO
41
+ - For optimization and variational methods offer a way for users to have custom stopping conditions(perhaps stop if a single parameter has converged, etc).
42
+ - Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
43
+ - Low-rank affine flow?
44
+ - https://arxiv.org/pdf/1803.05649 implement sylvester flows.
45
+ - Learn how to generate documentation.
46
+ - Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
47
+ - Look into adaptively tuning ADAM hyperparameters for VI.
bayinx-0.4.0/README.md ADDED
@@ -0,0 +1,33 @@
1
+ # `Bayinx`: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
2
+
3
+ The original aim of this project was to build a PPL in Python that is similar in feel to `Stan` or `Nimble`(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
4
+
5
+ Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to `Stan` and parse it to valid Rust code. Additionally, the current state of `bayinx` is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough functionality for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite `disize` in Python with JAX, and `bayinx` makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
6
+
7
+ Instead, this project is narrowing on implementing much of `Stan`'s functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
8
+
9
+ ```py
10
+ class NormalDist(Model):
11
+ x: Parameter[Array] = define(shape = (2,))
12
+
13
+ def eval(self, data: Dict[str, Array]):
14
+ # Constrain parameters
15
+ self, target = self.constrain_params() # this does nothing for the current model
16
+
17
+ # Evaluate x ~ Normal(10.0, 1.0)
18
+ target += normal.logprob(self.x(), 10.0, 1.0).sum()
19
+
20
+ return target
21
+ ```
22
+
23
+ I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for `disize`.
24
+
25
+
26
+ # TODO
27
+ - For optimization and variational methods offer a way for users to have custom stopping conditions(perhaps stop if a single parameter has converged, etc).
28
+ - Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
29
+ - Low-rank affine flow?
30
+ - https://arxiv.org/pdf/1803.05649 implement sylvester flows.
31
+ - Learn how to generate documentation.
32
+ - Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
33
+ - Look into adaptively tuning ADAM hyperparameters for VI.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.3.19"
3
+ version = "0.4.0"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -10,6 +10,8 @@ dependencies = [
10
10
  "jaxtyping>=0.2.36",
11
11
  "optax>=0.2.4",
12
12
  ]
13
+ authors = [{ name = "Todd McCready" }]
14
+ maintainers = [{ name = "Todd McCready" }]
13
15
 
14
16
  [build-system]
15
17
  requires = ["hatchling"]
@@ -19,7 +21,7 @@ build-backend = "hatchling.build"
19
21
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
22
 
21
23
  [tool.bumpversion]
22
- current_version = "0.3.19"
24
+ current_version = "0.4.0"
23
25
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
26
  serialize = ["{major}.{minor}.{patch}"]
25
27
  search = "{current_version}"
@@ -0,0 +1,3 @@
1
+ from bayinx.core import Model, Parameter, define
2
+
3
+ __all__ = ["Model", "Parameter", "define"]
@@ -3,9 +3,10 @@ from typing import Tuple
3
3
  import equinox as eqx
4
4
  import jax.numpy as jnp
5
5
  import jax.tree as jt
6
- from jaxtyping import PyTree, Scalar, ScalarLike
6
+ from jaxtyping import Scalar, ScalarLike
7
7
 
8
8
  from bayinx.core import Constraint, Parameter
9
+ from bayinx.core._parameter import T
9
10
 
10
11
 
11
12
  class Lower(Constraint):
@@ -18,33 +19,28 @@ class Lower(Constraint):
18
19
  def __init__(self, lb: ScalarLike):
19
20
  self.lb = jnp.array(lb)
20
21
 
21
- @eqx.filter_jit
22
- def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
22
+ def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
23
23
  """
24
24
  Enforces a lower bound on the parameter and adjusts the posterior density.
25
25
 
26
26
  # Parameters
27
- - `x`: The unconstrained `Parameter`.
27
+ - `param`: The unconstrained `Parameter`.
28
28
 
29
29
  # Parameters
30
30
  A tuple containing:
31
31
  - A modified `Parameter` with relevant leaves satisfying the constraint.
32
32
  - A scalar Array representing the log-absolute-Jacobian of the transformation.
33
33
  """
34
- # Extract relevant filter specification
35
- filter_spec = x.filter_spec
36
-
37
34
  # Extract relevant parameters(all Array)
38
- dyn_params, static_params = eqx.partition(x, filter_spec)
35
+ dyn, static = eqx.partition(param, param.filter_spec)
39
36
 
40
37
  # Compute density adjustment
41
- laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
42
- laj: Scalar = jt.reduce(lambda a, b: a + b, laj)
38
+ laj: Scalar = jt.reduce(lambda a, b: a + b, jt.map(jnp.sum, dyn))
43
39
 
44
40
  # Compute transformation
45
- dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
41
+ dyn = jt.map(lambda v: jnp.exp(v) + self.lb, dyn)
46
42
 
47
43
  # Combine into full parameter object
48
- x = eqx.combine(dyn_params, static_params)
44
+ param = eqx.combine(dyn, static)
49
45
 
50
- return x, laj
46
+ return param, laj
@@ -1,6 +1,6 @@
1
1
  from ._constraint import Constraint
2
2
  from ._flow import Flow
3
- from ._model import Model, constrain
3
+ from ._model import Model, define
4
4
  from ._optimization import optimize_model
5
5
  from ._parameter import Parameter
6
6
  from ._variational import Variational
@@ -9,7 +9,7 @@ __all__ = [
9
9
  "Constraint",
10
10
  "Flow",
11
11
  "Model",
12
- "constrain",
12
+ "define",
13
13
  "optimize_model",
14
14
  "Parameter",
15
15
  "Variational",
@@ -37,7 +37,6 @@ class Flow(eqx.Module):
37
37
 
38
38
  # Default filter specification
39
39
  @property
40
- @eqx.filter_jit
41
40
  def filter_spec(self):
42
41
  """
43
42
  Generates a filter specification to subset relevant parameters for the flow.
@@ -54,7 +53,6 @@ class Flow(eqx.Module):
54
53
 
55
54
  return filter_spec
56
55
 
57
- @eqx.filter_jit
58
56
  def constrain_params(self: Self):
59
57
  """
60
58
  Constrain `params` to the appropriate domain.
@@ -69,7 +67,6 @@ class Flow(eqx.Module):
69
67
 
70
68
  return t_params
71
69
 
72
- @eqx.filter_jit
73
70
  def transform_params(self: Self) -> Dict[str, Array]:
74
71
  """
75
72
  Apply a custom transformation to `params` if needed.
@@ -1,19 +1,34 @@
1
1
  from abc import abstractmethod
2
2
  from dataclasses import field, fields
3
- from typing import Any, Self, Tuple
3
+ from typing import Any, Dict, Optional, Self, Tuple
4
4
 
5
5
  import equinox as eqx
6
6
  import jax.numpy as jnp
7
7
  import jax.tree as jt
8
- from jaxtyping import Scalar
8
+ from jaxtyping import PyTree, Scalar
9
9
 
10
10
  from ._constraint import Constraint
11
11
  from ._parameter import Parameter
12
12
 
13
13
 
14
- def constrain(constraint: Constraint):
15
- """Defines constraint metadata."""
16
- return field(metadata={"constraint": constraint})
14
+ def define(
15
+ shape: Optional[Tuple[int, ...]] = None,
16
+ init: Optional[PyTree] = None,
17
+ constraint: Optional[Constraint] = None
18
+ ):
19
+ """Define a parameter."""
20
+ metadata: Dict = {}
21
+ if constraint is not None:
22
+ metadata["constraint"] = constraint
23
+
24
+ if isinstance(shape, Tuple):
25
+ metadata["shape"] = shape
26
+ elif isinstance(init, PyTree):
27
+ metadata["init"] = init
28
+ else:
29
+ raise TypeError("Neither 'shape' nor 'init' were given as proper arguments.")
30
+
31
+ return field(metadata = metadata)
17
32
 
18
33
 
19
34
  class Model(eqx.Module):
@@ -22,16 +37,32 @@ class Model(eqx.Module):
22
37
 
23
38
  Annotate parameter attributes with `Parameter`.
24
39
 
25
- Include constraints by setting them equal to `constrain(Constraint)`.
40
+ Include constraints by setting them equal to `define(Constraint)`.
26
41
  """
27
42
 
43
+ def __new__(cls, *args, **kwargs):
44
+ obj = super().__new__(cls)
45
+
46
+ # Auto-initialize parameters based on `define` metadata
47
+ for f in fields(cls):
48
+ if "shape" in f.metadata:
49
+ # Construct jax Array with correct dimensions
50
+ setattr(obj, f.name, Parameter(jnp.zeros(f.metadata["shape"])))
51
+ elif "init" in f.metadata:
52
+ # Slot in given 'init' object
53
+ setattr(obj, f.name, Parameter(f.metadata["init"]))
54
+
55
+ return obj
56
+
57
+ def __init__(self):
58
+ return self
59
+
28
60
  @abstractmethod
29
61
  def eval(self, data: Any) -> Scalar:
30
62
  pass
31
63
 
32
64
  # Default filter specification
33
65
  @property
34
- @eqx.filter_jit
35
66
  def filter_spec(self) -> Self:
36
67
  """
37
68
  Generates a filter specification to subset relevant parameters for the model.
@@ -49,12 +80,11 @@ class Model(eqx.Module):
49
80
  filter_spec = eqx.tree_at(
50
81
  lambda model: getattr(model, f.name),
51
82
  filter_spec,
52
- replace=attr.filter_spec,
83
+ replace=attr.filter_spec
53
84
  )
54
85
 
55
86
  return filter_spec
56
87
 
57
- @eqx.filter_jit
58
88
  def constrain_params(self) -> Tuple[Self, Scalar]:
59
89
  """
60
90
  Constrain parameters to the appropriate domain.
@@ -70,14 +100,14 @@ class Model(eqx.Module):
70
100
  attr = getattr(self, f.name)
71
101
 
72
102
  # Check if constrained parameter
73
- if isinstance(attr, Parameter) and "constraint" in f.metadata:
103
+ if isinstance(attr, Parameter) and ("constraint" in f.metadata):
74
104
  param = attr
75
105
  constraint = f.metadata["constraint"]
76
106
 
77
107
  # Apply constraint
78
108
  param, laj = constraint.constrain(param)
79
109
 
80
- # Update parameters for constrained model
110
+ # Update parameters for constrained model at same node
81
111
  constrained = eqx.tree_at(
82
112
  lambda model: getattr(model, f.name), constrained, replace=param
83
113
  )
@@ -87,7 +117,6 @@ class Model(eqx.Module):
87
117
 
88
118
  return constrained, target
89
119
 
90
- @eqx.filter_jit
91
120
  def transform_params(self) -> Tuple[Self, Scalar]:
92
121
  """
93
122
  Apply a custom transformation to parameters if needed(defaults to constrained parameters).
@@ -10,6 +10,8 @@ from optax import GradientTransformation, OptState, Schedule
10
10
  from ._model import Model
11
11
 
12
12
  M = TypeVar("M", bound=Model)
13
+
14
+
13
15
  @eqx.filter_jit
14
16
  def optimize_model(
15
17
  model: M,
@@ -39,6 +41,7 @@ def optimize_model(
39
41
 
40
42
  # Evaluate posterior
41
43
  return model.eval(data)
44
+
42
45
  eval_grad: Callable[[M], M] = eqx.filter_jit(eqx.filter_grad(eval))
43
46
 
44
47
  # Construct scheduler
@@ -5,6 +5,8 @@ import jax.tree as jt
5
5
  from jaxtyping import PyTree
6
6
 
7
7
  T = TypeVar("T", bound=PyTree)
8
+
9
+
8
10
  class Parameter(eqx.Module, Generic[T]):
9
11
  """
10
12
  A container for a parameter of a `Model`.
@@ -26,19 +28,18 @@ class Parameter(eqx.Module, Generic[T]):
26
28
 
27
29
  # Default filter specification
28
30
  @property
29
- @eqx.filter_jit
30
31
  def filter_spec(self) -> Self:
31
32
  """
32
- Generates a filter specification to filter out static parameters.
33
+ Generates a filter specification to filter for dynamic parameters.
33
34
  """
34
35
  # Generate empty specification
35
- filter_spec = jt.map(lambda _: False, self)
36
+ filter_spec: Self = jt.map(lambda _: False, self)
36
37
 
37
- # Specify Array leaves
38
+ # Specify Array-like leaves
38
39
  filter_spec = eqx.tree_at(
39
- lambda params: params.vals,
40
+ lambda param: param.vals,
40
41
  filter_spec,
41
- replace=jt.map(eqx.is_array_like, self.vals),
42
+ replace=jt.map(eqx.is_inexact_array_like, self.vals),
42
43
  )
43
44
 
44
45
  return filter_spec
@@ -22,11 +22,11 @@ class Variational(eqx.Module, Generic[M]):
22
22
 
23
23
  # Attributes
24
24
  - `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
25
- - `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
25
+ - `_static`: The static component of a partitioned `Model` used to initialize the `Variational` object.
26
26
  """
27
27
 
28
28
  _unflatten: Callable[[Array], M]
29
- _constraints: M
29
+ _static: M
30
30
 
31
31
  @abstractmethod
32
32
  def filter_spec(self):
@@ -69,7 +69,7 @@ class Variational(eqx.Module, Generic[M]):
69
69
  model: M = self._unflatten(draw)
70
70
 
71
71
  # Combine with constraints
72
- model: M = eqx.combine(model, self._constraints)
72
+ model: M = eqx.combine(model, self._static)
73
73
 
74
74
  return model
75
75
 
@@ -89,7 +89,6 @@ class Variational(eqx.Module, Generic[M]):
89
89
  # Evaluate posterior density
90
90
  return model.eval(data)
91
91
 
92
- # TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
93
92
  @eqx.filter_jit
94
93
  def fit(
95
94
  self,
@@ -116,11 +115,9 @@ class Variational(eqx.Module, Generic[M]):
116
115
  dyn, static = eqx.partition(self, self.filter_spec)
117
116
 
118
117
  # Construct scheduler
119
- schedule: Schedule = opx.warmup_cosine_decay_schedule(
120
- init_value=1e-16,
121
- peak_value=learning_rate,
122
- warmup_steps=int(max_iters / 10),
123
- decay_steps=max_iters - int(max_iters / 10),
118
+ schedule: Schedule = opx.cosine_decay_schedule(
119
+ init_value=learning_rate,
120
+ decay_steps=max_iters,
124
121
  )
125
122
 
126
123
  # Initialize optimizer
@@ -175,7 +172,7 @@ class Variational(eqx.Module, Generic[M]):
175
172
  return eqx.combine(dyn, static)
176
173
 
177
174
  @eqx.filter_jit
178
- def posterior_predictive(
175
+ def _posterior_predictive(
179
176
  self,
180
177
  func: Callable[[M, Any], Array],
181
178
  n: int,
@@ -111,6 +111,11 @@ def sample(
111
111
 
112
112
  # Construct draws
113
113
  draws = jr.uniform(key, shape)
114
- draws = mu + sigma * ndtri(normal.cdf(-mu/sigma, 0.0, 1.0) + draws * normal.cdf(mu/sigma, 0.0, 1.0))
114
+ draws = mu + sigma * ndtri(
115
+ normal.cdf(-mu / sigma, 0.0, 1.0) + draws * normal.cdf(mu / sigma, 0.0, 1.0)
116
+ )
117
+
118
+ # Censor draws
119
+ draws.at[censor <= draws].set(censor)
115
120
 
116
121
  return draws
@@ -0,0 +1,113 @@
1
+ import jax.numpy as jnp
2
+ from jax.scipy.special import gammaln
3
+ from jaxtyping import Array, ArrayLike, Float, UInt
4
+
5
+ __PI = 3.141592653589793
6
+
7
+
8
+ def __binom(x, y):
9
+ """
10
+ Helper function for the Binomial coefficient.
11
+ """
12
+ return jnp.exp(gammaln(x + 1) - gammaln(y + 1) - gammaln(x - y + 1))
13
+
14
+
15
+ def prob(
16
+ x: UInt[ArrayLike, "..."],
17
+ mu: Float[ArrayLike, "..."],
18
+ phi: Float[ArrayLike, "..."],
19
+ ) -> Float[Array, "..."]:
20
+ """
21
+ The probability mass function (PMF) for a (mean-inverse overdispersion parameterized) Negatvie Binomial distribution.
22
+
23
+ # Parameters
24
+ - `x`: Where to evaluate the PMF.
25
+ - `mu`: The mean.
26
+ - `phi`: The inverse overdispersion.
27
+
28
+ # Returns
29
+ The PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `phi`.
30
+ """
31
+ # Cast to Array
32
+ x, mu, phi = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(phi)
33
+
34
+ return jnp.exp(logprob(x, mu, phi))
35
+
36
+
37
+ def logprob(
38
+ x: UInt[ArrayLike, "..."],
39
+ mu: Float[ArrayLike, "..."],
40
+ phi: Float[ArrayLike, "..."],
41
+ ) -> Float[Array, "..."]:
42
+ """
43
+ The log-transformed probability mass function (PMF) for a (mean-inverse overdispersion parameterized) Negatvie Binomial distribution.
44
+
45
+ # Parameters
46
+ - `x`: Where to evaluate the log PMF.
47
+ - `mu`: The mean.
48
+ - `phi`: The inverse overdispersion.
49
+
50
+ # Returns
51
+ The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `phi`.
52
+ """
53
+ # Cast to Array
54
+ x, mu, phi = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(phi)
55
+
56
+ # Evaluate log PMF
57
+ evals: Array = jnp.where(
58
+ x < 0,
59
+ -jnp.inf,
60
+ (
61
+ gammaln(x + phi)
62
+ - gammaln(x + 1)
63
+ - gammaln(phi)
64
+ + x * (jnp.log(mu) - jnp.log(mu + phi))
65
+ + phi * (jnp.log(phi) - jnp.log(mu + phi))
66
+ ),
67
+ )
68
+
69
+ return evals
70
+
71
+
72
+ def cdf(
73
+ x: Float[ArrayLike, "..."],
74
+ mu: Float[ArrayLike, "..."],
75
+ sigma: Float[ArrayLike, "..."],
76
+ ) -> Float[Array, "..."]:
77
+ # Cast to Array
78
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
79
+
80
+ return jnp.array(1.0)
81
+
82
+
83
+ def logcdf(
84
+ x: Float[ArrayLike, "..."],
85
+ mu: Float[ArrayLike, "..."],
86
+ sigma: Float[ArrayLike, "..."],
87
+ ) -> Float[Array, "..."]:
88
+ # Cast to Array
89
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
90
+
91
+ return jnp.array(1.0)
92
+
93
+
94
+ def ccdf(
95
+ x: Float[ArrayLike, "..."],
96
+ mu: Float[ArrayLike, "..."],
97
+ sigma: Float[ArrayLike, "..."],
98
+ ) -> Float[Array, "..."]:
99
+ # Cast to Array
100
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
101
+
102
+ return jnp.array(1.0)
103
+
104
+
105
+ def logccdf(
106
+ x: Float[ArrayLike, "..."],
107
+ mu: Float[ArrayLike, "..."],
108
+ sigma: Float[ArrayLike, "..."],
109
+ ) -> Float[Array, "..."]:
110
+ # Cast to Array
111
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
112
+
113
+ return jnp.array(1.0)
@@ -1,5 +1,7 @@
1
1
  import jax.numpy as jnp
2
- from jaxtyping import Array, ArrayLike, Float
2
+ import jax.random as jr
3
+ from jax.scipy.special import ndtri
4
+ from jaxtyping import Array, ArrayLike, Float, Key
3
5
 
4
6
  from bayinx.dists import normal
5
7
 
@@ -251,3 +253,38 @@ def logccdf(
251
253
  evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
252
254
 
253
255
  return evals
256
+
257
+
258
+ def sample(
259
+ n: int,
260
+ mu: Float[ArrayLike, "..."],
261
+ sigma: Float[ArrayLike, "..."],
262
+ key: Key = jr.PRNGKey(0),
263
+ ) -> Float[Array, "..."]:
264
+ """
265
+ Sample from a positive Normal distribution.
266
+
267
+ # Parameters
268
+ - `n`: Number of draws to sample per-parameter.
269
+ - `mu`: The mean.
270
+ - `sigma`: The standard deviation.
271
+
272
+ # Returns
273
+ Draws from a positive Normal distribution. The output will have the shape of (n,) + the broadcasted shapes of `mu` and `sigma`.
274
+ """
275
+ # Cast to Array
276
+ mu, sigma = (
277
+ jnp.asarray(mu),
278
+ jnp.asarray(sigma),
279
+ )
280
+
281
+ # Derive shape
282
+ shape = (n,) + jnp.broadcast_shapes(mu.shape, sigma.shape)
283
+
284
+ # Construct draws
285
+ draws = jr.uniform(key, shape)
286
+ draws = mu + sigma * ndtri(
287
+ normal.cdf(-mu / sigma, 0.0, 1.0) + draws * normal.cdf(mu / sigma, 0.0, 1.0)
288
+ )
289
+
290
+ return draws
@@ -83,8 +83,12 @@ def ulogprob(
83
83
 
84
84
  return jnp.zeros(jnp.broadcast_shapes(x.shape, lb.shape, ub.shape))
85
85
 
86
+
86
87
  def sample(
87
- n: int, lb: Float[ArrayLike, "..."], ub: Float[ArrayLike, "..."], key: Key = jr.PRNGKey(0),
88
+ n: int,
89
+ lb: Float[ArrayLike, "..."],
90
+ ub: Float[ArrayLike, "..."],
91
+ key: Key = jr.PRNGKey(0),
88
92
  ) -> Float[Array, "..."]:
89
93
  """
90
94
  Sample from a Uniform distribution.
@@ -104,6 +108,6 @@ def sample(
104
108
  shape = (n,) + jnp.broadcast_shapes(lb.shape, ub.shape)
105
109
 
106
110
  # Construct draws
107
- draws = jr.uniform(key, shape, minval = lb, maxval = ub)
111
+ draws = jr.uniform(key, shape, minval=lb, maxval=ub)
108
112
 
109
113
  return draws
@@ -33,8 +33,6 @@ class FullAffine(Flow):
33
33
  if dim == 1:
34
34
  self.constraints = {}
35
35
  else:
36
-
37
- @eqx.filter_jit
38
36
  def constrain_scale(scale: Array):
39
37
  # Extract diagonal and apply exponential
40
38
  diag: Array = jnp.exp(jnp.diag(scale))
@@ -34,7 +34,7 @@ class MeanField(Variational, Generic[M]):
34
34
  - `init_log_std`: The initial log-transformed standard deviation of the Gaussian approximation.
35
35
  """
36
36
  # Partition model
37
- params, self._constraints = eqx.partition(model, model.filter_spec)
37
+ params, self._static = eqx.partition(model, model.filter_spec)
38
38
 
39
39
  # Flatten params component
40
40
  params, self._unflatten = ravel_pytree(params)
@@ -44,7 +44,6 @@ class MeanField(Variational, Generic[M]):
44
44
  self.log_std = jnp.full(params.size, init_log_std, params.dtype)
45
45
 
46
46
  @property
47
- @eqx.filter_jit
48
47
  def filter_spec(self):
49
48
  # Generate empty specification
50
49
  filter_spec = jtu.tree_map(lambda _: False, self)
@@ -67,8 +66,7 @@ class MeanField(Variational, Generic[M]):
67
66
  def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
68
67
  # Sample variational draws
69
68
  draws: Array = (
70
- jr.normal(key=key, shape=(n, self.mean.size))
71
- * jnp.exp(self.log_std)
69
+ jr.normal(key=key, shape=(n, self.mean.size)) * jnp.exp(self.log_std)
72
70
  + self.mean
73
71
  )
74
72
 
@@ -108,10 +106,9 @@ class MeanField(Variational, Generic[M]):
108
106
  def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
109
107
  dyn, static = eqx.partition(self, self.filter_spec)
110
108
 
111
- @eqx.filter_grad
112
109
  @eqx.filter_jit
110
+ @eqx.filter_grad
113
111
  def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
114
- # Combine
115
112
  vari = eqx.combine(dyn, static)
116
113
 
117
114
  # Sample draws from variational distribution