bayinx 0.3.19__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayinx might be problematic. Click here for more details.
- bayinx/__init__.py +2 -2
- bayinx/constraints.py +135 -0
- bayinx/core/__init__.py +2 -2
- bayinx/core/_flow.py +0 -3
- bayinx/core/_model.py +41 -12
- bayinx/core/_optimization.py +3 -0
- bayinx/core/_parameter.py +7 -6
- bayinx/core/_variational.py +7 -10
- bayinx/dists/censored/negbinom3/r.py +0 -0
- bayinx/dists/censored/posnormal/r.py +6 -1
- bayinx/dists/negbinom3.py +113 -0
- bayinx/dists/posnormal.py +38 -1
- bayinx/dists/uniform.py +6 -2
- bayinx/mhx/vi/flows/fullaffine.py +0 -2
- bayinx/mhx/vi/flows/radial.py +1 -1
- bayinx/mhx/vi/meanfield.py +3 -6
- bayinx/mhx/vi/normalizing_flow.py +4 -5
- bayinx/mhx/vi/standard.py +2 -1
- bayinx-0.4.1.dist-info/METADATA +47 -0
- bayinx-0.4.1.dist-info/RECORD +38 -0
- bayinx/constraints/__init__.py +0 -3
- bayinx/constraints/lower.py +0 -50
- bayinx-0.3.19.dist-info/METADATA +0 -39
- bayinx-0.3.19.dist-info/RECORD +0 -37
- {bayinx-0.3.19.dist-info → bayinx-0.4.1.dist-info}/WHEEL +0 -0
- {bayinx-0.3.19.dist-info → bayinx-0.4.1.dist-info}/licenses/LICENSE +0 -0
bayinx/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
from bayinx.core import Model, Parameter,
|
|
1
|
+
from bayinx.core import Model, Parameter, define
|
|
2
2
|
|
|
3
|
-
__all__ = ["Model", "Parameter", "
|
|
3
|
+
__all__ = ["Model", "Parameter", "define"]
|
bayinx/constraints.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax.nn as jnn
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.tree as jt
|
|
7
|
+
from jaxtyping import Array, PyTree, Scalar, ScalarLike
|
|
8
|
+
|
|
9
|
+
from bayinx.core import Constraint, Parameter
|
|
10
|
+
from bayinx.core._parameter import T
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Lower(Constraint):
|
|
14
|
+
"""
|
|
15
|
+
Enforces a lower bound on the parameter.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
lb: Scalar
|
|
19
|
+
|
|
20
|
+
def __init__(self, lb: ScalarLike):
|
|
21
|
+
# assert greater than 1
|
|
22
|
+
self.lb = jnp.asarray(lb)
|
|
23
|
+
|
|
24
|
+
def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
|
|
25
|
+
"""
|
|
26
|
+
Enforces a lower bound on the parameter and adjusts the posterior density.
|
|
27
|
+
|
|
28
|
+
# Parameters
|
|
29
|
+
- `param`: The unconstrained `Parameter`.
|
|
30
|
+
|
|
31
|
+
# Returns
|
|
32
|
+
A tuple containing:
|
|
33
|
+
- A modified `Parameter` with relevant leaves satisfying the constraint.
|
|
34
|
+
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
|
35
|
+
"""
|
|
36
|
+
# Extract relevant parameters(all inexact Arrays)
|
|
37
|
+
dyn, static = eqx.partition(param, param.filter_spec)
|
|
38
|
+
|
|
39
|
+
# Compute Jacobian adjustment
|
|
40
|
+
total_laj: Scalar = jt.reduce(lambda a, b: a + b, jt.map(jnp.sum, dyn))
|
|
41
|
+
|
|
42
|
+
# Compute transformation
|
|
43
|
+
dyn = jt.map(lambda v: jnp.exp(v) + self.lb, dyn)
|
|
44
|
+
|
|
45
|
+
# Combine into full parameter object
|
|
46
|
+
param = eqx.combine(dyn, static)
|
|
47
|
+
|
|
48
|
+
return param, total_laj
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class LogSimplex(Constraint):
|
|
52
|
+
"""
|
|
53
|
+
Enforces a log-transformed simplex constraint on the parameter.
|
|
54
|
+
|
|
55
|
+
# Attributes
|
|
56
|
+
- `sum`: The total sum of the parameter.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
sum: Scalar
|
|
60
|
+
|
|
61
|
+
def __init__(self, sum_val: ScalarLike = 1.0):
|
|
62
|
+
"""
|
|
63
|
+
# Parameters
|
|
64
|
+
- `sum_val`: The target sum for the exponentiated simplex. Defaults to 1.0.
|
|
65
|
+
"""
|
|
66
|
+
self.sum = jnp.asarray(sum_val)
|
|
67
|
+
|
|
68
|
+
def constrain(self, param: Parameter[T]) -> Tuple[Parameter[T], Scalar]:
|
|
69
|
+
"""
|
|
70
|
+
Enforces a log-transformed simplex constraint on the parameter and adjusts the posterior density.
|
|
71
|
+
|
|
72
|
+
# Parameters
|
|
73
|
+
- `param`: The unconstrained `Parameter`.
|
|
74
|
+
|
|
75
|
+
# Returns
|
|
76
|
+
A tuple containing:
|
|
77
|
+
- A modified `Parameter` with relevant leaves satisfying the constraint.
|
|
78
|
+
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
|
79
|
+
"""
|
|
80
|
+
# Partition the parameter into dynamic (to be transformed) and static parts
|
|
81
|
+
dyn, static = eqx.partition(param, param.filter_spec)
|
|
82
|
+
|
|
83
|
+
# Map transformation leaf-wise
|
|
84
|
+
transformed = jt.map(self._transform_leaf, dyn) ## filter spec handles subsetting arrays, is_leaf unnecessary
|
|
85
|
+
|
|
86
|
+
# Extract constrained parameters and Jacobian adjustments
|
|
87
|
+
dyn_constrained: PyTree = jt.map(lambda x: x[0], transformed)
|
|
88
|
+
lajs: PyTree = jt.map(lambda x: x[1], transformed)
|
|
89
|
+
|
|
90
|
+
# Sum to get total Jacobian adjustment
|
|
91
|
+
total_laj = jt.reduce(lambda a, b: a + b, lajs)
|
|
92
|
+
|
|
93
|
+
# Recombine the transformed dynamic parts with the static parts
|
|
94
|
+
param = eqx.combine(dyn_constrained, static)
|
|
95
|
+
|
|
96
|
+
return param, total_laj
|
|
97
|
+
|
|
98
|
+
def _transform_leaf(self, x: Array) -> Tuple[Array, Scalar]:
|
|
99
|
+
"""
|
|
100
|
+
Internal function that applies a log-transformed simplex constraint on a single array.
|
|
101
|
+
"""
|
|
102
|
+
laj: Scalar = jnp.array(0.0)
|
|
103
|
+
|
|
104
|
+
# Save output shape
|
|
105
|
+
output_shape: tuple[int, ...] = x.shape
|
|
106
|
+
|
|
107
|
+
if x.size == 1:
|
|
108
|
+
return(jnp.full(output_shape, jnp.log(self.sum)), laj)
|
|
109
|
+
else:
|
|
110
|
+
# Flatten x
|
|
111
|
+
x = x.flatten()
|
|
112
|
+
|
|
113
|
+
# Subset first K - 1 elements
|
|
114
|
+
x = x[:-1]
|
|
115
|
+
|
|
116
|
+
# Compute shifted cumulative sum
|
|
117
|
+
zeta: Array = jnp.concat([jnp.zeros(1), x.cumsum()[:-1]])
|
|
118
|
+
|
|
119
|
+
# Compute intermediate proportions vector
|
|
120
|
+
eta: Array = jnn.sigmoid(x - zeta)
|
|
121
|
+
|
|
122
|
+
# Compute Jacobian adjustment
|
|
123
|
+
laj += jnp.sum(jnp.log(eta) + jnp.log(1 - eta)) # TODO: check for correctness
|
|
124
|
+
|
|
125
|
+
# Compute log-transformed simplex weights
|
|
126
|
+
w: Array = jnp.log(eta) + jnp.concatenate([jnp.array([0.0]), jnp.log(jnp.cumprod((1-eta)[:-1]))])
|
|
127
|
+
w = jnp.concatenate([w, jnp.log(jnp.prod(1 - eta, keepdims=True))])
|
|
128
|
+
|
|
129
|
+
# Scale unit simplex on log-scale
|
|
130
|
+
w = w + jnp.log(self.sum)
|
|
131
|
+
|
|
132
|
+
# Reshape for output
|
|
133
|
+
w = w.reshape(output_shape)
|
|
134
|
+
|
|
135
|
+
return (w, laj)
|
bayinx/core/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from ._constraint import Constraint
|
|
2
2
|
from ._flow import Flow
|
|
3
|
-
from ._model import Model,
|
|
3
|
+
from ._model import Model, define
|
|
4
4
|
from ._optimization import optimize_model
|
|
5
5
|
from ._parameter import Parameter
|
|
6
6
|
from ._variational import Variational
|
|
@@ -9,7 +9,7 @@ __all__ = [
|
|
|
9
9
|
"Constraint",
|
|
10
10
|
"Flow",
|
|
11
11
|
"Model",
|
|
12
|
-
"
|
|
12
|
+
"define",
|
|
13
13
|
"optimize_model",
|
|
14
14
|
"Parameter",
|
|
15
15
|
"Variational",
|
bayinx/core/_flow.py
CHANGED
|
@@ -37,7 +37,6 @@ class Flow(eqx.Module):
|
|
|
37
37
|
|
|
38
38
|
# Default filter specification
|
|
39
39
|
@property
|
|
40
|
-
@eqx.filter_jit
|
|
41
40
|
def filter_spec(self):
|
|
42
41
|
"""
|
|
43
42
|
Generates a filter specification to subset relevant parameters for the flow.
|
|
@@ -54,7 +53,6 @@ class Flow(eqx.Module):
|
|
|
54
53
|
|
|
55
54
|
return filter_spec
|
|
56
55
|
|
|
57
|
-
@eqx.filter_jit
|
|
58
56
|
def constrain_params(self: Self):
|
|
59
57
|
"""
|
|
60
58
|
Constrain `params` to the appropriate domain.
|
|
@@ -69,7 +67,6 @@ class Flow(eqx.Module):
|
|
|
69
67
|
|
|
70
68
|
return t_params
|
|
71
69
|
|
|
72
|
-
@eqx.filter_jit
|
|
73
70
|
def transform_params(self: Self) -> Dict[str, Array]:
|
|
74
71
|
"""
|
|
75
72
|
Apply a custom transformation to `params` if needed.
|
bayinx/core/_model.py
CHANGED
|
@@ -1,19 +1,34 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from dataclasses import field, fields
|
|
3
|
-
from typing import Any, Self, Tuple
|
|
3
|
+
from typing import Any, Dict, Optional, Self, Tuple
|
|
4
4
|
|
|
5
5
|
import equinox as eqx
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import jax.tree as jt
|
|
8
|
-
from jaxtyping import Scalar
|
|
8
|
+
from jaxtyping import PyTree, Scalar
|
|
9
9
|
|
|
10
10
|
from ._constraint import Constraint
|
|
11
11
|
from ._parameter import Parameter
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def
|
|
15
|
-
|
|
16
|
-
|
|
14
|
+
def define(
|
|
15
|
+
shape: Optional[Tuple[int, ...]] = None,
|
|
16
|
+
init: Optional[PyTree] = None,
|
|
17
|
+
constraint: Optional[Constraint] = None
|
|
18
|
+
):
|
|
19
|
+
"""Define a parameter."""
|
|
20
|
+
metadata: Dict = {}
|
|
21
|
+
if constraint is not None:
|
|
22
|
+
metadata["constraint"] = constraint
|
|
23
|
+
|
|
24
|
+
if isinstance(shape, Tuple):
|
|
25
|
+
metadata["shape"] = shape
|
|
26
|
+
elif isinstance(init, PyTree):
|
|
27
|
+
metadata["init"] = init
|
|
28
|
+
else:
|
|
29
|
+
raise TypeError("Neither 'shape' nor 'init' were given as proper arguments.")
|
|
30
|
+
|
|
31
|
+
return field(metadata = metadata)
|
|
17
32
|
|
|
18
33
|
|
|
19
34
|
class Model(eqx.Module):
|
|
@@ -22,16 +37,32 @@ class Model(eqx.Module):
|
|
|
22
37
|
|
|
23
38
|
Annotate parameter attributes with `Parameter`.
|
|
24
39
|
|
|
25
|
-
Include constraints by setting them equal to `
|
|
40
|
+
Include constraints by setting them equal to `define(Constraint)`.
|
|
26
41
|
"""
|
|
27
42
|
|
|
43
|
+
def __new__(cls, *args, **kwargs):
|
|
44
|
+
obj = super().__new__(cls)
|
|
45
|
+
|
|
46
|
+
# Auto-initialize parameters based on `define` metadata
|
|
47
|
+
for f in fields(cls):
|
|
48
|
+
if "shape" in f.metadata:
|
|
49
|
+
# Construct jax Array with correct dimensions
|
|
50
|
+
setattr(obj, f.name, Parameter(jnp.zeros(f.metadata["shape"])))
|
|
51
|
+
elif "init" in f.metadata:
|
|
52
|
+
# Slot in given 'init' object
|
|
53
|
+
setattr(obj, f.name, Parameter(f.metadata["init"]))
|
|
54
|
+
|
|
55
|
+
return obj
|
|
56
|
+
|
|
57
|
+
def __init__(self):
|
|
58
|
+
return self
|
|
59
|
+
|
|
28
60
|
@abstractmethod
|
|
29
61
|
def eval(self, data: Any) -> Scalar:
|
|
30
62
|
pass
|
|
31
63
|
|
|
32
64
|
# Default filter specification
|
|
33
65
|
@property
|
|
34
|
-
@eqx.filter_jit
|
|
35
66
|
def filter_spec(self) -> Self:
|
|
36
67
|
"""
|
|
37
68
|
Generates a filter specification to subset relevant parameters for the model.
|
|
@@ -49,12 +80,11 @@ class Model(eqx.Module):
|
|
|
49
80
|
filter_spec = eqx.tree_at(
|
|
50
81
|
lambda model: getattr(model, f.name),
|
|
51
82
|
filter_spec,
|
|
52
|
-
replace=attr.filter_spec
|
|
83
|
+
replace=attr.filter_spec
|
|
53
84
|
)
|
|
54
85
|
|
|
55
86
|
return filter_spec
|
|
56
87
|
|
|
57
|
-
@eqx.filter_jit
|
|
58
88
|
def constrain_params(self) -> Tuple[Self, Scalar]:
|
|
59
89
|
"""
|
|
60
90
|
Constrain parameters to the appropriate domain.
|
|
@@ -70,14 +100,14 @@ class Model(eqx.Module):
|
|
|
70
100
|
attr = getattr(self, f.name)
|
|
71
101
|
|
|
72
102
|
# Check if constrained parameter
|
|
73
|
-
if isinstance(attr, Parameter) and "constraint" in f.metadata:
|
|
103
|
+
if isinstance(attr, Parameter) and ("constraint" in f.metadata):
|
|
74
104
|
param = attr
|
|
75
105
|
constraint = f.metadata["constraint"]
|
|
76
106
|
|
|
77
107
|
# Apply constraint
|
|
78
108
|
param, laj = constraint.constrain(param)
|
|
79
109
|
|
|
80
|
-
# Update parameters for constrained model
|
|
110
|
+
# Update parameters for constrained model at same node
|
|
81
111
|
constrained = eqx.tree_at(
|
|
82
112
|
lambda model: getattr(model, f.name), constrained, replace=param
|
|
83
113
|
)
|
|
@@ -87,7 +117,6 @@ class Model(eqx.Module):
|
|
|
87
117
|
|
|
88
118
|
return constrained, target
|
|
89
119
|
|
|
90
|
-
@eqx.filter_jit
|
|
91
120
|
def transform_params(self) -> Tuple[Self, Scalar]:
|
|
92
121
|
"""
|
|
93
122
|
Apply a custom transformation to parameters if needed(defaults to constrained parameters).
|
bayinx/core/_optimization.py
CHANGED
|
@@ -10,6 +10,8 @@ from optax import GradientTransformation, OptState, Schedule
|
|
|
10
10
|
from ._model import Model
|
|
11
11
|
|
|
12
12
|
M = TypeVar("M", bound=Model)
|
|
13
|
+
|
|
14
|
+
|
|
13
15
|
@eqx.filter_jit
|
|
14
16
|
def optimize_model(
|
|
15
17
|
model: M,
|
|
@@ -39,6 +41,7 @@ def optimize_model(
|
|
|
39
41
|
|
|
40
42
|
# Evaluate posterior
|
|
41
43
|
return model.eval(data)
|
|
44
|
+
|
|
42
45
|
eval_grad: Callable[[M], M] = eqx.filter_jit(eqx.filter_grad(eval))
|
|
43
46
|
|
|
44
47
|
# Construct scheduler
|
bayinx/core/_parameter.py
CHANGED
|
@@ -5,6 +5,8 @@ import jax.tree as jt
|
|
|
5
5
|
from jaxtyping import PyTree
|
|
6
6
|
|
|
7
7
|
T = TypeVar("T", bound=PyTree)
|
|
8
|
+
|
|
9
|
+
|
|
8
10
|
class Parameter(eqx.Module, Generic[T]):
|
|
9
11
|
"""
|
|
10
12
|
A container for a parameter of a `Model`.
|
|
@@ -26,19 +28,18 @@ class Parameter(eqx.Module, Generic[T]):
|
|
|
26
28
|
|
|
27
29
|
# Default filter specification
|
|
28
30
|
@property
|
|
29
|
-
@eqx.filter_jit
|
|
30
31
|
def filter_spec(self) -> Self:
|
|
31
32
|
"""
|
|
32
|
-
Generates a filter specification to filter
|
|
33
|
+
Generates a filter specification to filter for dynamic parameters.
|
|
33
34
|
"""
|
|
34
35
|
# Generate empty specification
|
|
35
|
-
filter_spec = jt.map(lambda _: False, self)
|
|
36
|
+
filter_spec: Self = jt.map(lambda _: False, self)
|
|
36
37
|
|
|
37
|
-
# Specify Array leaves
|
|
38
|
+
# Specify Array-like leaves
|
|
38
39
|
filter_spec = eqx.tree_at(
|
|
39
|
-
lambda
|
|
40
|
+
lambda param: param.vals,
|
|
40
41
|
filter_spec,
|
|
41
|
-
replace=jt.map(eqx.
|
|
42
|
+
replace=jt.map(eqx.is_inexact_array_like, self.vals),
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
return filter_spec
|
bayinx/core/_variational.py
CHANGED
|
@@ -22,11 +22,11 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
22
22
|
|
|
23
23
|
# Attributes
|
|
24
24
|
- `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
|
|
25
|
-
- `
|
|
25
|
+
- `_static`: The static component of a partitioned `Model` used to initialize the `Variational` object.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
28
|
_unflatten: Callable[[Array], M]
|
|
29
|
-
|
|
29
|
+
_static: M
|
|
30
30
|
|
|
31
31
|
@abstractmethod
|
|
32
32
|
def filter_spec(self):
|
|
@@ -69,7 +69,7 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
69
69
|
model: M = self._unflatten(draw)
|
|
70
70
|
|
|
71
71
|
# Combine with constraints
|
|
72
|
-
model: M = eqx.combine(model, self.
|
|
72
|
+
model: M = eqx.combine(model, self._static)
|
|
73
73
|
|
|
74
74
|
return model
|
|
75
75
|
|
|
@@ -89,7 +89,6 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
89
89
|
# Evaluate posterior density
|
|
90
90
|
return model.eval(data)
|
|
91
91
|
|
|
92
|
-
# TODO: get rid of this and put it all in each vari's methods, forgot abt discrete parameters :V
|
|
93
92
|
@eqx.filter_jit
|
|
94
93
|
def fit(
|
|
95
94
|
self,
|
|
@@ -116,11 +115,9 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
116
115
|
dyn, static = eqx.partition(self, self.filter_spec)
|
|
117
116
|
|
|
118
117
|
# Construct scheduler
|
|
119
|
-
schedule: Schedule = opx.
|
|
120
|
-
init_value=
|
|
121
|
-
|
|
122
|
-
warmup_steps=int(max_iters / 10),
|
|
123
|
-
decay_steps=max_iters - int(max_iters / 10),
|
|
118
|
+
schedule: Schedule = opx.cosine_decay_schedule(
|
|
119
|
+
init_value=learning_rate,
|
|
120
|
+
decay_steps=max_iters,
|
|
124
121
|
)
|
|
125
122
|
|
|
126
123
|
# Initialize optimizer
|
|
@@ -175,7 +172,7 @@ class Variational(eqx.Module, Generic[M]):
|
|
|
175
172
|
return eqx.combine(dyn, static)
|
|
176
173
|
|
|
177
174
|
@eqx.filter_jit
|
|
178
|
-
def
|
|
175
|
+
def _posterior_predictive(
|
|
179
176
|
self,
|
|
180
177
|
func: Callable[[M, Any], Array],
|
|
181
178
|
n: int,
|
|
File without changes
|
|
@@ -111,6 +111,11 @@ def sample(
|
|
|
111
111
|
|
|
112
112
|
# Construct draws
|
|
113
113
|
draws = jr.uniform(key, shape)
|
|
114
|
-
draws = mu + sigma * ndtri(
|
|
114
|
+
draws = mu + sigma * ndtri(
|
|
115
|
+
normal.cdf(-mu / sigma, 0.0, 1.0) + draws * normal.cdf(mu / sigma, 0.0, 1.0)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Censor draws
|
|
119
|
+
draws.at[censor <= draws].set(censor)
|
|
115
120
|
|
|
116
121
|
return draws
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax.scipy.special import gammaln
|
|
3
|
+
from jaxtyping import Array, ArrayLike, Float, UInt
|
|
4
|
+
|
|
5
|
+
__PI = 3.141592653589793
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def __binom(x, y):
|
|
9
|
+
"""
|
|
10
|
+
Helper function for the Binomial coefficient.
|
|
11
|
+
"""
|
|
12
|
+
return jnp.exp(gammaln(x + 1) - gammaln(y + 1) - gammaln(x - y + 1))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def prob(
|
|
16
|
+
x: UInt[ArrayLike, "..."],
|
|
17
|
+
mu: Float[ArrayLike, "..."],
|
|
18
|
+
phi: Float[ArrayLike, "..."],
|
|
19
|
+
) -> Float[Array, "..."]:
|
|
20
|
+
"""
|
|
21
|
+
The probability mass function (PMF) for a (mean-inverse overdispersion parameterized) Negatvie Binomial distribution.
|
|
22
|
+
|
|
23
|
+
# Parameters
|
|
24
|
+
- `x`: Where to evaluate the PMF.
|
|
25
|
+
- `mu`: The mean.
|
|
26
|
+
- `phi`: The inverse overdispersion.
|
|
27
|
+
|
|
28
|
+
# Returns
|
|
29
|
+
The PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `phi`.
|
|
30
|
+
"""
|
|
31
|
+
# Cast to Array
|
|
32
|
+
x, mu, phi = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(phi)
|
|
33
|
+
|
|
34
|
+
return jnp.exp(logprob(x, mu, phi))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def logprob(
|
|
38
|
+
x: UInt[ArrayLike, "..."],
|
|
39
|
+
mu: Float[ArrayLike, "..."],
|
|
40
|
+
phi: Float[ArrayLike, "..."],
|
|
41
|
+
) -> Float[Array, "..."]:
|
|
42
|
+
"""
|
|
43
|
+
The log-transformed probability mass function (PMF) for a (mean-inverse overdispersion parameterized) Negatvie Binomial distribution.
|
|
44
|
+
|
|
45
|
+
# Parameters
|
|
46
|
+
- `x`: Where to evaluate the log PMF.
|
|
47
|
+
- `mu`: The mean.
|
|
48
|
+
- `phi`: The inverse overdispersion.
|
|
49
|
+
|
|
50
|
+
# Returns
|
|
51
|
+
The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `phi`.
|
|
52
|
+
"""
|
|
53
|
+
# Cast to Array
|
|
54
|
+
x, mu, phi = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(phi)
|
|
55
|
+
|
|
56
|
+
# Evaluate log PMF
|
|
57
|
+
evals: Array = jnp.where(
|
|
58
|
+
x < 0,
|
|
59
|
+
-jnp.inf,
|
|
60
|
+
(
|
|
61
|
+
gammaln(x + phi)
|
|
62
|
+
- gammaln(x + 1)
|
|
63
|
+
- gammaln(phi)
|
|
64
|
+
+ x * (jnp.log(mu) - jnp.log(mu + phi))
|
|
65
|
+
+ phi * (jnp.log(phi) - jnp.log(mu + phi))
|
|
66
|
+
),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return evals
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def cdf(
|
|
73
|
+
x: Float[ArrayLike, "..."],
|
|
74
|
+
mu: Float[ArrayLike, "..."],
|
|
75
|
+
sigma: Float[ArrayLike, "..."],
|
|
76
|
+
) -> Float[Array, "..."]:
|
|
77
|
+
# Cast to Array
|
|
78
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
79
|
+
|
|
80
|
+
return jnp.array(1.0)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def logcdf(
|
|
84
|
+
x: Float[ArrayLike, "..."],
|
|
85
|
+
mu: Float[ArrayLike, "..."],
|
|
86
|
+
sigma: Float[ArrayLike, "..."],
|
|
87
|
+
) -> Float[Array, "..."]:
|
|
88
|
+
# Cast to Array
|
|
89
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
90
|
+
|
|
91
|
+
return jnp.array(1.0)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def ccdf(
|
|
95
|
+
x: Float[ArrayLike, "..."],
|
|
96
|
+
mu: Float[ArrayLike, "..."],
|
|
97
|
+
sigma: Float[ArrayLike, "..."],
|
|
98
|
+
) -> Float[Array, "..."]:
|
|
99
|
+
# Cast to Array
|
|
100
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
101
|
+
|
|
102
|
+
return jnp.array(1.0)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def logccdf(
|
|
106
|
+
x: Float[ArrayLike, "..."],
|
|
107
|
+
mu: Float[ArrayLike, "..."],
|
|
108
|
+
sigma: Float[ArrayLike, "..."],
|
|
109
|
+
) -> Float[Array, "..."]:
|
|
110
|
+
# Cast to Array
|
|
111
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
|
112
|
+
|
|
113
|
+
return jnp.array(1.0)
|
bayinx/dists/posnormal.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import jax.numpy as jnp
|
|
2
|
-
|
|
2
|
+
import jax.random as jr
|
|
3
|
+
from jax.scipy.special import ndtri
|
|
4
|
+
from jaxtyping import Array, ArrayLike, Float, Key
|
|
3
5
|
|
|
4
6
|
from bayinx.dists import normal
|
|
5
7
|
|
|
@@ -251,3 +253,38 @@ def logccdf(
|
|
|
251
253
|
evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
|
|
252
254
|
|
|
253
255
|
return evals
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def sample(
|
|
259
|
+
n: int,
|
|
260
|
+
mu: Float[ArrayLike, "..."],
|
|
261
|
+
sigma: Float[ArrayLike, "..."],
|
|
262
|
+
key: Key = jr.PRNGKey(0),
|
|
263
|
+
) -> Float[Array, "..."]:
|
|
264
|
+
"""
|
|
265
|
+
Sample from a positive Normal distribution.
|
|
266
|
+
|
|
267
|
+
# Parameters
|
|
268
|
+
- `n`: Number of draws to sample per-parameter.
|
|
269
|
+
- `mu`: The mean.
|
|
270
|
+
- `sigma`: The standard deviation.
|
|
271
|
+
|
|
272
|
+
# Returns
|
|
273
|
+
Draws from a positive Normal distribution. The output will have the shape of (n,) + the broadcasted shapes of `mu` and `sigma`.
|
|
274
|
+
"""
|
|
275
|
+
# Cast to Array
|
|
276
|
+
mu, sigma = (
|
|
277
|
+
jnp.asarray(mu),
|
|
278
|
+
jnp.asarray(sigma),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Derive shape
|
|
282
|
+
shape = (n,) + jnp.broadcast_shapes(mu.shape, sigma.shape)
|
|
283
|
+
|
|
284
|
+
# Construct draws
|
|
285
|
+
draws = jr.uniform(key, shape)
|
|
286
|
+
draws = mu + sigma * ndtri(
|
|
287
|
+
normal.cdf(-mu / sigma, 0.0, 1.0) + draws * normal.cdf(mu / sigma, 0.0, 1.0)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
return draws
|
bayinx/dists/uniform.py
CHANGED
|
@@ -83,8 +83,12 @@ def ulogprob(
|
|
|
83
83
|
|
|
84
84
|
return jnp.zeros(jnp.broadcast_shapes(x.shape, lb.shape, ub.shape))
|
|
85
85
|
|
|
86
|
+
|
|
86
87
|
def sample(
|
|
87
|
-
n: int,
|
|
88
|
+
n: int,
|
|
89
|
+
lb: Float[ArrayLike, "..."],
|
|
90
|
+
ub: Float[ArrayLike, "..."],
|
|
91
|
+
key: Key = jr.PRNGKey(0),
|
|
88
92
|
) -> Float[Array, "..."]:
|
|
89
93
|
"""
|
|
90
94
|
Sample from a Uniform distribution.
|
|
@@ -104,6 +108,6 @@ def sample(
|
|
|
104
108
|
shape = (n,) + jnp.broadcast_shapes(lb.shape, ub.shape)
|
|
105
109
|
|
|
106
110
|
# Construct draws
|
|
107
|
-
draws = jr.uniform(key, shape, minval
|
|
111
|
+
draws = jr.uniform(key, shape, minval=lb, maxval=ub)
|
|
108
112
|
|
|
109
113
|
return draws
|
bayinx/mhx/vi/flows/radial.py
CHANGED
|
@@ -37,8 +37,8 @@ class Radial(Flow):
|
|
|
37
37
|
}
|
|
38
38
|
self.constraints = {"beta": jnp.exp}
|
|
39
39
|
|
|
40
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
|
41
40
|
@eqx.filter_jit
|
|
41
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
|
42
42
|
def forward(self, draws: Array) -> Array:
|
|
43
43
|
"""
|
|
44
44
|
Applies the forward radial transformation for each draw.
|
bayinx/mhx/vi/meanfield.py
CHANGED
|
@@ -34,7 +34,7 @@ class MeanField(Variational, Generic[M]):
|
|
|
34
34
|
- `init_log_std`: The initial log-transformed standard deviation of the Gaussian approximation.
|
|
35
35
|
"""
|
|
36
36
|
# Partition model
|
|
37
|
-
params, self.
|
|
37
|
+
params, self._static = eqx.partition(model, model.filter_spec)
|
|
38
38
|
|
|
39
39
|
# Flatten params component
|
|
40
40
|
params, self._unflatten = ravel_pytree(params)
|
|
@@ -44,7 +44,6 @@ class MeanField(Variational, Generic[M]):
|
|
|
44
44
|
self.log_std = jnp.full(params.size, init_log_std, params.dtype)
|
|
45
45
|
|
|
46
46
|
@property
|
|
47
|
-
@eqx.filter_jit
|
|
48
47
|
def filter_spec(self):
|
|
49
48
|
# Generate empty specification
|
|
50
49
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
|
@@ -67,8 +66,7 @@ class MeanField(Variational, Generic[M]):
|
|
|
67
66
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
|
68
67
|
# Sample variational draws
|
|
69
68
|
draws: Array = (
|
|
70
|
-
jr.normal(key=key, shape=(n, self.mean.size))
|
|
71
|
-
* jnp.exp(self.log_std)
|
|
69
|
+
jr.normal(key=key, shape=(n, self.mean.size)) * jnp.exp(self.log_std)
|
|
72
70
|
+ self.mean
|
|
73
71
|
)
|
|
74
72
|
|
|
@@ -108,10 +106,9 @@ class MeanField(Variational, Generic[M]):
|
|
|
108
106
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
|
109
107
|
dyn, static = eqx.partition(self, self.filter_spec)
|
|
110
108
|
|
|
111
|
-
@eqx.filter_grad
|
|
112
109
|
@eqx.filter_jit
|
|
110
|
+
@eqx.filter_grad
|
|
113
111
|
def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
|
|
114
|
-
# Combine
|
|
115
112
|
vari = eqx.combine(dyn, static)
|
|
116
113
|
|
|
117
114
|
# Sample draws from variational distribution
|
|
@@ -31,11 +31,11 @@ class NormalizingFlow(Variational, Generic[M]):
|
|
|
31
31
|
|
|
32
32
|
# Parameters
|
|
33
33
|
- `base`: The base variational distribution.
|
|
34
|
-
- `flows`: A list of
|
|
34
|
+
- `flows`: A list of flows.
|
|
35
35
|
- `model`: A probabilistic `Model` object.
|
|
36
36
|
"""
|
|
37
37
|
# Partition model
|
|
38
|
-
params, self.
|
|
38
|
+
params, self._static = eqx.partition(model, model.filter_spec)
|
|
39
39
|
|
|
40
40
|
# Flatten params component
|
|
41
41
|
_, self._unflatten = jfu.ravel_pytree(params)
|
|
@@ -44,7 +44,6 @@ class NormalizingFlow(Variational, Generic[M]):
|
|
|
44
44
|
self.flows = flows
|
|
45
45
|
|
|
46
46
|
@property
|
|
47
|
-
@eqx.filter_jit
|
|
48
47
|
def filter_spec(self):
|
|
49
48
|
# Generate empty specification
|
|
50
49
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
|
@@ -78,7 +77,7 @@ class NormalizingFlow(Variational, Generic[M]):
|
|
|
78
77
|
variational_evals: Array = self.base.eval(draws)
|
|
79
78
|
|
|
80
79
|
for map in self.flows:
|
|
81
|
-
#
|
|
80
|
+
# Apply transformation
|
|
82
81
|
draws, laj = map.adjust_density(draws)
|
|
83
82
|
|
|
84
83
|
# Adjust variational density
|
|
@@ -103,7 +102,7 @@ class NormalizingFlow(Variational, Generic[M]):
|
|
|
103
102
|
variational_evals: Array = self.base.eval(draws)
|
|
104
103
|
|
|
105
104
|
for map in self.flows:
|
|
106
|
-
#
|
|
105
|
+
# Apply transformation
|
|
107
106
|
draws, laj = map.adjust_density(draws)
|
|
108
107
|
|
|
109
108
|
# Adjust variational density
|
bayinx/mhx/vi/standard.py
CHANGED
|
@@ -27,7 +27,7 @@ class Standard(Variational[M]):
|
|
|
27
27
|
- `model`: A probabilistic `Model` object.
|
|
28
28
|
"""
|
|
29
29
|
# Partition model
|
|
30
|
-
params, self.
|
|
30
|
+
params, self._static = eqx.partition(model, model.filter_spec)
|
|
31
31
|
|
|
32
32
|
# Flatten params component
|
|
33
33
|
params, self._unflatten = ravel_pytree(params)
|
|
@@ -35,6 +35,7 @@ class Standard(Variational[M]):
|
|
|
35
35
|
# Store dimension of parameter space
|
|
36
36
|
self.dim = jnp.size(params)
|
|
37
37
|
|
|
38
|
+
|
|
38
39
|
@eqx.filter_jit
|
|
39
40
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
|
40
41
|
# Sample variational draws
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: bayinx
|
|
3
|
+
Version: 0.4.1
|
|
4
|
+
Summary: Bayesian Inference with JAX
|
|
5
|
+
Author: Todd McCready
|
|
6
|
+
Maintainer: Todd McCready
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
|
+
Requires-Dist: equinox>=0.11.12
|
|
10
|
+
Requires-Dist: jax>=0.4.38
|
|
11
|
+
Requires-Dist: jaxtyping>=0.2.36
|
|
12
|
+
Requires-Dist: optax>=0.2.4
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
|
|
15
|
+
# Bayinx: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
|
16
|
+
|
|
17
|
+
The original aim of this project was to build a PPL in Python that is similar in feel to Stan or Nimble(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to [baycian](https://github.com/toddmccready/baycian) for the foreseeable future.
|
|
18
|
+
|
|
19
|
+
Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to [Rust macros](https://doc.rust-lang.org/rust-by-example/macros/dsl.html); I can define syntax similar to Stan and parse it to valid Rust code. Additionally, the current state of bayinx is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough for one of my other projects: [disize](https://github.com/toddmccready/disize)! I plan to rewrite disize in Python with JAX, and bayinx makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
|
|
20
|
+
|
|
21
|
+
Instead, this project is narrowing on implementing much of Stan's functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions `0.4.#`. Therefore, people will work with `target` directly and return the density like below:
|
|
22
|
+
|
|
23
|
+
```py
|
|
24
|
+
class NormalDist(Model):
|
|
25
|
+
x: Parameter[Array] = define(shape = (2,))
|
|
26
|
+
|
|
27
|
+
def eval(self, data: Dict[str, Array]):
|
|
28
|
+
# Constrain parameters
|
|
29
|
+
self, target = self.constrain_params() # this does nothing for the current model
|
|
30
|
+
|
|
31
|
+
# Evaluate x ~ Normal(10.0, 1.0)
|
|
32
|
+
target += normal.logprob(self.x(), 10.0, 1.0).sum()
|
|
33
|
+
|
|
34
|
+
return target
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
I have ideas for using a context manager and implementing `Node`: `Observed`/`Stochastic` classes that will try and replicate what `baycian` is trying to do, but that is for the future and versions `0.4.#` will retain the functionality needed for disize.
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# TODO
|
|
41
|
+
- For optimization and variational methods offer a way for users to have custom stopping conditions(perhaps stop if a single parameter has converged, etc).
|
|
42
|
+
- Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
|
|
43
|
+
- Low-rank affine flow?
|
|
44
|
+
- https://arxiv.org/pdf/1803.05649 implement sylvester flows.
|
|
45
|
+
- Learn how to generate documentation.
|
|
46
|
+
- Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
|
|
47
|
+
- Look into adaptively tuning ADAM hyperparameters for VI.
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
bayinx/__init__.py,sha256=8etrxEtEGEzSDmKsW0TB4XoUGLiMPt9wpwNR8CGe1gU,93
|
|
2
|
+
bayinx/constraints.py,sha256=2ufHsXR-_bWKR4WKKuR-OTjj3XCc4TkSeHVGWYadwCg,4387
|
|
3
|
+
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
+
bayinx/core/__init__.py,sha256=samkrHp2zYyj8n37k-06tlaVrSqbtcgoa1LO0btAEHc,338
|
|
5
|
+
bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
|
|
6
|
+
bayinx/core/_flow.py,sha256=V7uVfo2kBkyD4lWKBdSFFSmla68E_Hqv-uu9akuWBio,2282
|
|
7
|
+
bayinx/core/_model.py,sha256=9uRVG5Zwrt9GHCY1-ULXw4KCZWHBpsG8AirJplDMT78,3848
|
|
8
|
+
bayinx/core/_optimization.py,sha256=Ehp8UqtN-IYrEeX178TFUjCsJEBHxDnqxvgIco15zYU,2654
|
|
9
|
+
bayinx/core/_parameter.py,sha256=O2s0-WxxqWOqwjuqWBh-JU-rRSBKpT_Fzy4ZxXr3cEc,1080
|
|
10
|
+
bayinx/core/_variational.py,sha256=ddjtdq6z8tCJ5RZ4crveiN1HuPILdejYfwcAiZDLNZ4,6164
|
|
11
|
+
bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
|
|
12
|
+
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
|
13
|
+
bayinx/dists/gamma2.py,sha256=HtB60LUQdPj4yDAHme2jsHNmLfrAKWsSZnDYkxAGaOI,1548
|
|
14
|
+
bayinx/dists/negbinom3.py,sha256=u_USHQHxXmdS6hDCW2xmDpcqlxf17SJLaeptpxAI8TQ,2912
|
|
15
|
+
bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
|
|
16
|
+
bayinx/dists/posnormal.py,sha256=P4ShZPeqw2dr1mnhgaMku8JLEmCeQOe3jLb-BQbqs9o,8087
|
|
17
|
+
bayinx/dists/uniform.py,sha256=FGXEIrvq4UXIhZ3mz4EDmbkpnZxUa2MkyeqEUu0CGJ4,3470
|
|
18
|
+
bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
|
|
19
|
+
bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
|
20
|
+
bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
|
|
21
|
+
bayinx/dists/censored/negbinom3/r.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
+
bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
|
23
|
+
bayinx/dists/censored/posnormal/r.py,sha256=nefEBJvjgekPL8L7L7-UXV5bc9cIWDlcDBwWWrA7YcM,3536
|
|
24
|
+
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
25
|
+
bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
+
bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
|
|
27
|
+
bayinx/mhx/vi/meanfield.py,sha256=GLX6fzXr1v_PEHNyXCcBgTT5_7tGSSEkrTkwCNujplI,3858
|
|
28
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=xSRxJ6hmrT6Fnx-uSIJ7mqAO5JM7bmbJtj2uljcKERs,4681
|
|
29
|
+
bayinx/mhx/vi/standard.py,sha256=s-Kvw37Y_KzlkcbeS6eVppg6DgPCclZQEcooGgbg3SU,1574
|
|
30
|
+
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
|
31
|
+
bayinx/mhx/vi/flows/fullaffine.py,sha256=7KXuukwzVtMRIa8bSK_4pjnnP-lLIzVJBCAuKVydVgE,1925
|
|
32
|
+
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
|
33
|
+
bayinx/mhx/vi/flows/radial.py,sha256=AyaqLJCwn871L6E8lBCU4Y8zZBF9UYZu6KIhzV6Z6wo,2503
|
|
34
|
+
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
|
35
|
+
bayinx-0.4.1.dist-info/METADATA,sha256=7Zw-9hVqUVxj3ncyGBfn72FQzTDomdDaXXH2hOsJM60,2989
|
|
36
|
+
bayinx-0.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
37
|
+
bayinx-0.4.1.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
|
38
|
+
bayinx-0.4.1.dist-info/RECORD,,
|
bayinx/constraints/__init__.py
DELETED
bayinx/constraints/lower.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
from typing import Tuple
|
|
2
|
-
|
|
3
|
-
import equinox as eqx
|
|
4
|
-
import jax.numpy as jnp
|
|
5
|
-
import jax.tree as jt
|
|
6
|
-
from jaxtyping import PyTree, Scalar, ScalarLike
|
|
7
|
-
|
|
8
|
-
from bayinx.core import Constraint, Parameter
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class Lower(Constraint):
|
|
12
|
-
"""
|
|
13
|
-
Enforces a lower bound on the parameter.
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
lb: Scalar
|
|
17
|
-
|
|
18
|
-
def __init__(self, lb: ScalarLike):
|
|
19
|
-
self.lb = jnp.array(lb)
|
|
20
|
-
|
|
21
|
-
@eqx.filter_jit
|
|
22
|
-
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
|
23
|
-
"""
|
|
24
|
-
Enforces a lower bound on the parameter and adjusts the posterior density.
|
|
25
|
-
|
|
26
|
-
# Parameters
|
|
27
|
-
- `x`: The unconstrained `Parameter`.
|
|
28
|
-
|
|
29
|
-
# Parameters
|
|
30
|
-
A tuple containing:
|
|
31
|
-
- A modified `Parameter` with relevant leaves satisfying the constraint.
|
|
32
|
-
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
|
33
|
-
"""
|
|
34
|
-
# Extract relevant filter specification
|
|
35
|
-
filter_spec = x.filter_spec
|
|
36
|
-
|
|
37
|
-
# Extract relevant parameters(all Array)
|
|
38
|
-
dyn_params, static_params = eqx.partition(x, filter_spec)
|
|
39
|
-
|
|
40
|
-
# Compute density adjustment
|
|
41
|
-
laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
|
|
42
|
-
laj: Scalar = jt.reduce(lambda a, b: a + b, laj)
|
|
43
|
-
|
|
44
|
-
# Compute transformation
|
|
45
|
-
dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
|
|
46
|
-
|
|
47
|
-
# Combine into full parameter object
|
|
48
|
-
x = eqx.combine(dyn_params, static_params)
|
|
49
|
-
|
|
50
|
-
return x, laj
|
bayinx-0.3.19.dist-info/METADATA
DELETED
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: bayinx
|
|
3
|
-
Version: 0.3.19
|
|
4
|
-
Summary: Bayesian Inference with JAX
|
|
5
|
-
License-File: LICENSE
|
|
6
|
-
Requires-Python: >=3.12
|
|
7
|
-
Requires-Dist: equinox>=0.11.12
|
|
8
|
-
Requires-Dist: jax>=0.4.38
|
|
9
|
-
Requires-Dist: jaxtyping>=0.2.36
|
|
10
|
-
Requires-Dist: optax>=0.2.4
|
|
11
|
-
Description-Content-Type: text/markdown
|
|
12
|
-
|
|
13
|
-
# `Bayinx`: <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
|
14
|
-
|
|
15
|
-
The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is similar to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
|
|
16
|
-
|
|
17
|
-
In the short-term, I'm going to focus on:
|
|
18
|
-
1) Implementing as much machinery as I feel is enough.
|
|
19
|
-
2) Figuring out how to design the `Model` superclass to have something like the `transformed pars {}` block but unifies transformations and constraints.
|
|
20
|
-
3) Figuring out how to design the library to automatically recognize what kind of machinery is amenable to a given probabilistic model.
|
|
21
|
-
|
|
22
|
-
In the long-term, I'm going to focus on:
|
|
23
|
-
1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
|
|
24
|
-
2) How to make working with the posterior as easy as possible.
|
|
25
|
-
- That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{D}, \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{D}, \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
|
|
26
|
-
|
|
27
|
-
Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
|
|
28
|
-
|
|
29
|
-
Additionally, when I get around to it I'd like the package documentation to also include theoretical and implementation details for all machinery implemented(with overthinking boxes because I do like that design from McElreath's book).
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
# TODO
|
|
33
|
-
- Find some way to discern between models with all floating-point parameters and weirder models with integer parameters. Useful for restricting variational methods like `MeanField` to `Model`s that only have floating-point parameters.
|
|
34
|
-
- Look into adaptively tuning ADAM hyperparameters.
|
|
35
|
-
- Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
|
|
36
|
-
- Low-rank affine flow?
|
|
37
|
-
- https://arxiv.org/pdf/1803.05649 implement sylvester flows.
|
|
38
|
-
- Learn how to generate documentation lol.
|
|
39
|
-
- Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
|
bayinx-0.3.19.dist-info/RECORD
DELETED
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
|
|
2
|
-
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
bayinx/constraints/__init__.py,sha256=027WJxRLkybXZkmusfvR6iZayY2pDid7Tw6TTTeA6ko,64
|
|
4
|
-
bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
|
|
5
|
-
bayinx/core/__init__.py,sha256=Qmy0EjzqqKwI9F8rjmC9j6J8hiDw6A54yOck2WuQJkY,344
|
|
6
|
-
bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
|
|
7
|
-
bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
|
8
|
-
bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
|
|
9
|
-
bayinx/core/_optimization.py,sha256=mmeVUqfFARz8F7q4LRl-uEwVWzekmzh-9o7PnuvsHZk,2651
|
|
10
|
-
bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
|
|
11
|
-
bayinx/core/_variational.py,sha256=ljhLi9vbw8wh-z0Eisf0C08IoKlgwpx0VONh_ES-HmI,6384
|
|
12
|
-
bayinx/dists/__init__.py,sha256=BIrypqMnTLWK3a_zw8fYKMyuEMxP_qGsLfLeScias0o,118
|
|
13
|
-
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
|
14
|
-
bayinx/dists/gamma2.py,sha256=HtB60LUQdPj4yDAHme2jsHNmLfrAKWsSZnDYkxAGaOI,1548
|
|
15
|
-
bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
|
|
16
|
-
bayinx/dists/posnormal.py,sha256=cOLCdd39DX3v8DD-seSIKNk4OfdNfaYaLzpCh_xBGyw,7150
|
|
17
|
-
bayinx/dists/uniform.py,sha256=2ZQxEfAX5TFgSPuQ8joFDuFbd_NfmQ1GvmGGjusqvNQ,3461
|
|
18
|
-
bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
|
|
19
|
-
bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
|
20
|
-
bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
|
|
21
|
-
bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
|
22
|
-
bayinx/dists/censored/posnormal/r.py,sha256=wMDt2Am1TD376ms8B-o6PFCJZXmUJd2-aBC-t9kidH4,3456
|
|
23
|
-
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
24
|
-
bayinx/mhx/opt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
-
bayinx/mhx/vi/__init__.py,sha256=3T1dEpiiRge4tW-vpS0xBob_RbO1iVFnL3fVCRUawCM,205
|
|
26
|
-
bayinx/mhx/vi/meanfield.py,sha256=iX4AeDG9jrLZd6d9NimuJ3O5zaoBXsD03JbgPgxVrfY,3917
|
|
27
|
-
bayinx/mhx/vi/normalizing_flow.py,sha256=vzLu5H1G1-pBqhgHWmIZkUTyPE1DxC9vBwpiZeIyu1I,4712
|
|
28
|
-
bayinx/mhx/vi/standard.py,sha256=LYgglaGQMGmXpzFR4eMJnXkl2PhBeggbXMvO5zJpf2c,1578
|
|
29
|
-
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
|
30
|
-
bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
|
|
31
|
-
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
|
32
|
-
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
|
33
|
-
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
|
34
|
-
bayinx-0.3.19.dist-info/METADATA,sha256=V1qTM781r9t4cDGSRkgvn4PQAoxtsD9cFrMj0YVhUHo,3087
|
|
35
|
-
bayinx-0.3.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
36
|
-
bayinx-0.3.19.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
|
37
|
-
bayinx-0.3.19.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|