bayinx 0.2.6__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/core/flow.py +35 -23
- bayinx/core/model.py +19 -0
- bayinx/core/variational.py +19 -11
- bayinx/mhx/__init__.py +1 -0
- bayinx/mhx/vi/__init__.py +3 -0
- bayinx/mhx/vi/flows/__init__.py +3 -0
- bayinx/{machinery/variational → mhx/vi}/flows/affine.py +1 -3
- bayinx/{machinery/variational → mhx/vi}/flows/planar.py +1 -3
- bayinx/{machinery/variational → mhx/vi}/flows/radial.py +3 -4
- bayinx/mhx/vi/flows/sylvester.py +19 -0
- bayinx/{machinery/variational → mhx/vi}/meanfield.py +4 -4
- bayinx/{machinery/variational → mhx/vi}/normalizing_flow.py +12 -14
- bayinx/{machinery/variational → mhx/vi}/standard.py +6 -6
- {bayinx-0.2.6.dist-info → bayinx-0.2.10.dist-info}/METADATA +2 -4
- bayinx-0.2.10.dist-info/RECORD +26 -0
- bayinx/machinery/__init__.py +0 -0
- bayinx/machinery/variational/__init__.py +0 -5
- bayinx/machinery/variational/flows/__init__.py +0 -3
- bayinx/machinery/variational/flows/sylvester.py +0 -76
- bayinx-0.2.6.dist-info/RECORD +0 -26
- {bayinx-0.2.6.dist-info → bayinx-0.2.10.dist-info}/WHEEL +0 -0
bayinx/core/flow.py
CHANGED
@@ -2,6 +2,7 @@ from abc import abstractmethod
|
|
2
2
|
from typing import Callable, Dict, Self, Tuple
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
|
+
import jax.tree_util as jtu
|
5
6
|
from jaxtyping import Array, Float
|
6
7
|
|
7
8
|
from bayinx.core.utils import __MyMeta
|
@@ -36,33 +37,44 @@ class Flow(eqx.Module, metaclass=__MyMeta):
|
|
36
37
|
"""
|
37
38
|
pass
|
38
39
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
40
|
+
# Default filter specification
|
41
|
+
def filter_spec(self):
|
42
|
+
"""
|
43
|
+
Generates a filter specification to subset relevant parameters for the flow.
|
44
|
+
"""
|
45
|
+
# Generate empty specification
|
46
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
44
47
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
48
|
+
# Specify JAX Array parameters
|
49
|
+
filter_spec = eqx.tree_at(
|
50
|
+
lambda flow: flow.params,
|
51
|
+
filter_spec,
|
52
|
+
replace=jtu.tree_map(eqx.is_array, self.params),
|
53
|
+
)
|
49
54
|
|
50
|
-
|
51
|
-
t_params[par] = map(t_params[par])
|
55
|
+
return filter_spec
|
52
56
|
|
53
|
-
|
57
|
+
@eqx.filter_jit
|
58
|
+
def constrain_pars(self: Self):
|
59
|
+
"""
|
60
|
+
Constrain `params` to the appropriate domain.
|
54
61
|
|
55
|
-
|
62
|
+
# Returns
|
63
|
+
A dictionary of transformed JAX Arrays representing the constrained parameters.
|
64
|
+
"""
|
65
|
+
t_params = self.params
|
56
66
|
|
57
|
-
|
58
|
-
|
59
|
-
def transform_pars(self: Self) -> Dict[str, Array]:
|
60
|
-
"""
|
61
|
-
Apply a custom transformation to `params` if needed.
|
67
|
+
for par, map in self.constraints.items():
|
68
|
+
t_params[par] = map(t_params[par])
|
62
69
|
|
63
|
-
|
64
|
-
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
65
|
-
"""
|
66
|
-
return self.constrain_pars()
|
70
|
+
return t_params
|
67
71
|
|
68
|
-
|
72
|
+
@eqx.filter_jit
|
73
|
+
def transform_pars(self: Self) -> Dict[str, Array]:
|
74
|
+
"""
|
75
|
+
Apply a custom transformation to `params` if needed.
|
76
|
+
|
77
|
+
# Returns
|
78
|
+
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
79
|
+
"""
|
80
|
+
return self.constrain_pars()
|
bayinx/core/model.py
CHANGED
@@ -2,6 +2,7 @@ from abc import abstractmethod
|
|
2
2
|
from typing import Any, Callable, Dict
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
|
+
import jax.tree_util as jtu
|
5
6
|
from jaxtyping import Array, Scalar
|
6
7
|
|
7
8
|
from bayinx.core.utils import __MyMeta
|
@@ -23,6 +24,23 @@ class Model(eqx.Module, metaclass=__MyMeta):
|
|
23
24
|
def eval(self, data: Any) -> Scalar:
|
24
25
|
pass
|
25
26
|
|
27
|
+
# Default filter specification
|
28
|
+
def filter_spec(self):
|
29
|
+
"""
|
30
|
+
Generates a filter specification to subset relevant parameters for the model.
|
31
|
+
"""
|
32
|
+
# Generate empty specification
|
33
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
34
|
+
|
35
|
+
# Specify JAX Array parameters
|
36
|
+
filter_spec = eqx.tree_at(
|
37
|
+
lambda model: model.params,
|
38
|
+
filter_spec,
|
39
|
+
replace=jtu.tree_map(eqx.is_array, self.params),
|
40
|
+
)
|
41
|
+
|
42
|
+
return filter_spec
|
43
|
+
|
26
44
|
def __init_subclass__(cls):
|
27
45
|
# Add constrain method
|
28
46
|
def constrain_pars(self: Model) -> Dict[str, Array]:
|
@@ -43,6 +61,7 @@ class Model(eqx.Module, metaclass=__MyMeta):
|
|
43
61
|
|
44
62
|
# Add transform_pars method if not present
|
45
63
|
if not callable(getattr(cls, "transform_pars", None)):
|
64
|
+
|
46
65
|
def transform_pars(self: Model) -> Dict[str, Array]:
|
47
66
|
"""
|
48
67
|
Apply a custom transformation to `params` if needed.
|
bayinx/core/variational.py
CHANGED
@@ -89,6 +89,7 @@ class Variational(eqx.Module):
|
|
89
89
|
max_iters: int,
|
90
90
|
data: Any = None,
|
91
91
|
learning_rate: float = 1,
|
92
|
+
weight_decay: float = 1e-4,
|
92
93
|
tolerance: float = 1e-4,
|
93
94
|
var_draws: int = 1,
|
94
95
|
key: Key = jr.PRNGKey(0),
|
@@ -104,6 +105,9 @@ class Variational(eqx.Module):
|
|
104
105
|
- `var_draws`: Number of variational draws to draw each iteration.
|
105
106
|
- `key`: A PRNG key.
|
106
107
|
"""
|
108
|
+
# Partition variational
|
109
|
+
dyn, static = eqx.partition(self, self.filter_spec())
|
110
|
+
|
107
111
|
# Construct scheduler
|
108
112
|
schedule: Schedule = opx.cosine_decay_schedule(
|
109
113
|
init_value=learning_rate, decay_steps=max_iters
|
@@ -111,22 +115,22 @@ class Variational(eqx.Module):
|
|
111
115
|
|
112
116
|
# Initialize optimizer
|
113
117
|
optim: GradientTransformation = opx.chain(
|
114
|
-
opx.scale(-1.0), opx.
|
118
|
+
opx.scale(-1.0), opx.nadamw(schedule,weight_decay=weight_decay)
|
115
119
|
)
|
116
|
-
opt_state: OptState = optim.init(
|
120
|
+
opt_state: OptState = optim.init(dyn)
|
117
121
|
|
118
122
|
# Optimization loop helper functions
|
119
123
|
@eqx.filter_jit
|
120
124
|
def condition(state: Tuple[Self, OptState, Scalar, Key]):
|
121
125
|
# Unpack iteration state
|
122
|
-
|
126
|
+
dyn, opt_state, i, key = state
|
123
127
|
|
124
128
|
return i < max_iters
|
125
129
|
|
126
130
|
@eqx.filter_jit
|
127
131
|
def body(state: Tuple[Self, OptState, Scalar, Key]):
|
128
132
|
# Unpack iteration state
|
129
|
-
|
133
|
+
dyn, opt_state, i, key = state
|
130
134
|
|
131
135
|
# Update iteration
|
132
136
|
i = i + 1
|
@@ -134,26 +138,30 @@ class Variational(eqx.Module):
|
|
134
138
|
# Update PRNG key
|
135
139
|
key, _ = jr.split(key)
|
136
140
|
|
141
|
+
# Combine variational
|
142
|
+
vari = eqx.combine(dyn, static)
|
143
|
+
|
137
144
|
# Compute gradient of the ELBO
|
138
|
-
updates: PyTree =
|
145
|
+
updates: PyTree = vari.elbo_grad(var_draws, key, data)
|
139
146
|
|
140
147
|
# Compute updates
|
141
148
|
updates, opt_state = optim.update(
|
142
|
-
updates, opt_state, eqx.filter(
|
149
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
|
143
150
|
)
|
144
151
|
|
145
152
|
# Update variational distribution
|
146
|
-
|
153
|
+
dyn = eqx.apply_updates(dyn, updates)
|
147
154
|
|
148
|
-
return
|
155
|
+
return dyn, opt_state, i, key
|
149
156
|
|
150
157
|
# Run optimization loop
|
151
|
-
|
158
|
+
dyn = lax.while_loop(
|
152
159
|
cond_fun=condition,
|
153
160
|
body_fun=body,
|
154
|
-
init_val=(
|
161
|
+
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
|
155
162
|
)[0]
|
156
163
|
|
157
|
-
|
164
|
+
# Return optimized variational
|
165
|
+
return eqx.combine(dyn, static)
|
158
166
|
|
159
167
|
cls.fit = eqx.filter_jit(fit)
|
bayinx/mhx/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -19,9 +19,7 @@ class Affine(Flow):
|
|
19
19
|
"""
|
20
20
|
|
21
21
|
params: Dict[str, Float[Array, "..."]]
|
22
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
23
|
-
eqx.field(static=True)
|
24
|
-
)
|
22
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
25
23
|
|
26
24
|
def __init__(self, dim: int):
|
27
25
|
"""
|
@@ -20,9 +20,7 @@ class Planar(Flow):
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
params: Dict[str, Float[Array, "..."]]
|
23
|
-
constraints: Dict[str, Callable[[
|
24
|
-
eqx.field(static=True)
|
25
|
-
)
|
23
|
+
constraints: Dict[str, Callable[[Array], Array]]
|
26
24
|
|
27
25
|
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
28
26
|
"""
|
@@ -21,9 +21,7 @@ class Radial(Flow):
|
|
21
21
|
"""
|
22
22
|
|
23
23
|
params: Dict[str, Float[Array, "..."]]
|
24
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
25
|
-
eqx.field(static=True)
|
26
|
-
)
|
24
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
27
25
|
|
28
26
|
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
29
27
|
"""
|
@@ -88,7 +86,8 @@ class Radial(Flow):
|
|
88
86
|
# Compute density adjustment
|
89
87
|
ladj = jnp.log(
|
90
88
|
jnp.abs(
|
91
|
-
(1.0 + alpha * beta / (alpha + r) ** 2.0)
|
89
|
+
(1.0 + alpha * beta / (alpha + r) ** 2.0)
|
90
|
+
* (1.0 + x) ** (center.size - 1.0)
|
92
91
|
)
|
93
92
|
)
|
94
93
|
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import Callable, Dict
|
2
|
+
|
3
|
+
from jaxtyping import Array, Float
|
4
|
+
|
5
|
+
from bayinx.core import Flow
|
6
|
+
|
7
|
+
|
8
|
+
# TODO
|
9
|
+
class Sylvester(Flow):
|
10
|
+
"""
|
11
|
+
A sylvester flow.
|
12
|
+
|
13
|
+
# Attributes
|
14
|
+
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
15
|
+
- `constraints`: A dictionary of constraining transformations.
|
16
|
+
"""
|
17
|
+
|
18
|
+
params: Dict[str, Float[Array, "..."]]
|
19
|
+
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
|
@@ -31,15 +31,15 @@ class MeanField(Variational):
|
|
31
31
|
- `model`: A probabilistic `Model` object.
|
32
32
|
"""
|
33
33
|
# Partition model
|
34
|
-
params, self._constraints = eqx.partition(model,
|
34
|
+
params, self._constraints = eqx.partition(model, model.filter_spec())
|
35
35
|
|
36
36
|
# Flatten params component
|
37
|
-
|
37
|
+
params, self._unflatten = ravel_pytree(params)
|
38
38
|
|
39
39
|
# Initialize variational parameters
|
40
40
|
self.var_params = {
|
41
|
-
"mean":
|
42
|
-
"log_std": jnp.zeros(
|
41
|
+
"mean": params,
|
42
|
+
"log_std": jnp.zeros(params.size, dtype=params.dtype),
|
43
43
|
}
|
44
44
|
|
45
45
|
@eqx.filter_jit
|
@@ -23,8 +23,8 @@ class NormalizingFlow(Variational):
|
|
23
23
|
|
24
24
|
flows: list[Flow]
|
25
25
|
base: Variational
|
26
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
27
|
-
_constraints: Model
|
26
|
+
_unflatten: Callable[[Float[Array, "..."]], Model]
|
27
|
+
_constraints: Model
|
28
28
|
|
29
29
|
def __init__(self, base: Variational, flows: list[Flow], model: Model):
|
30
30
|
"""
|
@@ -39,7 +39,7 @@ class NormalizingFlow(Variational):
|
|
39
39
|
params, self._constraints = eqx.partition(model, eqx.is_array)
|
40
40
|
|
41
41
|
# Flatten params component
|
42
|
-
|
42
|
+
_, self._unflatten = jfu.ravel_pytree(params)
|
43
43
|
|
44
44
|
self.base = base
|
45
45
|
self.flows = flows
|
@@ -73,7 +73,7 @@ class NormalizingFlow(Variational):
|
|
73
73
|
return variational_evals
|
74
74
|
|
75
75
|
@eqx.filter_jit
|
76
|
-
def
|
76
|
+
def __eval(self, draws: Array, data=None) -> Tuple[Array, Array]:
|
77
77
|
"""
|
78
78
|
Evaluate the posterior and variational densities at the transformed
|
79
79
|
`draws` to avoid extra compute when requiring variational draws for
|
@@ -84,7 +84,7 @@ class NormalizingFlow(Variational):
|
|
84
84
|
- `data`: Any data required to evaluate the posterior density.
|
85
85
|
|
86
86
|
# Returns
|
87
|
-
The posterior and variational densities.
|
87
|
+
The posterior and variational densities as JAX Arrays.
|
88
88
|
"""
|
89
89
|
# Evaluate base density
|
90
90
|
variational_evals: Array = self.base.eval(draws)
|
@@ -102,30 +102,30 @@ class NormalizingFlow(Variational):
|
|
102
102
|
return posterior_evals, variational_evals
|
103
103
|
|
104
104
|
def filter_spec(self):
|
105
|
-
#
|
105
|
+
# Generate empty specification
|
106
106
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
107
|
+
|
108
|
+
# Specify variational parameters based on each flow's filter spec.
|
107
109
|
filter_spec = eqx.tree_at(
|
108
|
-
lambda
|
110
|
+
lambda vari: vari.flows,
|
109
111
|
filter_spec,
|
110
|
-
replace=
|
112
|
+
replace=[flow.filter_spec() for flow in self.flows],
|
111
113
|
)
|
112
114
|
|
113
115
|
return filter_spec
|
114
116
|
|
115
117
|
@eqx.filter_jit
|
116
118
|
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
117
|
-
# Partition
|
118
119
|
dyn, static = eqx.partition(self, self.filter_spec())
|
119
120
|
|
120
121
|
@eqx.filter_jit
|
121
122
|
def elbo(dyn: Self, n: int, key: Key, data: Any = None):
|
122
|
-
# Combine
|
123
123
|
self = eqx.combine(dyn, static)
|
124
124
|
|
125
125
|
# Sample draws from variational distribution
|
126
126
|
draws: Array = self.base.sample(n, key)
|
127
127
|
|
128
|
-
posterior_evals, variational_evals = self.
|
128
|
+
posterior_evals, variational_evals = self.__eval(draws, data)
|
129
129
|
# Evaluate ELBO
|
130
130
|
return jnp.mean(posterior_evals - variational_evals)
|
131
131
|
|
@@ -133,19 +133,17 @@ class NormalizingFlow(Variational):
|
|
133
133
|
|
134
134
|
@eqx.filter_jit
|
135
135
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
136
|
-
# Partition
|
137
136
|
dyn, static = eqx.partition(self, self.filter_spec())
|
138
137
|
|
139
138
|
@eqx.filter_grad
|
140
139
|
@eqx.filter_jit
|
141
140
|
def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
|
142
|
-
# Combine
|
143
141
|
self = eqx.combine(dyn, static)
|
144
142
|
|
145
143
|
# Sample draws from variational distribution
|
146
144
|
draws: Array = self.base.sample(n, key)
|
147
145
|
|
148
|
-
posterior_evals, variational_evals = self.
|
146
|
+
posterior_evals, variational_evals = self.__eval(draws, data)
|
149
147
|
# Evaluate ELBO
|
150
148
|
return jnp.mean(posterior_evals - variational_evals)
|
151
149
|
|
@@ -13,15 +13,15 @@ from bayinx.dists import normal
|
|
13
13
|
|
14
14
|
class Standard(Variational):
|
15
15
|
"""
|
16
|
-
A standard normal
|
16
|
+
A standard normal approximation to a posterior distribution.
|
17
17
|
|
18
18
|
# Attributes
|
19
19
|
- `dim`: Dimension of the parameter space.
|
20
20
|
"""
|
21
21
|
|
22
22
|
dim: int = eqx.field(static=True)
|
23
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
|
-
_constraints: Model
|
23
|
+
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
|
+
_constraints: Model
|
25
25
|
|
26
26
|
def __init__(self, model: Model):
|
27
27
|
"""
|
@@ -31,13 +31,13 @@ class Standard(Variational):
|
|
31
31
|
- `model`: A probabilistic `Model` object.
|
32
32
|
"""
|
33
33
|
# Partition model
|
34
|
-
|
34
|
+
params, self._constraints = eqx.partition(model, model.filter_spec())
|
35
35
|
|
36
36
|
# Flatten params component
|
37
|
-
|
37
|
+
params, self._unflatten = ravel_pytree(params)
|
38
38
|
|
39
39
|
# Store dimension of parameter space
|
40
|
-
self.dim = jnp.size(
|
40
|
+
self.dim = jnp.size(params)
|
41
41
|
|
42
42
|
@eqx.filter_jit
|
43
43
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
@@ -1,14 +1,12 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: bayinx
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.10
|
4
4
|
Summary: Bayesian Inference with JAX
|
5
5
|
Requires-Python: >=3.12
|
6
6
|
Requires-Dist: equinox>=0.11.12
|
7
7
|
Requires-Dist: jax>=0.4.38
|
8
8
|
Requires-Dist: jaxtyping>=0.2.36
|
9
9
|
Requires-Dist: optax>=0.2.4
|
10
|
-
Requires-Dist: pytest-benchmark>=5.1.0
|
11
|
-
Requires-Dist: pytest>=8.3.5
|
12
10
|
Description-Content-Type: text/markdown
|
13
11
|
|
14
12
|
# <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
|
@@ -23,7 +21,7 @@ In the short-term, I'm going to focus on:
|
|
23
21
|
In the long-term, I'm going to focus on:
|
24
22
|
1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
|
25
23
|
2) How to make working with the posterior as easy as possible.
|
26
|
-
- That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
|
24
|
+
- That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{D}, \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{D}, \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
|
27
25
|
|
28
26
|
Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
|
29
27
|
|
@@ -0,0 +1,26 @@
|
|
1
|
+
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
|
+
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
+
bayinx/core/flow.py,sha256=oZE0OHCninIHjp-WVLFWd1DaN0-qXxNWFAUAdgIDmRU,2423
|
5
|
+
bayinx/core/model.py,sha256=-rT3NHjxqGB0lDBMi0Mr9XNOz1_TUnJWtd4ITj0rsus,2257
|
6
|
+
bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
|
7
|
+
bayinx/core/variational.py,sha256=3CsDyQkq1XgV2ZBLzGrm5XgUFoJBnT6glHDgxHNcbTc,5250
|
8
|
+
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
10
|
+
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
+
bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
+
bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
|
14
|
+
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
15
|
+
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
16
|
+
bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
|
17
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=V0-R2Nc_3Vy5c2rkjDOwIutA0G-_kN6trxjwsT5FgsA,4774
|
18
|
+
bayinx/mhx/vi/standard.py,sha256=m5gtcHfrYzV28h-Red3Zn6SxEgJlndeIXiIG5gDPecU,1703
|
19
|
+
bayinx/mhx/vi/flows/__init__.py,sha256=V_Ng5cecKlLlFSI9ncmaiyvoy_d2EAfeDhBFcy5aQhA,168
|
20
|
+
bayinx/mhx/vi/flows/affine.py,sha256=a205nNx6KRvOwGlnjI6YeDo7OTWPPIxffGZfAcTecNA,1707
|
21
|
+
bayinx/mhx/vi/flows/planar.py,sha256=0BGdMm-GpTCJnxq9cOrgLl8IsHgGIL0eSFagWJNVdqQ,1944
|
22
|
+
bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
|
23
|
+
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
24
|
+
bayinx-0.2.10.dist-info/METADATA,sha256=7Ej3pWMyQr0xLMmWb1WPhRDyxIQHiJ2sNfbTHkCCJ-E,3058
|
25
|
+
bayinx-0.2.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
26
|
+
bayinx-0.2.10.dist-info/RECORD,,
|
bayinx/machinery/__init__.py
DELETED
File without changes
|
@@ -1,76 +0,0 @@
|
|
1
|
-
from functools import partial
|
2
|
-
from typing import Callable, Dict, Tuple
|
3
|
-
|
4
|
-
import equinox as eqx
|
5
|
-
import jax
|
6
|
-
import jax.numpy as jnp
|
7
|
-
import jax.random as jr
|
8
|
-
from jaxtyping import Array, Float, Scalar
|
9
|
-
|
10
|
-
from bayinx.core import Flow
|
11
|
-
|
12
|
-
|
13
|
-
class Sylvester(Flow):
|
14
|
-
"""
|
15
|
-
A sylvester flow.
|
16
|
-
|
17
|
-
# Attributes
|
18
|
-
- `params`: A dictionary containing the JAX Arrays representing the flow parameters.
|
19
|
-
- `constraints`: A dictionary of constraining transformations.
|
20
|
-
"""
|
21
|
-
|
22
|
-
params: Dict[str, Float[Array, "..."]]
|
23
|
-
constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
|
24
|
-
eqx.field(static=True)
|
25
|
-
)
|
26
|
-
|
27
|
-
def __init__(self, dim: int, key=jr.PRNGKey(0)):
|
28
|
-
"""
|
29
|
-
Initializes a planar flow.
|
30
|
-
|
31
|
-
# Parameters
|
32
|
-
- `dim`: The dimension of the parameter space.
|
33
|
-
"""
|
34
|
-
self.params = {
|
35
|
-
"u": jr.normal(key, (dim,)),
|
36
|
-
"w": jr.normal(key, (dim,)),
|
37
|
-
"b": jr.normal(key, (1,)),
|
38
|
-
}
|
39
|
-
self.constraints = {}
|
40
|
-
|
41
|
-
@eqx.filter_jit
|
42
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
43
|
-
def forward(self, draws: Array) -> Array:
|
44
|
-
params = self.constrain_pars()
|
45
|
-
|
46
|
-
# Extract parameters
|
47
|
-
w: Array = params["w"]
|
48
|
-
u: Array = params["u"]
|
49
|
-
b: Array = params["b"]
|
50
|
-
|
51
|
-
# Compute forward transformation
|
52
|
-
draws = draws + u * jnp.tanh(draws.dot(w) + b)
|
53
|
-
|
54
|
-
return draws
|
55
|
-
|
56
|
-
@eqx.filter_jit
|
57
|
-
@partial(jax.vmap, in_axes=(None, 0))
|
58
|
-
def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
|
59
|
-
params = self.constrain_pars()
|
60
|
-
|
61
|
-
# Extract parameters
|
62
|
-
w: Array = params["w"]
|
63
|
-
u: Array = params["u"]
|
64
|
-
b: Array = params["b"]
|
65
|
-
|
66
|
-
# Compute shared intermediates
|
67
|
-
x: Array = draws.dot(w) + b
|
68
|
-
|
69
|
-
# Compute forward transformation
|
70
|
-
draws = draws + u * jnp.tanh(x)
|
71
|
-
|
72
|
-
# Compute ladj
|
73
|
-
h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
|
74
|
-
ladj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
|
75
|
-
|
76
|
-
return ladj, draws
|
bayinx-0.2.6.dist-info/RECORD
DELETED
@@ -1,26 +0,0 @@
|
|
1
|
-
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
|
-
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
4
|
-
bayinx/core/flow.py,sha256=4vj1t2xNPGp1VPE4xUshY-rHAw__KvSwjGDtKkW2taE,2252
|
5
|
-
bayinx/core/model.py,sha256=AI4eHrXAds3K7eWgZ9g5E6Kh76HP6WTn6s6q_0tnhck,1719
|
6
|
-
bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
|
7
|
-
bayinx/core/variational.py,sha256=T42uUNkF2tP1HJPyeIv7ISdika_G28wR_OOFXzx_hgo,4978
|
8
|
-
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
-
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
10
|
-
bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
|
14
|
-
bayinx/machinery/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
-
bayinx/machinery/variational/__init__.py,sha256=5GdhqBOHKsXg2tZGAMNlxyrLPD0-s64wAEy8998cHZ4,247
|
16
|
-
bayinx/machinery/variational/meanfield.py,sha256=stI96DNhHNIROnr8rLNEN9SN_0lXqkwit0KNto34Q6A,3889
|
17
|
-
bayinx/machinery/variational/normalizing_flow.py,sha256=qypgPq9vIqSIJNOHCDaN-hvwFfttNQ5_yXqvmi5hslI,4796
|
18
|
-
bayinx/machinery/variational/standard.py,sha256=IQdNd5QIE8u3zcOw7K4EW69lIQ0ZUGGDvwZVyvrYHxA,1739
|
19
|
-
bayinx/machinery/variational/flows/__init__.py,sha256=VGh-ffuUfMso_0JxwGCJQ2yVnFJdOrkFsSnorojQldY,213
|
20
|
-
bayinx/machinery/variational/flows/affine.py,sha256=TPyUUPRoSkyDMwGO5wtq-Ei8DAUvlb_N6JCk7uPlbJQ,1748
|
21
|
-
bayinx/machinery/variational/flows/planar.py,sha256=rJ1XpqoWzig_5Udq6oCh5JV4ptlTRLRS7tb9DCX22lE,2013
|
22
|
-
bayinx/machinery/variational/flows/radial.py,sha256=NF1tCd_PH6m8eqjJkom2c30sRUQ04Vf8zeRx_RQCDcg,2526
|
23
|
-
bayinx/machinery/variational/flows/sylvester.py,sha256=DeZl4Fkz9XpCGsfDcjS0eWlrMR0xMO0MPfJvoDhixSA,2019
|
24
|
-
bayinx-0.2.6.dist-info/METADATA,sha256=LGWQJPIHXjtLo7f6X_bsNMwnu5DfKpJ9kRTAJ91fxKU,3099
|
25
|
-
bayinx-0.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
26
|
-
bayinx-0.2.6.dist-info/RECORD,,
|
File without changes
|