bayinx 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayinx might be problematic. Click here for more details.
bayinx/constraints.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from typing import Tuple
|
|
2
2
|
|
|
3
3
|
import equinox as eqx
|
|
4
|
+
import jax.nn as jnn
|
|
4
5
|
import jax.numpy as jnp
|
|
5
6
|
import jax.tree as jt
|
|
6
|
-
from jaxtyping import Scalar, ScalarLike
|
|
7
|
+
from jaxtyping import Array, PyTree, Scalar, ScalarLike
|
|
7
8
|
|
|
8
9
|
from bayinx.core import Constraint, Parameter
|
|
9
10
|
from bayinx.core._parameter import T
|
|
@@ -17,7 +18,8 @@ class Lower(Constraint):
|
|
|
17
18
|
lb: Scalar
|
|
18
19
|
|
|
19
20
|
def __init__(self, lb: ScalarLike):
|
|
20
|
-
|
|
21
|
+
# assert greater than 1
|
|
22
|
+
self.lb = jnp.asarray(lb)
|
|
21
23
|
|
|
22
24
|
def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
|
|
23
25
|
"""
|
|
@@ -26,16 +28,16 @@ class Lower(Constraint):
|
|
|
26
28
|
# Parameters
|
|
27
29
|
- `param`: The unconstrained `Parameter`.
|
|
28
30
|
|
|
29
|
-
#
|
|
31
|
+
# Returns
|
|
30
32
|
A tuple containing:
|
|
31
33
|
- A modified `Parameter` with relevant leaves satisfying the constraint.
|
|
32
34
|
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
|
33
35
|
"""
|
|
34
|
-
# Extract relevant parameters(all
|
|
36
|
+
# Extract relevant parameters(all inexact Arrays)
|
|
35
37
|
dyn, static = eqx.partition(param, param.filter_spec)
|
|
36
38
|
|
|
37
|
-
# Compute
|
|
38
|
-
|
|
39
|
+
# Compute Jacobian adjustment
|
|
40
|
+
total_laj: Scalar = jt.reduce(lambda a, b: a + b, jt.map(jnp.sum, dyn))
|
|
39
41
|
|
|
40
42
|
# Compute transformation
|
|
41
43
|
dyn = jt.map(lambda v: jnp.exp(v) + self.lb, dyn)
|
|
@@ -43,4 +45,91 @@ class Lower(Constraint):
|
|
|
43
45
|
# Combine into full parameter object
|
|
44
46
|
param = eqx.combine(dyn, static)
|
|
45
47
|
|
|
46
|
-
return param,
|
|
48
|
+
return param, total_laj
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class LogSimplex(Constraint):
|
|
52
|
+
"""
|
|
53
|
+
Enforces a log-transformed simplex constraint on the parameter.
|
|
54
|
+
|
|
55
|
+
# Attributes
|
|
56
|
+
- `sum`: The total sum of the parameter.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
sum: Scalar
|
|
60
|
+
|
|
61
|
+
def __init__(self, sum_val: ScalarLike = 1.0):
|
|
62
|
+
"""
|
|
63
|
+
# Parameters
|
|
64
|
+
- `sum_val`: The target sum for the exponentiated simplex. Defaults to 1.0.
|
|
65
|
+
"""
|
|
66
|
+
self.sum = jnp.asarray(sum_val)
|
|
67
|
+
|
|
68
|
+
def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
|
|
69
|
+
"""
|
|
70
|
+
Enforces a log-transformed simplex constraint on the parameter and adjusts the posterior density.
|
|
71
|
+
|
|
72
|
+
# Parameters
|
|
73
|
+
- `param`: The unconstrained `Parameter`.
|
|
74
|
+
|
|
75
|
+
# Returns
|
|
76
|
+
A tuple containing:
|
|
77
|
+
- A modified `Parameter` with relevant leaves satisfying the constraint.
|
|
78
|
+
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
|
79
|
+
"""
|
|
80
|
+
# Partition the parameter into dynamic (to be transformed) and static parts
|
|
81
|
+
dyn, static = eqx.partition(param, param.filter_spec)
|
|
82
|
+
|
|
83
|
+
# Map transformation leaf-wise
|
|
84
|
+
transformed = jt.map(self._transform_leaf, dyn) ## filter spec handles subsetting arrays, is_leaf unnecessary
|
|
85
|
+
|
|
86
|
+
# Extract constrained parameters and Jacobian adjustments
|
|
87
|
+
dyn_constrained: PyTree = jt.map(lambda x: x[0], transformed)
|
|
88
|
+
lajs: PyTree = jt.map(lambda x: x[1], transformed)
|
|
89
|
+
|
|
90
|
+
# Sum to get total Jacobian adjustment
|
|
91
|
+
total_laj = jt.reduce(lambda a, b: a + b, lajs)
|
|
92
|
+
|
|
93
|
+
# Recombine the transformed dynamic parts with the static parts
|
|
94
|
+
param = eqx.combine(dyn_constrained, static)
|
|
95
|
+
|
|
96
|
+
return param, total_laj
|
|
97
|
+
|
|
98
|
+
def _transform_leaf(self, x: Array) -> Tuple[Array, Scalar]:
|
|
99
|
+
"""
|
|
100
|
+
Internal function that applies a log-transformed simplex constraint on a single array.
|
|
101
|
+
"""
|
|
102
|
+
laj: Scalar = jnp.array(0.0)
|
|
103
|
+
|
|
104
|
+
# Save output shape
|
|
105
|
+
output_shape: tuple[int, ...] = x.shape
|
|
106
|
+
|
|
107
|
+
if x.size == 1:
|
|
108
|
+
return(jnp.full(output_shape, jnp.log(self.sum)), laj)
|
|
109
|
+
else:
|
|
110
|
+
# Flatten x
|
|
111
|
+
x = x.flatten()
|
|
112
|
+
|
|
113
|
+
# Subset first K - 1 elements
|
|
114
|
+
x = x[:-1]
|
|
115
|
+
|
|
116
|
+
# Compute shifted cumulative sum
|
|
117
|
+
zeta: Array = jnp.concat([jnp.zeros(1), x.cumsum()[:-1]])
|
|
118
|
+
|
|
119
|
+
# Compute intermediate proportions vector
|
|
120
|
+
eta: Array = jnn.sigmoid(x - zeta)
|
|
121
|
+
|
|
122
|
+
# Compute Jacobian adjustment
|
|
123
|
+
laj += jnp.sum(jnp.log(eta) + jnp.log(1 - eta)) # TODO: check for correctness
|
|
124
|
+
|
|
125
|
+
# Compute log-transformed simplex weights
|
|
126
|
+
w: Array = jnp.log(eta) + jnp.concatenate([jnp.array([0.0]), jnp.log(jnp.cumprod((1-eta)[:-1]))])
|
|
127
|
+
w = jnp.concatenate([w, jnp.log(jnp.prod(1 - eta, keepdims=True))])
|
|
128
|
+
|
|
129
|
+
# Scale unit simplex on log-scale
|
|
130
|
+
w = w + jnp.log(self.sum)
|
|
131
|
+
|
|
132
|
+
# Reshape for output
|
|
133
|
+
w = w.reshape(output_shape)
|
|
134
|
+
|
|
135
|
+
return (w, laj)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayinx
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: Bayesian Inference with JAX
|
|
5
5
|
Author: Todd McCready
|
|
6
6
|
Maintainer: Todd McCready
|
|
@@ -12,13 +12,13 @@ Requires-Dist: jaxtyping>=0.2.36
|
|
|
12
12
|
Requires-Dist: optax>=0.2.4
|
|
13
13
|
Description-Content-Type: text/markdown
|
|
14
14
|
|
|
15
|
-
#
|
|
15
|
+
# Bayinx: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
|
16
16
|
|
|
17
|
-
The original aim of this project was to build a PPL in Python that is similar in feel to
|
|
17
|
+
The original aim of this project was to build a PPL in Python that is similar in feel to Stan or Nimble(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
|
|
18
18
|
|
|
19
|
-
Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to
|
|
19
|
+
Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to Stan and parse it to valid Rust code. Additionally, the current state of bayinx is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite disize in Python with JAX, and bayinx makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
|
|
20
20
|
|
|
21
|
-
Instead, this project is narrowing on implementing much of
|
|
21
|
+
Instead, this project is narrowing on implementing much of Stan's functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
|
|
22
22
|
|
|
23
23
|
```py
|
|
24
24
|
class NormalDist(Model):
|
|
@@ -34,7 +34,7 @@ class NormalDist(Model):
|
|
|
34
34
|
return target
|
|
35
35
|
```
|
|
36
36
|
|
|
37
|
-
I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for
|
|
37
|
+
I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for disize.
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
# TODO
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
bayinx/__init__.py,sha256=8etrxEtEGEzSDmKsW0TB4XoUGLiMPt9wpwNR8CGe1gU,93
|
|
2
|
-
bayinx/constraints.py,sha256=
|
|
2
|
+
bayinx/constraints.py,sha256=2ufHsXR-_bWKR4WKKuR-OTjj3XCc4TkSeHVGWYadwCg,4387
|
|
3
3
|
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
bayinx/core/__init__.py,sha256=samkrHp2zYyj8n37k-06tlaVrSqbtcgoa1LO0btAEHc,338
|
|
5
5
|
bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
|
|
@@ -32,7 +32,7 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=7KXuukwzVtMRIa8bSK_4pjnnP-lLIzVJBCAuKVy
|
|
|
32
32
|
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
|
33
33
|
bayinx/mhx/vi/flows/radial.py,sha256=AyaqLJCwn871L6E8lBCU4Y8zZBF9UYZu6KIhzV6Z6wo,2503
|
|
34
34
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
|
35
|
-
bayinx-0.4.
|
|
36
|
-
bayinx-0.4.
|
|
37
|
-
bayinx-0.4.
|
|
38
|
-
bayinx-0.4.
|
|
35
|
+
bayinx-0.4.1.dist-info/METADATA,sha256=7Zw-9hVqUVxj3ncyGBfn72FQzTDomdDaXXH2hOsJM60,2989
|
|
36
|
+
bayinx-0.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
37
|
+
bayinx-0.4.1.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
|
38
|
+
bayinx-0.4.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|