bayinx 0.3.19__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayinx might be problematic. Click here for more details.

Files changed (49) hide show
  1. bayinx-0.4.1/PKG-INFO +47 -0
  2. bayinx-0.4.1/README.md +33 -0
  3. {bayinx-0.3.19 → bayinx-0.4.1}/pyproject.toml +4 -2
  4. bayinx-0.4.1/src/bayinx/__init__.py +3 -0
  5. bayinx-0.4.1/src/bayinx/constraints.py +135 -0
  6. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/core/__init__.py +2 -2
  7. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/core/_flow.py +0 -3
  8. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/core/_model.py +41 -12
  9. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/core/_optimization.py +3 -0
  10. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/core/_parameter.py +7 -6
  11. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/core/_variational.py +7 -10
  12. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/censored/posnormal/r.py +6 -1
  13. bayinx-0.4.1/src/bayinx/dists/negbinom3.py +113 -0
  14. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/posnormal.py +38 -1
  15. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/uniform.py +6 -2
  16. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -2
  17. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/vi/meanfield.py +3 -6
  18. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/vi/normalizing_flow.py +4 -5
  19. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/vi/standard.py +2 -1
  20. bayinx-0.4.1/tests/__init__.py +0 -0
  21. {bayinx-0.3.19 → bayinx-0.4.1}/tests/test_variational.py +8 -12
  22. {bayinx-0.3.19 → bayinx-0.4.1}/uv.lock +266 -266
  23. bayinx-0.3.19/PKG-INFO +0 -39
  24. bayinx-0.3.19/README.md +0 -27
  25. bayinx-0.3.19/src/bayinx/__init__.py +0 -3
  26. bayinx-0.3.19/src/bayinx/constraints/__init__.py +0 -3
  27. bayinx-0.3.19/src/bayinx/constraints/lower.py +0 -50
  28. bayinx-0.3.19/tests/test_predictive.py +0 -45
  29. {bayinx-0.3.19 → bayinx-0.4.1}/.github/workflows/release_and_publish.yml +0 -0
  30. {bayinx-0.3.19 → bayinx-0.4.1}/.gitignore +0 -0
  31. {bayinx-0.3.19 → bayinx-0.4.1}/LICENSE +0 -0
  32. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/core/_constraint.py +0 -0
  33. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/__init__.py +0 -0
  34. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/bernoulli.py +0 -0
  35. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/censored/__init__.py +0 -0
  36. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/censored/gamma2/__init__.py +0 -0
  37. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/censored/gamma2/r.py +0 -0
  38. bayinx-0.3.19/src/bayinx/mhx/opt/__init__.py → bayinx-0.4.1/src/bayinx/dists/censored/negbinom3/r.py +0 -0
  39. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/censored/posnormal/__init__.py +0 -0
  40. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/gamma2.py +0 -0
  41. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/dists/normal.py +0 -0
  42. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/__init__.py +0 -0
  43. {bayinx-0.3.19/tests → bayinx-0.4.1/src/bayinx/mhx/opt}/__init__.py +0 -0
  44. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/vi/__init__.py +0 -0
  45. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  46. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  47. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/radial.py +1 -1
  48. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  49. {bayinx-0.3.19 → bayinx-0.4.1}/src/bayinx/py.typed +0 -0
bayinx-0.4.1/PKG-INFO ADDED
@@ -0,0 +1,47 @@
1
+ Metadata-Version: 2.4
2
+ Name: bayinx
3
+ Version: 0.4.1
4
+ Summary: Bayesian Inference with JAX
5
+ Author: Todd McCready
6
+ Maintainer: Todd McCready
7
+ License-File: LICENSE
8
+ Requires-Python: >=3.12
9
+ Requires-Dist: equinox>=0.11.12
10
+ Requires-Dist: jax>=0.4.38
11
+ Requires-Dist: jaxtyping>=0.2.36
12
+ Requires-Dist: optax>=0.2.4
13
+ Description-Content-Type: text/markdown
14
+
15
+ # Bayinx: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
16
+
17
+ The original aim of this project was to build a PPL in Python that is similar in feel to Stan or Nimble(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
18
+
19
+ Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to Stan and parse it to valid Rust code. Additionally, the current state of bayinx is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite disize in Python with JAX, and bayinx makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
20
+
21
+ Instead, this project is narrowing on implementing much of Stan's functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
22
+
23
+ ```py
24
+ class NormalDist(Model):
25
+ x: Parameter[Array] = define(shape = (2,))
26
+
27
+ def eval(self, data: Dict[str, Array]):
28
+ # Constrain parameters
29
+ self, target = self.constrain_params() # this does nothing for the current model
30
+
31
+ # Evaluate x ~ Normal(10.0, 1.0)
32
+ target += normal.logprob(self.x(), 10.0, 1.0).sum()
33
+
34
+ return target
35
+ ```
36
+
37
+ I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for disize.
38
+
39
+
40
+ # TODO
41
+ - For optimization and variational methods offer a way for users to have custom stopping conditions(perhaps stop if a single parameter has converged, etc).
42
+ - Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
43
+ - Low-rank affine flow?
44
+ - https://arxiv.org/pdf/1803.05649 implement sylvester flows.
45
+ - Learn how to generate documentation.
46
+ - Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
47
+ - Look into adaptively tuning ADAM hyperparameters for VI.
bayinx-0.4.1/README.md ADDED
@@ -0,0 +1,33 @@
1
+ # Bayinx: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
2
+
3
+ The original aim of this project was to build a PPL in Python that is similar in feel to Stan or Nimble(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
4
+
5
+ Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to Stan and parse it to valid Rust code. Additionally, the current state of bayinx is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite disize in Python with JAX, and bayinx makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
6
+
7
+ Instead, this project is narrowing on implementing much of Stan's functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
8
+
9
+ ```py
10
+ class NormalDist(Model):
11
+ x: Parameter[Array] = define(shape = (2,))
12
+
13
+ def eval(self, data: Dict[str, Array]):
14
+ # Constrain parameters
15
+ self, target = self.constrain_params() # this does nothing for the current model
16
+
17
+ # Evaluate x ~ Normal(10.0, 1.0)
18
+ target += normal.logprob(self.x(), 10.0, 1.0).sum()
19
+
20
+ return target
21
+ ```
22
+
23
+ I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for disize.
24
+
25
+
26
+ # TODO
27
+ - For optimization and variational methods offer a way for users to have custom stopping conditions(perhaps stop if a single parameter has converged, etc).
28
+ - Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
29
+ - Low-rank affine flow?
30
+ - https://arxiv.org/pdf/1803.05649 implement sylvester flows.
31
+ - Learn how to generate documentation.
32
+ - Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
33
+ - Look into adaptively tuning ADAM hyperparameters for VI.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.3.19"
3
+ version = "0.4.1"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -10,6 +10,8 @@ dependencies = [
10
10
  "jaxtyping>=0.2.36",
11
11
  "optax>=0.2.4",
12
12
  ]
13
+ authors = [{ name = "Todd McCready" }]
14
+ maintainers = [{ name = "Todd McCready" }]
13
15
 
14
16
  [build-system]
15
17
  requires = ["hatchling"]
@@ -19,7 +21,7 @@ build-backend = "hatchling.build"
19
21
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
20
22
 
21
23
  [tool.bumpversion]
22
- current_version = "0.3.19"
24
+ current_version = "0.4.1"
23
25
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
24
26
  serialize = ["{major}.{minor}.{patch}"]
25
27
  search = "{current_version}"
@@ -0,0 +1,3 @@
1
+ from bayinx.core import Model, Parameter, define
2
+
3
+ __all__ = ["Model", "Parameter", "define"]
@@ -0,0 +1,135 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.nn as jnn
5
+ import jax.numpy as jnp
6
+ import jax.tree as jt
7
+ from jaxtyping import Array, PyTree, Scalar, ScalarLike
8
+
9
+ from bayinx.core import Constraint, Parameter
10
+ from bayinx.core._parameter import T
11
+
12
+
13
+ class Lower(Constraint):
14
+ """
15
+ Enforces a lower bound on the parameter.
16
+ """
17
+
18
+ lb: Scalar
19
+
20
+ def __init__(self, lb: ScalarLike):
21
+ # assert greater than 1
22
+ self.lb = jnp.asarray(lb)
23
+
24
+ def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
25
+ """
26
+ Enforces a lower bound on the parameter and adjusts the posterior density.
27
+
28
+ # Parameters
29
+ - `param`: The unconstrained `Parameter`.
30
+
31
+ # Returns
32
+ A tuple containing:
33
+ - A modified `Parameter` with relevant leaves satisfying the constraint.
34
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
35
+ """
36
+ # Extract relevant parameters(all inexact Arrays)
37
+ dyn, static = eqx.partition(param, param.filter_spec)
38
+
39
+ # Compute Jacobian adjustment
40
+ total_laj: Scalar = jt.reduce(lambda a, b: a + b, jt.map(jnp.sum, dyn))
41
+
42
+ # Compute transformation
43
+ dyn = jt.map(lambda v: jnp.exp(v) + self.lb, dyn)
44
+
45
+ # Combine into full parameter object
46
+ param = eqx.combine(dyn, static)
47
+
48
+ return param, total_laj
49
+
50
+
51
+ class LogSimplex(Constraint):
52
+ """
53
+ Enforces a log-transformed simplex constraint on the parameter.
54
+
55
+ # Attributes
56
+ - `sum`: The total sum of the parameter.
57
+ """
58
+
59
+ sum: Scalar
60
+
61
+ def __init__(self, sum_val: ScalarLike = 1.0):
62
+ """
63
+ # Parameters
64
+ - `sum_val`: The target sum for the exponentiated simplex. Defaults to 1.0.
65
+ """
66
+ self.sum = jnp.asarray(sum_val)
67
+
68
+ def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
69
+ """
70
+ Enforces a log-transformed simplex constraint on the parameter and adjusts the posterior density.
71
+
72
+ # Parameters
73
+ - `param`: The unconstrained `Parameter`.
74
+
75
+ # Returns
76
+ A tuple containing:
77
+ - A modified `Parameter` with relevant leaves satisfying the constraint.
78
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
79
+ """
80
+ # Partition the parameter into dynamic (to be transformed) and static parts
81
+ dyn, static = eqx.partition(param, param.filter_spec)
82
+
83
+ # Map transformation leaf-wise
84
+ transformed = jt.map(self._transform_leaf, dyn) ## filter spec handles subsetting arrays, is_leaf unnecessary
85
+
86
+ # Extract constrained parameters and Jacobian adjustments
87
+ dyn_constrained: PyTree = jt.map(lambda x: x[0], transformed)
88
+ lajs: PyTree = jt.map(lambda x: x[1], transformed)
89
+
90
+ # Sum to get total Jacobian adjustment
91
+ total_laj = jt.reduce(lambda a, b: a + b, lajs)
92
+
93
+ # Recombine the transformed dynamic parts with the static parts
94
+ param = eqx.combine(dyn_constrained, static)
95
+
96
+ return param, total_laj
97
+
98
+ def _transform_leaf(self, x: Array) -> Tuple[Array, Scalar]:
99
+ """
100
+ Internal function that applies a log-transformed simplex constraint on a single array.
101
+ """
102
+ laj: Scalar = jnp.array(0.0)
103
+
104
+ # Save output shape
105
+ output_shape: tuple[int, ...] = x.shape
106
+
107
+ if x.size == 1:
108
+ return(jnp.full(output_shape, jnp.log(self.sum)), laj)
109
+ else:
110
+ # Flatten x
111
+ x = x.flatten()
112
+
113
+ # Subset first K - 1 elements
114
+ x = x[:-1]
115
+
116
+ # Compute shifted cumulative sum
117
+ zeta: Array = jnp.concat([jnp.zeros(1), x.cumsum()[:-1]])
118
+
119
+ # Compute intermediate proportions vector
120
+ eta: Array = jnn.sigmoid(x - zeta)
121
+
122
+ # Compute Jacobian adjustment
123
+ laj += jnp.sum(jnp.log(eta) + jnp.log(1 - eta)) # TODO: check for correctness
124
+
125
+ # Compute log-transformed simplex weights
126
+ w: Array = jnp.log(eta) + jnp.concatenate([jnp.array([0.0]), jnp.log(jnp.cumprod((1-eta)[:-1]))])
127
+ w = jnp.concatenate([w, jnp.log(jnp.prod(1 - eta, keepdims=True))])
128
+
129
+ # Scale unit simplex on log-scale
130
+ w = w + jnp.log(self.sum)
131
+
132
+ # Reshape for output
133
+ w = w.reshape(output_shape)
134
+
135
+ return (w, laj)
@@ -1,6 +1,6 @@
1
1
  from ._constraint import Constraint
2
2
  from ._flow import Flow
3
- from ._model import Model, constrain
3
+ from ._model import Model, define
4
4
  from ._optimization import optimize_model
5
5
  from ._parameter import Parameter
6
6
  from ._variational import Variational
@@ -9,7 +9,7 @@ __all__ = [
9
9
  "Constraint",
10
10
  "Flow",
11
11
  "Model",
12
- "constrain",
12
+ "define",
13
13
  "optimize_model",
14
14
  "Parameter",
15
15
  "Variational",
@@ -37,7 +37,6 @@ class Flow(eqx.Module):
37
37
 
38
38
  # Default filter specification
39
39
  @property
40
- @eqx.filter_jit
41
40
  def filter_spec(self):
42
41
  """
43
42
  Generates a filter specification to subset relevant parameters for the flow.
@@ -54,7 +53,6 @@ class Flow(eqx.Module):
54
53
 
55
54
  return filter_spec
56
55
 
57
- @eqx.filter_jit
58
56
  def constrain_params(self: Self):
59
57
  """
60
58
  Constrain `params` to the appropriate domain.
@@ -69,7 +67,6 @@ class Flow(eqx.Module):
69
67
 
70
68
  return t_params
71
69
 
72
- @eqx.filter_jit
73
70
  def transform_params(self: Self) -> Dict[str, Array]:
74
71
  """
75
72
  Apply a custom transformation to `params` if needed.
@@ -1,19 +1,34 @@
1
1
  from abc import abstractmethod
2
2
  from dataclasses import field, fields
3
- from typing import Any, Self, Tuple
3
+ from typing import Any, Dict, Optional, Self, Tuple
4
4
 
5
5
  import equinox as eqx
6
6
  import jax.numpy as jnp
7
7
  import jax.tree as jt
8
- from jaxtyping import Scalar
8
+ from jaxtyping import PyTree, Scalar
9
9
 
10
10
  from ._constraint import Constraint
11
11
  from ._parameter import Parameter
12
12
 
13
13
 
14
- def constrain(constraint: Constraint):
15
- """Defines constraint metadata."""
16
- return field(metadata={"constraint": constraint})
14
+ def define(
15
+ shape: Optional[Tuple[int, ...]] = None,
16
+ init: Optional[PyTree] = None,
17
+ constraint: Optional[Constraint] = None
18
+ ):
19
+ """Define a parameter."""
20
+ metadata: Dict = {}
21
+ if constraint is not None:
22
+ metadata["constraint"] = constraint
23
+
24
+ if isinstance(shape, Tuple):
25
+ metadata["shape"] = shape
26
+ elif isinstance(init, PyTree):
27
+ metadata["init"] = init
28
+ else:
29
+ raise TypeError("Neither 'shape' nor 'init' were given as proper arguments.")
30
+
31
+ return field(metadata = metadata)
17
32
 
18
33
 
19
34
  class Model(eqx.Module):
@@ -22,16 +37,32 @@ class Model(eqx.Module):
22
37
 
23
38
  Annotate parameter attributes with `Parameter`.
24
39
 
25
- Include constraints by setting them equal to `constrain(Constraint)`.
40
+ Include constraints by setting them equal to `define(Constraint)`.
26
41
  """
27
42
 
43
+ def __new__(cls, *args, **kwargs):
44
+ obj = super().__new__(cls)
45
+
46
+ # Auto-initialize parameters based on `define` metadata
47
+ for f in fields(cls):
48
+ if "shape" in f.metadata:
49
+ # Construct jax Array with correct dimensions
50
+ setattr(obj, f.name, Parameter(jnp.zeros(f.metadata["shape"])))
51
+ elif "init" in f.metadata:
52
+ # Slot in given 'init' object
53
+ setattr(obj, f.name, Parameter(f.metadata["init"]))
54
+
55
+ return obj
56
+
57
+ def __init__(self):
58
+ return self
59
+
28
60
  @abstractmethod
29
61
  def eval(self, data: Any) -> Scalar:
30
62
  pass
31
63
 
32
64
  # Default filter specification
33
65
  @property
34
- @eqx.filter_jit
35
66
  def filter_spec(self) -> Self:
36
67
  """
37
68
  Generates a filter specification to subset relevant parameters for the model.
@@ -49,12 +80,11 @@ class Model(eqx.Module):
49
80
  filter_spec = eqx.tree_at(
50
81
  lambda model: getattr(model, f.name),
51
82
  filter_spec,
52
- replace=attr.filter_spec,
83
+ replace=attr.filter_spec
53
84
  )
54
85
 
55
86
  return filter_spec
56
87
 
57
- @eqx.filter_jit
58
88
  def constrain_params(self) -> Tuple[Self, Scalar]:
59
89
  """
60
90
  Constrain parameters to the appropriate domain.
@@ -70,14 +100,14 @@ class Model(eqx.Module):
70
100
  attr = getattr(self, f.name)
71
101
 
72
102
  # Check if constrained parameter
73
- if isinstance(attr, Parameter) and "constraint" in f.metadata:
103
+ if isinstance(attr, Parameter) and ("constraint" in f.metadata):
74
104
  param = attr
75
105
  constraint = f.metadata["constraint"]
76
106
 
77
107
  # Apply constraint
78
108
  param, laj = constraint.constrain(param)
79
109
 
80
- # Update parameters for constrained model
110
+ # Update parameters for constrained model at same node
81
111
  constrained = eqx.tree_at(
82
112
  lambda model: getattr(model, f.name), constrained, replace=param
83
113
  )
@@ -87,7 +117,6 @@ class Model(eqx.Module):
87
117
 
88
118
  return constrained, target
89
119
 
90
- @eqx.filter_jit
91
120
  def transform_params(self) -> Tuple[Self, Scalar]:
92
121
  """
93
122
  Apply a custom transformation to parameters if needed(defaults to constrained parameters).
@@ -10,6 +10,8 @@ from optax import GradientTransformation, OptState, Schedule
10
10
  from ._model import Model
11
11
 
12
12
  M = TypeVar("M", bound=Model)
13
+
14
+
13
15
  @eqx.filter_jit
14
16
  def optimize_model(
15
17
  model: M,
@@ -39,6 +41,7 @@ def optimize_model(
39
41
 
40
42
  # Evaluate posterior
41
43
  return model.eval(data)
44
+
42
45
  eval_grad: Callable[[M], M] = eqx.filter_jit(eqx.filter_grad(eval))
43
46
 
44
47
  # Construct scheduler
@@ -5,6 +5,8 @@ import jax.tree as jt
5
5
  from jaxtyping import PyTree
6
6
 
7
7
  T = TypeVar("T", bound=PyTree)
8
+
9
+
8
10
  class Parameter(eqx.Module, Generic[T]):
9
11
  """
10
12
  A container for a parameter of a `Model`.
@@ -26,19 +28,18 @@ class Parameter(eqx.Module, Generic[T]):
26
28
 
27
29
  # Default filter specification
28
30
  @property
29
- @eqx.filter_jit
30
31
  def filter_spec(self) -> Self:
31
32
  """
32
- Generates a filter specification to filter out static parameters.
33
+ Generates a filter specification to filter for dynamic parameters.
33
34
  """
34
35
  # Generate empty specification
35
- filter_spec = jt.map(lambda _: False, self)
36
+ filter_spec: Self = jt.map(lambda _: False, self)
36
37
 
37
- # Specify Array leaves
38
+ # Specify Array-like leaves
38
39
  filter_spec = eqx.tree_at(
39
- lambda params: params.vals,
40
+ lambda param: param.vals,
40
41
  filter_spec,
41
- replace=jt.map(eqx.is_array_like, self.vals),
42
+ replace=jt.map(eqx.is_inexact_array_like, self.vals),
42
43
  )
43
44
 
44
45
  return filter_spec
@@ -22,11 +22,11 @@ class Variational(eqx.Module, Generic[M]):
22
22
 
23
23
  # Attributes
24
24
  - `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
25
- - `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
25
+ - `_static`: The static component of a partitioned `Model` used to initialize the `Variational` object.
26
26
  """
27
27
 
28
28
  _unflatten: Callable[[Array], M]
29
- _constraints: M
29
+ _static: M
30
30
 
31
31
  @abstractmethod
32
32
  def filter_spec(self):
@@ -69,7 +69,7 @@ class Variational(eqx.Module, Generic[M]):
69
69
  model: M = self._unflatten(draw)
70
70
 
71
71
  # Combine with constraints
72
- model: M = eqx.combine(model, self._constraints)
72
+ model: M = eqx.combine(model, self._static)
73
73
 
74
74
  return model
75
75
 
@@ -89,7 +89,6 @@ class Variational(eqx.Module, Generic[M]):
89
89
  # Evaluate posterior density
90
90
  return model.eval(data)
91
91
 
92
- # TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
93
92
  @eqx.filter_jit
94
93
  def fit(
95
94
  self,
@@ -116,11 +115,9 @@ class Variational(eqx.Module, Generic[M]):
116
115
  dyn, static = eqx.partition(self, self.filter_spec)
117
116
 
118
117
  # Construct scheduler
119
- schedule: Schedule = opx.warmup_cosine_decay_schedule(
120
- init_value=1e-16,
121
- peak_value=learning_rate,
122
- warmup_steps=int(max_iters / 10),
123
- decay_steps=max_iters - int(max_iters / 10),
118
+ schedule: Schedule = opx.cosine_decay_schedule(
119
+ init_value=learning_rate,
120
+ decay_steps=max_iters,
124
121
  )
125
122
 
126
123
  # Initialize optimizer
@@ -175,7 +172,7 @@ class Variational(eqx.Module, Generic[M]):
175
172
  return eqx.combine(dyn, static)
176
173
 
177
174
  @eqx.filter_jit
178
- def posterior_predictive(
175
+ def _posterior_predictive(
179
176
  self,
180
177
  func: Callable[[M, Any], Array],
181
178
  n: int,
@@ -111,6 +111,11 @@ def sample(
111
111
 
112
112
  # Construct draws
113
113
  draws = jr.uniform(key, shape)
114
- draws = mu + sigma * ndtri(normal.cdf(-mu/sigma, 0.0, 1.0) + draws * normal.cdf(mu/sigma, 0.0, 1.0))
114
+ draws = mu + sigma * ndtri(
115
+ normal.cdf(-mu / sigma, 0.0, 1.0) + draws * normal.cdf(mu / sigma, 0.0, 1.0)
116
+ )
117
+
118
+ # Censor draws
119
+ draws.at[censor <= draws].set(censor)
115
120
 
116
121
  return draws
@@ -0,0 +1,113 @@
1
+ import jax.numpy as jnp
2
+ from jax.scipy.special import gammaln
3
+ from jaxtyping import Array, ArrayLike, Float, UInt
4
+
5
+ __PI = 3.141592653589793
6
+
7
+
8
+ def __binom(x, y):
9
+ """
10
+ Helper function for the Binomial coefficient.
11
+ """
12
+ return jnp.exp(gammaln(x + 1) - gammaln(y + 1) - gammaln(x - y + 1))
13
+
14
+
15
+ def prob(
16
+ x: UInt[ArrayLike, "..."],
17
+ mu: Float[ArrayLike, "..."],
18
+ phi: Float[ArrayLike, "..."],
19
+ ) -> Float[Array, "..."]:
20
+ """
21
+ The probability mass function (PMF) for a (mean-inverse overdispersion parameterized) Negatvie Binomial distribution.
22
+
23
+ # Parameters
24
+ - `x`: Where to evaluate the PMF.
25
+ - `mu`: The mean.
26
+ - `phi`: The inverse overdispersion.
27
+
28
+ # Returns
29
+ The PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `phi`.
30
+ """
31
+ # Cast to Array
32
+ x, mu, phi = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(phi)
33
+
34
+ return jnp.exp(logprob(x, mu, phi))
35
+
36
+
37
+ def logprob(
38
+ x: UInt[ArrayLike, "..."],
39
+ mu: Float[ArrayLike, "..."],
40
+ phi: Float[ArrayLike, "..."],
41
+ ) -> Float[Array, "..."]:
42
+ """
43
+ The log-transformed probability mass function (PMF) for a (mean-inverse overdispersion parameterized) Negatvie Binomial distribution.
44
+
45
+ # Parameters
46
+ - `x`: Where to evaluate the log PMF.
47
+ - `mu`: The mean.
48
+ - `phi`: The inverse overdispersion.
49
+
50
+ # Returns
51
+ The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `phi`.
52
+ """
53
+ # Cast to Array
54
+ x, mu, phi = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(phi)
55
+
56
+ # Evaluate log PMF
57
+ evals: Array = jnp.where(
58
+ x < 0,
59
+ -jnp.inf,
60
+ (
61
+ gammaln(x + phi)
62
+ - gammaln(x + 1)
63
+ - gammaln(phi)
64
+ + x * (jnp.log(mu) - jnp.log(mu + phi))
65
+ + phi * (jnp.log(phi) - jnp.log(mu + phi))
66
+ ),
67
+ )
68
+
69
+ return evals
70
+
71
+
72
+ def cdf(
73
+ x: Float[ArrayLike, "..."],
74
+ mu: Float[ArrayLike, "..."],
75
+ sigma: Float[ArrayLike, "..."],
76
+ ) -> Float[Array, "..."]:
77
+ # Cast to Array
78
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
79
+
80
+ return jnp.array(1.0)
81
+
82
+
83
+ def logcdf(
84
+ x: Float[ArrayLike, "..."],
85
+ mu: Float[ArrayLike, "..."],
86
+ sigma: Float[ArrayLike, "..."],
87
+ ) -> Float[Array, "..."]:
88
+ # Cast to Array
89
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
90
+
91
+ return jnp.array(1.0)
92
+
93
+
94
+ def ccdf(
95
+ x: Float[ArrayLike, "..."],
96
+ mu: Float[ArrayLike, "..."],
97
+ sigma: Float[ArrayLike, "..."],
98
+ ) -> Float[Array, "..."]:
99
+ # Cast to Array
100
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
101
+
102
+ return jnp.array(1.0)
103
+
104
+
105
+ def logccdf(
106
+ x: Float[ArrayLike, "..."],
107
+ mu: Float[ArrayLike, "..."],
108
+ sigma: Float[ArrayLike, "..."],
109
+ ) -> Float[Array, "..."]:
110
+ # Cast to Array
111
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
112
+
113
+ return jnp.array(1.0)
@@ -1,5 +1,7 @@
1
1
  import jax.numpy as jnp
2
- from jaxtyping import Array, ArrayLike, Float
2
+ import jax.random as jr
3
+ from jax.scipy.special import ndtri
4
+ from jaxtyping import Array, ArrayLike, Float, Key
3
5
 
4
6
  from bayinx.dists import normal
5
7
 
@@ -251,3 +253,38 @@ def logccdf(
251
253
  evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
252
254
 
253
255
  return evals
256
+
257
+
258
+ def sample(
259
+ n: int,
260
+ mu: Float[ArrayLike, "..."],
261
+ sigma: Float[ArrayLike, "..."],
262
+ key: Key = jr.PRNGKey(0),
263
+ ) -> Float[Array, "..."]:
264
+ """
265
+ Sample from a positive Normal distribution.
266
+
267
+ # Parameters
268
+ - `n`: Number of draws to sample per-parameter.
269
+ - `mu`: The mean.
270
+ - `sigma`: The standard deviation.
271
+
272
+ # Returns
273
+ Draws from a positive Normal distribution. The output will have the shape of (n,) + the broadcasted shapes of `mu` and `sigma`.
274
+ """
275
+ # Cast to Array
276
+ mu, sigma = (
277
+ jnp.asarray(mu),
278
+ jnp.asarray(sigma),
279
+ )
280
+
281
+ # Derive shape
282
+ shape = (n,) + jnp.broadcast_shapes(mu.shape, sigma.shape)
283
+
284
+ # Construct draws
285
+ draws = jr.uniform(key, shape)
286
+ draws = mu + sigma * ndtri(
287
+ normal.cdf(-mu / sigma, 0.0, 1.0) + draws * normal.cdf(mu / sigma, 0.0, 1.0)
288
+ )
289
+
290
+ return draws