bayinx 0.2.6__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/core/flow.py +35 -23
- bayinx/core/model.py +19 -0
- bayinx/core/variational.py +17 -10
- bayinx/mhx/__init__.py +1 -0
- bayinx/mhx/vi/__init__.py +3 -0
- bayinx/mhx/vi/flows/__init__.py +3 -0
- bayinx/{machinery/variational → mhx/vi}/flows/affine.py +1 -3
- bayinx/{machinery/variational → mhx/vi}/flows/planar.py +1 -3
- bayinx/{machinery/variational → mhx/vi}/flows/radial.py +3 -4
- bayinx/mhx/vi/flows/sylvester.py +19 -0
- bayinx/{machinery/variational → mhx/vi}/meanfield.py +4 -4
- bayinx/{machinery/variational → mhx/vi}/normalizing_flow.py +12 -14
- bayinx/{machinery/variational → mhx/vi}/standard.py +6 -6
- {bayinx-0.2.6.dist-info → bayinx-0.2.9.dist-info}/METADATA +2 -4
- bayinx-0.2.9.dist-info/RECORD +26 -0
- bayinx/machinery/__init__.py +0 -0
- bayinx/machinery/variational/__init__.py +0 -5
- bayinx/machinery/variational/flows/__init__.py +0 -3
- bayinx/machinery/variational/flows/sylvester.py +0 -76
- bayinx-0.2.6.dist-info/RECORD +0 -26
- {bayinx-0.2.6.dist-info → bayinx-0.2.9.dist-info}/WHEEL +0 -0
bayinx/core/flow.py
CHANGED
@@ -2,6 +2,7 @@ from abc import abstractmethod
|
|
2
2
|
from typing import Callable, Dict, Self, Tuple
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
|
+
import jax.tree_util as jtu
|
5
6
|
from jaxtyping import Array, Float
|
6
7
|
|
7
8
|
from bayinx.core.utils import __MyMeta
|
@@ -36,33 +37,44 @@ class Flow(eqx.Module, metaclass=__MyMeta):
|
|
36
37
|
"""
|
37
38
|
pass
|
38
39
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
40
|
+
# Default filter specification
|
41
|
+
def filter_spec(self):
|
42
|
+
"""
|
43
|
+
Generates a filter specification to subset relevant parameters for the flow.
|
44
|
+
"""
|
45
|
+
# Generate empty specification
|
46
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
44
47
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
48
|
+
# Specify JAX Array parameters
|
49
|
+
filter_spec = eqx.tree_at(
|
50
|
+
lambda flow: flow.params,
|
51
|
+
filter_spec,
|
52
|
+
replace=jtu.tree_map(eqx.is_array, self.params),
|
53
|
+
)
|
49
54
|
|
50
|
-
|
51
|
-
t_params[par] = map(t_params[par])
|
55
|
+
return filter_spec
|
52
56
|
|
53
|
-
|
57
|
+
@eqx.filter_jit
|
58
|
+
def constrain_pars(self: Self):
|
59
|
+
"""
|
60
|
+
Constrain `params` to the appropriate domain.
|
54
61
|
|
55
|
-
|
62
|
+
# Returns
|
63
|
+
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
64
|
+
"""
|
65
|
+
t_params = self.params
|
56
66
|
|
57
|
-
|
58
|
-
|
59
|
-
def transform_pars(self: Self) -> Dict[str, Array]:
|
60
|
-
"""
|
61
|
-
Apply a custom transformation to `params` if needed.
|
67
|
+
for par, map in self.constraints.items():
|
68
|
+
t_params[par] = map(t_params[par])
|
62
69
|
|
63
|
-
|
64
|
-
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
65
|
-
"""
|
66
|
-
return self.constrain_pars()
|
70
|
+
return t_params
|
67
71
|
|
68
|
-
|
72
|
+
@eqx.filter_jit
|
73
|
+
def transform_pars(self: Self) -> Dict[str, Array]:
|
74
|
+
"""
|
75
|
+
Apply a custom transformation to `params` if needed.
|
76
|
+
|
77
|
+
# Returns
|
78
|
+
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
79
|
+
"""
|
80
|
+
return self.constrain_pars()
|
bayinx/core/model.py
CHANGED
@@ -2,6 +2,7 @@ from abc import abstractmethod
|
|
2
2
|
from typing import Any, Callable, Dict
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
|
+
import jax.tree_util as jtu
|
5
6
|
from jaxtyping import Array, Scalar
|
6
7
|
|
7
8
|
from bayinx.core.utils import __MyMeta
|
@@ -23,6 +24,23 @@ class Model(eqx.Module, metaclass=__MyMeta):
|
|
23
24
|
def eval(self, data: Any) -> Scalar:
|
24
25
|
pass
|
25
26
|
|
27
|
+
# Default filter specification
|
28
|
+
def filter_spec(self):
|
29
|
+
"""
|
30
|
+
Generates a filter specification to subset relevant parameters for the model.
|
31
|
+
"""
|
32
|
+
# Generate empty specification
|
33
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
34
|
+
|
35
|
+
# Specify JAX Array parameters
|
36
|
+
filter_spec = eqx.tree_at(
|
37
|
+
lambda model: model.params,
|
38
|
+
filter_spec,
|
39
|
+
replace=jtu.tree_map(eqx.is_array, self.params),
|
40
|
+
)
|
41
|
+
|
42
|
+
return filter_spec
|
43
|
+
|
26
44
|
def __init_subclass__(cls):
|
27
45
|
# Add constrain method
|
28
46
|
def constrain_pars(self: Model) -> Dict[str, Array]:
|
@@ -43,6 +61,7 @@ class Model(eqx.Module, metaclass=__MyMeta):
|
|
43
61
|
|
44
62
|
# Add transform_pars method if not present
|
45
63
|
if not callable(getattr(cls, "transform_pars", None)):
|
64
|
+
|
46
65
|
def transform_pars(self: Model) -> Dict[str, Array]:
|
47
66
|
"""
|
48
67
|
Apply a custom transformation to `params` if needed.
|
bayinx/core/variational.py
CHANGED
@@ -104,6 +104,9 @@ class Variational(eqx.Module):
|
|
104
104
|
- `var_draws`: Number of variational draws to draw each iteration.
|
105
105
|
- `key`: A PRNG key.
|
106
106
|
"""
|
107
|
+
# Partition variational
|
108
|
+
dyn, static = eqx.partition(self, self.filter_spec())
|
109
|
+
|
107
110
|
# Construct scheduler
|
108
111
|
schedule: Schedule = opx.cosine_decay_schedule(
|
109
112
|
init_value=learning_rate, decay_steps=max_iters
|
@@ -113,20 +116,20 @@ class Variational(eqx.Module):
|
|
113
116
|
optim: GradientTransformation = opx.chain(
|
114
117
|
opx.scale(-1.0), opx.nadam(schedule)
|
115
118
|
)
|
116
|
-
opt_state: OptState = optim.init(
|
119
|
+
opt_state: OptState = optim.init(dyn)
|
117
120
|
|
118
121
|
# Optimization loop helper functions
|
119
122
|
@eqx.filter_jit
|
120
123
|
def condition(state: Tuple[Self, OptState, Scalar, Key]):
|
121
124
|
# Unpack iteration state
|
122
|
-
|
125
|
+
dyn, opt_state, i, key = state
|
123
126
|
|
124
127
|
return i < max_iters
|
125
128
|
|
126
129
|
@eqx.filter_jit
|
127
130
|
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
128
131
|
# Unpack iteration state
|
129
|
-
|
132
|
+
dyn, opt_state, i, key = state
|
130
133
|
|
131
134
|
# Update iteration
|
132
135
|
i = i + 1
|
@@ -134,26 +137,30 @@ class Variational(eqx.Module):
|
|
134
137
|
# Update PRNG key
|
135
138
|
key, _ = jr.split(key)
|
136
139
|
|
140
|
+
# Combine variational
|
141
|
+
vari = eqx.combine(dyn, static)
|
142
|
+
|
137
143
|
# Compute gradient of the ELBO
|
138
|
-
updates: PyTree =
|
144
|
+
updates: PyTree = vari.elbo_grad(var_draws, key, data)
|
139
145
|
|
140
146
|
# Compute updates
|
141
147
|
updates, opt_state = optim.update(
|
142
|
-
updates, opt_state, eqx.filter(
|
148
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
|
143
149
|
)
|
144
150
|
|
145
151
|
# Update variational distribution
|
146
|
-
|
152
|
+
dyn = eqx.apply_updates(dyn, updates)
|
147
153
|
|
148
|
-
return
|
154
|
+
return dyn, opt_state, i, key
|
149
155
|
|
150
156
|
# Run optimization loop
|
151
|
-
|
157
|
+
dyn = lax.while_loop(
|
152
158
|
cond_fun=condition,
|
153
159
|
body_fun=body,
|
154
|
-
init_val=(
|
160
|
+
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
155
161
|
)[0]
|
156
162
|
|
157
|
-
|
163
|
+
# Return optimized variational
|
164
|
+
return eqx.combine(dyn, static)
|
158
165
|
|
159
166
|
cls.fit = eqx.filter_jit(fit)
|
bayinx/mhx/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -19,9 +19,7 @@ class Affine(Flow):
|
|
19
19
|
"""
|
20
20
|
|
21
21
|
params: Dict[str, Float[Array, "..."]]
|
22
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
23
|
-
eqx.field(static=True)
|
24
|
-
)
|
22
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
25
23
|
|
26
24
|
def __init__(self, dim: int):
|
27
25
|
"""
|
@@ -20,9 +20,7 @@ class Planar(Flow):
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
params: Dict[str, Float[Array, "..."]]
|
23
|
-
constraints: Dict[str, Callable[[
|
24
|
-
eqx.field(static=True)
|
25
|
-
)
|
23
|
+
constraints: Dict[str, Callable[[Array], Array]]
|
26
24
|
|
27
25
|
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
28
26
|
"""
|
@@ -21,9 +21,7 @@ class Radial(Flow):
|
|
21
21
|
"""
|
22
22
|
|
23
23
|
params: Dict[str, Float[Array, "..."]]
|
24
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
25
|
-
eqx.field(static=True)
|
26
|
-
)
|
24
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
27
25
|
|
28
26
|
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
29
27
|
"""
|
@@ -88,7 +86,8 @@ class Radial(Flow):
|
|
88
86
|
# Compute density adjustment
|
89
87
|
ladj = jnp.log(
|
90
88
|
jnp.abs(
|
91
|
-
(1.0 + alpha * beta / (alpha + r) ** 2.0)
|
89
|
+
(1.0 + alpha * beta / (alpha + r) ** 2.0)
|
90
|
+
* (1.0 + x) ** (center.size - 1.0)
|
92
91
|
)
|
93
92
|
)
|
94
93
|
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import Callable, Dict
|
2
|
+
|
3
|
+
from jaxtyping import Array, Float
|
4
|
+
|
5
|
+
from bayinx.core import Flow
|
6
|
+
|
7
|
+
|
8
|
+
# TODO
|
9
|
+
class Sylvester(Flow):
|
10
|
+
"""
|
11
|
+
A sylvester flow.
|
12
|
+
|
13
|
+
# Attributes
|
14
|
+
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
15
|
+
- `constraints`: A dictionary of constraining transformations.
|
16
|
+
"""
|
17
|
+
|
18
|
+
params: Dict[str, Float[Array, "..."]]
|
19
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
@@ -31,15 +31,15 @@ class MeanField(Variational):
|
|
31
31
|
- `model`: A probabilistic `Model` object.
|
32
32
|
"""
|
33
33
|
# Partition model
|
34
|
-
params, self._constraints = eqx.partition(model,
|
34
|
+
params, self._constraints = eqx.partition(model, model.filter_spec())
|
35
35
|
|
36
36
|
# Flatten params component
|
37
|
-
|
37
|
+
params, self._unflatten = ravel_pytree(params)
|
38
38
|
|
39
39
|
# Initialize variational parameters
|
40
40
|
self.var_params = {
|
41
|
-
"mean":
|
42
|
-
"log_std": jnp.zeros(
|
41
|
+
"mean": params,
|
42
|
+
"log_std": jnp.zeros(params.size, dtype=params.dtype),
|
43
43
|
}
|
44
44
|
|
45
45
|
@eqx.filter_jit
|
@@ -23,8 +23,8 @@ class NormalizingFlow(Variational):
|
|
23
23
|
|
24
24
|
flows: list[Flow]
|
25
25
|
base: Variational
|
26
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
27
|
-
_constraints: Model
|
26
|
+
_unflatten: Callable[[Float[Array, "..."]], Model]
|
27
|
+
_constraints: Model
|
28
28
|
|
29
29
|
def __init__(self, base: Variational, flows: list[Flow], model: Model):
|
30
30
|
"""
|
@@ -39,7 +39,7 @@ class NormalizingFlow(Variational):
|
|
39
39
|
params, self._constraints = eqx.partition(model, eqx.is_array)
|
40
40
|
|
41
41
|
# Flatten params component
|
42
|
-
|
42
|
+
_, self._unflatten = jfu.ravel_pytree(params)
|
43
43
|
|
44
44
|
self.base = base
|
45
45
|
self.flows = flows
|
@@ -73,7 +73,7 @@ class NormalizingFlow(Variational):
|
|
73
73
|
return variational_evals
|
74
74
|
|
75
75
|
@eqx.filter_jit
|
76
|
-
def
|
76
|
+
def __eval(self, draws: Array, data=None) -> Tuple[Array, Array]:
|
77
77
|
"""
|
78
78
|
Evaluate the posterior and variational densities at the transformed
|
79
79
|
`draws` to avoid extra compute when requiring variational draws for
|
@@ -84,7 +84,7 @@ class NormalizingFlow(Variational):
|
|
84
84
|
- `data`: Any data required to evaluate the posterior density.
|
85
85
|
|
86
86
|
# Returns
|
87
|
-
The posterior and variational densities.
|
87
|
+
The posterior and variational densities as JAX Arrays.
|
88
88
|
"""
|
89
89
|
# Evaluate base density
|
90
90
|
variational_evals: Array = self.base.eval(draws)
|
@@ -102,30 +102,30 @@ class NormalizingFlow(Variational):
|
|
102
102
|
return posterior_evals, variational_evals
|
103
103
|
|
104
104
|
def filter_spec(self):
|
105
|
-
#
|
105
|
+
# Generate empty specification
|
106
106
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
107
|
+
|
108
|
+
# Specify variational parameters based on each flow's filter spec.
|
107
109
|
filter_spec = eqx.tree_at(
|
108
|
-
lambda
|
110
|
+
lambda vari: vari.flows,
|
109
111
|
filter_spec,
|
110
|
-
replace=
|
112
|
+
replace=[flow.filter_spec() for flow in self.flows],
|
111
113
|
)
|
112
114
|
|
113
115
|
return filter_spec
|
114
116
|
|
115
117
|
@eqx.filter_jit
|
116
118
|
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
117
|
-
# Partition
|
118
119
|
dyn, static = eqx.partition(self, self.filter_spec())
|
119
120
|
|
120
121
|
@eqx.filter_jit
|
121
122
|
def elbo(dyn: Self, n: int, key: Key, data: Any = None):
|
122
|
-
# Combine
|
123
123
|
self = eqx.combine(dyn, static)
|
124
124
|
|
125
125
|
# Sample draws from variational distribution
|
126
126
|
draws: Array = self.base.sample(n, key)
|
127
127
|
|
128
|
-
posterior_evals, variational_evals = self.
|
128
|
+
posterior_evals, variational_evals = self.__eval(draws, data)
|
129
129
|
# Evaluate ELBO
|
130
130
|
return jnp.mean(posterior_evals - variational_evals)
|
131
131
|
|
@@ -133,19 +133,17 @@ class NormalizingFlow(Variational):
|
|
133
133
|
|
134
134
|
@eqx.filter_jit
|
135
135
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
136
|
-
# Partition
|
137
136
|
dyn, static = eqx.partition(self, self.filter_spec())
|
138
137
|
|
139
138
|
@eqx.filter_grad
|
140
139
|
@eqx.filter_jit
|
141
140
|
def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
|
142
|
-
# Combine
|
143
141
|
self = eqx.combine(dyn, static)
|
144
142
|
|
145
143
|
# Sample draws from variational distribution
|
146
144
|
draws: Array = self.base.sample(n, key)
|
147
145
|
|
148
|
-
posterior_evals, variational_evals = self.
|
146
|
+
posterior_evals, variational_evals = self.__eval(draws, data)
|
149
147
|
# Evaluate ELBO
|
150
148
|
return jnp.mean(posterior_evals - variational_evals)
|
151
149
|
|
@@ -13,15 +13,15 @@ from bayinx.dists import normal
|
|
13
13
|
|
14
14
|
class Standard(Variational):
|
15
15
|
"""
|
16
|
-
A standard normal
|
16
|
+
A standard normal approximation to a posterior distribution.
|
17
17
|
|
18
18
|
# Attributes
|
19
19
|
- `dim`: Dimension of the parameter space.
|
20
20
|
"""
|
21
21
|
|
22
22
|
dim: int = eqx.field(static=True)
|
23
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
|
-
_constraints: Model
|
23
|
+
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
|
+
_constraints: Model
|
25
25
|
|
26
26
|
def __init__(self, model: Model):
|
27
27
|
"""
|
@@ -31,13 +31,13 @@ class Standard(Variational):
|
|
31
31
|
- `model`: A probabilistic `Model` object.
|
32
32
|
"""
|
33
33
|
# Partition model
|
34
|
-
|
34
|
+
params, self._constraints = eqx.partition(model, model.filter_spec())
|
35
35
|
|
36
36
|
# Flatten params component
|
37
|
-
|
37
|
+
params, self._unflatten = ravel_pytree(params)
|
38
38
|
|
39
39
|
# Store dimension of parameter space
|
40
|
-
self.dim = jnp.size(
|
40
|
+
self.dim = jnp.size(params)
|
41
41
|
|
42
42
|
@eqx.filter_jit
|
43
43
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
@@ -1,14 +1,12 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: bayinx
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.9
|
4
4
|
Summary: Bayesian Inference with JAX
|
5
5
|
Requires-Python: >=3.12
|
6
6
|
Requires-Dist: equinox>=0.11.12
|
7
7
|
Requires-Dist: jax>=0.4.38
|
8
8
|
Requires-Dist: jaxtyping>=0.2.36
|
9
9
|
Requires-Dist: optax>=0.2.4
|
10
|
-
Requires-Dist: pytest-benchmark>=5.1.0
|
11
|
-
Requires-Dist: pytest>=8.3.5
|
12
10
|
Description-Content-Type: text/markdown
|
13
11
|
|
14
12
|
# <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
@@ -23,7 +21,7 @@ In the short-term, I'm going to focus on:
|
|
23
21
|
In the long-term, I'm going to focus on:
|
24
22
|
1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
|
25
23
|
2) How to make working with the posterior as easy as possible.
|
26
|
-
- That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
|
24
|
+
- That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{D}, \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{D}, \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
|
27
25
|
|
28
26
|
Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
|
29
27
|
|
@@ -0,0 +1,26 @@
|
|
1
|
+
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
|
+
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
+
bayinx/core/flow.py,sha256=oZE0OHCninIHjp-WVLFWd1DaN0-qXxNWFAUAdgIDmRU,2423
|
5
|
+
bayinx/core/model.py,sha256=-rT3NHjxqGB0lDBMi0Mr9XNOz1_TUnJWtd4ITj0rsus,2257
|
6
|
+
bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
|
7
|
+
bayinx/core/variational.py,sha256=yhraTVlNOJaU1NEYVrWpUXVzzWvY1Mq9ZOZv6V0_Vo0,5183
|
8
|
+
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
10
|
+
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
+
bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
+
bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
|
14
|
+
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
15
|
+
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
16
|
+
bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
|
17
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=V0-R2Nc_3Vy5c2rkjDOwIutA0G-_kN6trxjwsT5FgsA,4774
|
18
|
+
bayinx/mhx/vi/standard.py,sha256=m5gtcHfrYzV28h-Red3Zn6SxEgJlndeIXiIG5gDPecU,1703
|
19
|
+
bayinx/mhx/vi/flows/__init__.py,sha256=V_Ng5cecKlLlFSI9ncmaiyvoy_d2EAfeDhBFcy5aQhA,168
|
20
|
+
bayinx/mhx/vi/flows/affine.py,sha256=a205nNx6KRvOwGlnjI6YeDo7OTWPPIxffGZfAcTecNA,1707
|
21
|
+
bayinx/mhx/vi/flows/planar.py,sha256=0BGdMm-GpTCJnxq9cOrgLl8IsHgGIL0eSFagWJNVdqQ,1944
|
22
|
+
bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
|
23
|
+
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
24
|
+
bayinx-0.2.9.dist-info/METADATA,sha256=xp6L_DdXPC-TMHV4SL5LdIuhzX8GizUlx2muMgSFcy0,3057
|
25
|
+
bayinx-0.2.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
26
|
+
bayinx-0.2.9.dist-info/RECORD,,
|
bayinx/machinery/__init__.py
DELETED
File without changes
|
@@ -1,76 +0,0 @@
|
|
1
|
-
from functools import partial
|
2
|
-
from typing import Callable, Dict, Tuple
|
3
|
-
|
4
|
-
import equinox as eqx
|
5
|
-
import jax
|
6
|
-
import jax.numpy as jnp
|
7
|
-
import jax.random as jr
|
8
|
-
from jaxtyping import Array, Float, Scalar
|
9
|
-
|
10
|
-
from bayinx.core import Flow
|
11
|
-
|
12
|
-
|
13
|
-
class Sylvester(Flow):
|
14
|
-
"""
|
15
|
-
A sylvester flow.
|
16
|
-
|
17
|
-
# Attributes
|
18
|
-
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
19
|
-
- `constraints`: A dictionary of constraining transformations.
|
20
|
-
"""
|
21
|
-
|
22
|
-
params: Dict[str, Float[Array, "..."]]
|
23
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
|
24
|
-
eqx.field(static=True)
|
25
|
-
)
|
26
|
-
|
27
|
-
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
28
|
-
"""
|
29
|
-
Initializes a planar flow.
|
30
|
-
|
31
|
-
# Parameters
|
32
|
-
- `dim`: The dimension of the parameter space.
|
33
|
-
"""
|
34
|
-
self.params = {
|
35
|
-
"u": jr.normal(key, (dim,)),
|
36
|
-
"w": jr.normal(key, (dim,)),
|
37
|
-
"b": jr.normal(key, (1,)),
|
38
|
-
}
|
39
|
-
self.constraints = {}
|
40
|
-
|
41
|
-
@eqx.filter_jit
|
42
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
43
|
-
def forward(self, draws: Array) -> Array:
|
44
|
-
params = self.constrain_pars()
|
45
|
-
|
46
|
-
# Extract parameters
|
47
|
-
w: Array = params["w"]
|
48
|
-
u: Array = params["u"]
|
49
|
-
b: Array = params["b"]
|
50
|
-
|
51
|
-
# Compute forward transformation
|
52
|
-
draws = draws + u * jnp.tanh(draws.dot(w) + b)
|
53
|
-
|
54
|
-
return draws
|
55
|
-
|
56
|
-
@eqx.filter_jit
|
57
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
58
|
-
def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
|
59
|
-
params = self.constrain_pars()
|
60
|
-
|
61
|
-
# Extract parameters
|
62
|
-
w: Array = params["w"]
|
63
|
-
u: Array = params["u"]
|
64
|
-
b: Array = params["b"]
|
65
|
-
|
66
|
-
# Compute shared intermediates
|
67
|
-
x: Array = draws.dot(w) + b
|
68
|
-
|
69
|
-
# Compute forward transformation
|
70
|
-
draws = draws + u * jnp.tanh(x)
|
71
|
-
|
72
|
-
# Compute ladj
|
73
|
-
h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
|
74
|
-
ladj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
|
75
|
-
|
76
|
-
return ladj, draws
|
bayinx-0.2.6.dist-info/RECORD
DELETED
@@ -1,26 +0,0 @@
|
|
1
|
-
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
|
-
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
-
bayinx/core/flow.py,sha256=4vj1t2xNPGp1VPE4xUshY-rHAw__KvSwjGDtKkW2taE,2252
|
5
|
-
bayinx/core/model.py,sha256=AI4eHrXAds3K7eWgZ9g5E6Kh76HP6WTn6s6q_0tnhck,1719
|
6
|
-
bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
|
7
|
-
bayinx/core/variational.py,sha256=T42uUNkF2tP1HJPyeIv7ISdika_G28wR_OOFXzx_hgo,4978
|
8
|
-
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
-
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
10
|
-
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
|
14
|
-
bayinx/machinery/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
-
bayinx/machinery/variational/__init__.py,sha256=5GdhqBOHKsXg2tZGAMNlxyrLPD0-s64wAEy8998cHZ4,247
|
16
|
-
bayinx/machinery/variational/meanfield.py,sha256=stI96DNhHNIROnr8rLNEN9SN_0lXqkwit0KNto34Q6A,3889
|
17
|
-
bayinx/machinery/variational/normalizing_flow.py,sha256=qypgPq9vIqSIJNOHCDaN-hvwFfttNQ5_yXqvmi5hslI,4796
|
18
|
-
bayinx/machinery/variational/standard.py,sha256=IQdNd5QIE8u3zcOw7K4EW69lIQ0ZUGGDvwZVyvrYHxA,1739
|
19
|
-
bayinx/machinery/variational/flows/__init__.py,sha256=VGh-ffuUfMso_0JxwGCJQ2yVnFJdOrkFsSnorojQldY,213
|
20
|
-
bayinx/machinery/variational/flows/affine.py,sha256=TPyUUPRoSkyDMwGO5wtq-Ei8DAUvlb_N6JCk7uPlbJQ,1748
|
21
|
-
bayinx/machinery/variational/flows/planar.py,sha256=rJ1XpqoWzig_5Udq6oCh5JV4ptlTRLRS7tb9DCX22lE,2013
|
22
|
-
bayinx/machinery/variational/flows/radial.py,sha256=NF1tCd_PH6m8eqjJkom2c30sRUQ04Vf8zeRx_RQCDcg,2526
|
23
|
-
bayinx/machinery/variational/flows/sylvester.py,sha256=DeZl4Fkz9XpCGsfDcjS0eWlrMR0xMO0MPfJvoDhixSA,2019
|
24
|
-
bayinx-0.2.6.dist-info/METADATA,sha256=LGWQJPIHXjtLo7f6X_bsNMwnu5DfKpJ9kRTAJ91fxKU,3099
|
25
|
-
bayinx-0.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
26
|
-
bayinx-0.2.6.dist-info/RECORD,,
|
File without changes
|