bayinx 0.4.0__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayinx might be problematic. Click here for more details.

Files changed (44) hide show
  1. {bayinx-0.4.0 → bayinx-0.4.1}/PKG-INFO +6 -6
  2. {bayinx-0.4.0 → bayinx-0.4.1}/README.md +5 -5
  3. {bayinx-0.4.0 → bayinx-0.4.1}/pyproject.toml +2 -2
  4. bayinx-0.4.1/src/bayinx/constraints.py +135 -0
  5. bayinx-0.4.0/src/bayinx/constraints.py +0 -46
  6. {bayinx-0.4.0 → bayinx-0.4.1}/.github/workflows/release_and_publish.yml +0 -0
  7. {bayinx-0.4.0 → bayinx-0.4.1}/.gitignore +0 -0
  8. {bayinx-0.4.0 → bayinx-0.4.1}/LICENSE +0 -0
  9. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/__init__.py +0 -0
  10. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/core/__init__.py +0 -0
  11. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/core/_constraint.py +0 -0
  12. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/core/_flow.py +0 -0
  13. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/core/_model.py +0 -0
  14. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/core/_optimization.py +0 -0
  15. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/core/_parameter.py +0 -0
  16. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/core/_variational.py +0 -0
  17. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/__init__.py +0 -0
  18. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/bernoulli.py +0 -0
  19. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/censored/__init__.py +0 -0
  20. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/censored/gamma2/__init__.py +0 -0
  21. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/censored/gamma2/r.py +0 -0
  22. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/censored/negbinom3/r.py +0 -0
  23. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/censored/posnormal/__init__.py +0 -0
  24. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/censored/posnormal/r.py +0 -0
  25. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/gamma2.py +0 -0
  26. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/negbinom3.py +0 -0
  27. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/normal.py +0 -0
  28. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/posnormal.py +0 -0
  29. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/dists/uniform.py +0 -0
  30. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/__init__.py +0 -0
  31. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/opt/__init__.py +0 -0
  32. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/vi/__init__.py +0 -0
  33. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  34. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
  35. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  36. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  37. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  38. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/vi/meanfield.py +0 -0
  39. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/vi/normalizing_flow.py +0 -0
  40. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/mhx/vi/standard.py +0 -0
  41. {bayinx-0.4.0 → bayinx-0.4.1}/src/bayinx/py.typed +0 -0
  42. {bayinx-0.4.0 → bayinx-0.4.1}/tests/__init__.py +0 -0
  43. {bayinx-0.4.0 → bayinx-0.4.1}/tests/test_variational.py +0 -0
  44. {bayinx-0.4.0 → bayinx-0.4.1}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.4.0
3
+ Version: 0.4.1
4
4
  Summary: Bayesian Inference with JAX
5
5
  Author: Todd McCready
6
6
  Maintainer: Todd McCready
@@ -12,13 +12,13 @@ Requires-Dist: jaxtyping>=0.2.36
12
12
  Requires-Dist: optax>=0.2.4
13
13
  Description-Content-Type: text/markdown
14
14
 
15
- # `Bayinx`: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
15
+ # Bayinx: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
16
16
 
17
- The original aim of this project was to build a PPL in Python that is similar in feel to `Stan` or `Nimble`(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
17
+ The original aim of this project was to build a PPL in Python that is similar in feel to Stan or Nimble(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
18
18
 
19
- Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to `Stan` and parse it to valid Rust code. Additionally, the current state of `bayinx` is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough functionality for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite `disize` in Python with JAX, and `bayinx` makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
19
+ Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to Stan and parse it to valid Rust code. Additionally, the current state of bayinx is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite disize in Python with JAX, and bayinx makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
20
20
 
21
- Instead, this project is narrowing on implementing much of `Stan`'s functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
21
+ Instead, this project is narrowing on implementing much of Stan's functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
22
22
 
23
23
  ```py
24
24
  class NormalDist(Model):
@@ -34,7 +34,7 @@ class NormalDist(Model):
34
34
  return target
35
35
  ```
36
36
 
37
- I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for `disize`.
37
+ I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for disize.
38
38
 
39
39
 
40
40
  # TODO
@@ -1,10 +1,10 @@
1
- # `Bayinx`: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
1
+ # Bayinx: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
2
2
 
3
- The original aim of this project was to build a PPL in Python that is similar in feel to `Stan` or `Nimble`(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
3
+ The original aim of this project was to build a PPL in Python that is similar in feel to Stan or Nimble(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
4
4
 
5
- Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to `Stan` and parse it to valid Rust code. Additionally, the current state of `bayinx` is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough functionality for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite `disize` in Python with JAX, and `bayinx` makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
5
+ Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to Stan and parse it to valid Rust code. Additionally, the current state of bayinx is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite disize in Python with JAX, and bayinx makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
6
6
 
7
- Instead, this project is narrowing on implementing much of `Stan`'s functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
7
+ Instead, this project is narrowing on implementing much of Stan's functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
8
8
 
9
9
  ```py
10
10
  class NormalDist(Model):
@@ -20,7 +20,7 @@ class NormalDist(Model):
20
20
  return target
21
21
  ```
22
22
 
23
- I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for `disize`.
23
+ I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for disize.
24
24
 
25
25
 
26
26
  # TODO
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.4.0"
3
+ version = "0.4.1"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -21,7 +21,7 @@ build-backend = "hatchling.build"
21
21
  addopts = "-q --benchmark-min-rounds=30 --benchmark-columns=rounds,mean,median,stddev --benchmark-group-by=func"
22
22
 
23
23
  [tool.bumpversion]
24
- current_version = "0.4.0"
24
+ current_version = "0.4.1"
25
25
  parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
26
26
  serialize = ["{major}.{minor}.{patch}"]
27
27
  search = "{current_version}"
@@ -0,0 +1,135 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.nn as jnn
5
+ import jax.numpy as jnp
6
+ import jax.tree as jt
7
+ from jaxtyping import Array, PyTree, Scalar, ScalarLike
8
+
9
+ from bayinx.core import Constraint, Parameter
10
+ from bayinx.core._parameter import T
11
+
12
+
13
+ class Lower(Constraint):
14
+ """
15
+ Enforces a lower bound on the parameter.
16
+ """
17
+
18
+ lb: Scalar
19
+
20
+ def __init__(self, lb: ScalarLike):
21
+ # assert greater than 1
22
+ self.lb = jnp.asarray(lb)
23
+
24
+ def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
25
+ """
26
+ Enforces a lower bound on the parameter and adjusts the posterior density.
27
+
28
+ # Parameters
29
+ - `param`: The unconstrained `Parameter`.
30
+
31
+ # Returns
32
+ A tuple containing:
33
+ - A modified `Parameter` with relevant leaves satisfying the constraint.
34
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
35
+ """
36
+ # Extract relevant parameters(all inexact Arrays)
37
+ dyn, static = eqx.partition(param, param.filter_spec)
38
+
39
+ # Compute Jacobian adjustment
40
+ total_laj: Scalar = jt.reduce(lambda a, b: a + b, jt.map(jnp.sum, dyn))
41
+
42
+ # Compute transformation
43
+ dyn = jt.map(lambda v: jnp.exp(v) + self.lb, dyn)
44
+
45
+ # Combine into full parameter object
46
+ param = eqx.combine(dyn, static)
47
+
48
+ return param, total_laj
49
+
50
+
51
+ class LogSimplex(Constraint):
52
+ """
53
+ Enforces a log-transformed simplex constraint on the parameter.
54
+
55
+ # Attributes
56
+ - `sum`: The total sum of the parameter.
57
+ """
58
+
59
+ sum: Scalar
60
+
61
+ def __init__(self, sum_val: ScalarLike = 1.0):
62
+ """
63
+ # Parameters
64
+ - `sum_val`: The target sum for the exponentiated simplex. Defaults to 1.0.
65
+ """
66
+ self.sum = jnp.asarray(sum_val)
67
+
68
+ def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
69
+ """
70
+ Enforces a log-transformed simplex constraint on the parameter and adjusts the posterior density.
71
+
72
+ # Parameters
73
+ - `param`: The unconstrained `Parameter`.
74
+
75
+ # Returns
76
+ A tuple containing:
77
+ - A modified `Parameter` with relevant leaves satisfying the constraint.
78
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
79
+ """
80
+ # Partition the parameter into dynamic (to be transformed) and static parts
81
+ dyn, static = eqx.partition(param, param.filter_spec)
82
+
83
+ # Map transformation leaf-wise
84
+ transformed = jt.map(self._transform_leaf, dyn) ## filter spec handles subsetting arrays, is_leaf unnecessary
85
+
86
+ # Extract constrained parameters and Jacobian adjustments
87
+ dyn_constrained: PyTree = jt.map(lambda x: x[0], transformed)
88
+ lajs: PyTree = jt.map(lambda x: x[1], transformed)
89
+
90
+ # Sum to get total Jacobian adjustment
91
+ total_laj = jt.reduce(lambda a, b: a + b, lajs)
92
+
93
+ # Recombine the transformed dynamic parts with the static parts
94
+ param = eqx.combine(dyn_constrained, static)
95
+
96
+ return param, total_laj
97
+
98
+ def _transform_leaf(self, x: Array) -> Tuple[Array, Scalar]:
99
+ """
100
+ Internal function that applies a log-transformed simplex constraint on a single array.
101
+ """
102
+ laj: Scalar = jnp.array(0.0)
103
+
104
+ # Save output shape
105
+ output_shape: tuple[int, ...] = x.shape
106
+
107
+ if x.size == 1:
108
+ return(jnp.full(output_shape, jnp.log(self.sum)), laj)
109
+ else:
110
+ # Flatten x
111
+ x = x.flatten()
112
+
113
+ # Subset first K - 1 elements
114
+ x = x[:-1]
115
+
116
+ # Compute shifted cumulative sum
117
+ zeta: Array = jnp.concat([jnp.zeros(1), x.cumsum()[:-1]])
118
+
119
+ # Compute intermediate proportions vector
120
+ eta: Array = jnn.sigmoid(x - zeta)
121
+
122
+ # Compute Jacobian adjustment
123
+ laj += jnp.sum(jnp.log(eta) + jnp.log(1 - eta)) # TODO: check for correctness
124
+
125
+ # Compute log-transformed simplex weights
126
+ w: Array = jnp.log(eta) + jnp.concatenate([jnp.array([0.0]), jnp.log(jnp.cumprod((1-eta)[:-1]))])
127
+ w = jnp.concatenate([w, jnp.log(jnp.prod(1 - eta, keepdims=True))])
128
+
129
+ # Scale unit simplex on log-scale
130
+ w = w + jnp.log(self.sum)
131
+
132
+ # Reshape for output
133
+ w = w.reshape(output_shape)
134
+
135
+ return (w, laj)
@@ -1,46 +0,0 @@
1
- from typing import Tuple
2
-
3
- import equinox as eqx
4
- import jax.numpy as jnp
5
- import jax.tree as jt
6
- from jaxtyping import Scalar, ScalarLike
7
-
8
- from bayinx.core import Constraint, Parameter
9
- from bayinx.core._parameter import T
10
-
11
-
12
- class Lower(Constraint):
13
- """
14
- Enforces a lower bound on the parameter.
15
- """
16
-
17
- lb: Scalar
18
-
19
- def __init__(self, lb: ScalarLike):
20
- self.lb = jnp.array(lb)
21
-
22
- def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
23
- """
24
- Enforces a lower bound on the parameter and adjusts the posterior density.
25
-
26
- # Parameters
27
- - `param`: The unconstrained `Parameter`.
28
-
29
- # Parameters
30
- A tuple containing:
31
- - A modified `Parameter` with relevant leaves satisfying the constraint.
32
- - A scalar Array representing the log-absolute-Jacobian of the transformation.
33
- """
34
- # Extract relevant parameters(all Array)
35
- dyn, static = eqx.partition(param, param.filter_spec)
36
-
37
- # Compute density adjustment
38
- laj: Scalar = jt.reduce(lambda a, b: a + b, jt.map(jnp.sum, dyn))
39
-
40
- # Compute transformation
41
- dyn = jt.map(lambda v: jnp.exp(v) + self.lb, dyn)
42
-
43
- # Combine into full parameter object
44
- param = eqx.combine(dyn, static)
45
-
46
- return param, laj
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes