bayinx 0.2.33__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +2 -1
- bayinx/constraints/lower.py +25 -11
- bayinx/core/__init__.py +1 -0
- bayinx/core/constraint.py +8 -6
- bayinx/core/flow.py +6 -4
- bayinx/core/model.py +12 -9
- bayinx/core/parameter.py +41 -0
- bayinx/core/variational.py +2 -2
- bayinx/mhx/vi/flows/fullaffine.py +2 -2
- bayinx/mhx/vi/flows/planar.py +2 -2
- bayinx/mhx/vi/flows/radial.py +2 -2
- bayinx/mhx/vi/meanfield.py +19 -17
- bayinx/mhx/vi/normalizing_flow.py +6 -4
- bayinx/mhx/vi/standard.py +3 -3
- {bayinx-0.2.33.dist-info → bayinx-0.3.0.dist-info}/METADATA +1 -1
- bayinx-0.3.0.dist-info/RECORD +30 -0
- bayinx-0.2.33.dist-info/RECORD +0 -29
- {bayinx-0.2.33.dist-info → bayinx-0.3.0.dist-info}/WHEEL +0 -0
bayinx/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1
|
-
from bayinx.core
|
1
|
+
from bayinx.core import Model as Model
|
2
|
+
from bayinx.core import Parameter as Parameter
|
bayinx/constraints/lower.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1
1
|
from typing import Tuple
|
2
2
|
|
3
|
+
import equinox as eqx
|
3
4
|
import jax.numpy as jnp
|
4
|
-
|
5
|
+
import jax.tree as jt
|
6
|
+
from jaxtyping import PyTree, Scalar, ScalarLike
|
5
7
|
|
6
8
|
from bayinx.core.constraint import Constraint
|
9
|
+
from bayinx.core.parameter import Parameter
|
7
10
|
|
8
11
|
|
9
12
|
class Lower(Constraint):
|
@@ -11,27 +14,38 @@ class Lower(Constraint):
|
|
11
14
|
Enforces a lower bound on the parameter.
|
12
15
|
"""
|
13
16
|
|
14
|
-
lb:
|
17
|
+
lb: Scalar
|
15
18
|
|
16
19
|
def __init__(self, lb: ScalarLike):
|
17
|
-
self.lb = lb
|
20
|
+
self.lb = jnp.array(lb)
|
18
21
|
|
19
|
-
|
22
|
+
@eqx.filter_jit
|
23
|
+
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
20
24
|
"""
|
21
|
-
|
25
|
+
Enforces a lower bound on the parameter and adjusts the posterior density.
|
22
26
|
|
23
27
|
# Parameters
|
24
|
-
- `x`: The unconstrained
|
28
|
+
- `x`: The unconstrained `Parameter`.
|
25
29
|
|
26
30
|
# Parameters
|
27
31
|
A tuple containing:
|
28
|
-
-
|
29
|
-
- A scalar
|
32
|
+
- A modified `Parameter` with relevant leaves satisfying the constraint.
|
33
|
+
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
30
34
|
"""
|
31
|
-
#
|
32
|
-
|
35
|
+
# Extract relevant filter specification
|
36
|
+
filter_spec = x.filter_spec
|
37
|
+
|
38
|
+
# Extract relevant parameters(all Array)
|
39
|
+
dyn_params, static_params = eqx.partition(x, filter_spec)
|
40
|
+
|
41
|
+
# Compute density adjustment
|
42
|
+
laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
|
43
|
+
laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
|
33
44
|
|
34
45
|
# Compute transformation
|
35
|
-
|
46
|
+
dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
|
47
|
+
|
48
|
+
# Combine into full parameter object
|
49
|
+
x = eqx.combine(dyn_params, static_params)
|
36
50
|
|
37
51
|
return x, laj
|
bayinx/core/__init__.py
CHANGED
bayinx/core/constraint.py
CHANGED
@@ -2,7 +2,9 @@ from abc import abstractmethod
|
|
2
2
|
from typing import Tuple
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
|
-
from jaxtyping import
|
5
|
+
from jaxtyping import Scalar
|
6
|
+
|
7
|
+
from bayinx.core.parameter import Parameter
|
6
8
|
|
7
9
|
|
8
10
|
class Constraint(eqx.Module):
|
@@ -11,16 +13,16 @@ class Constraint(eqx.Module):
|
|
11
13
|
"""
|
12
14
|
|
13
15
|
@abstractmethod
|
14
|
-
def constrain(self, x:
|
16
|
+
def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
|
15
17
|
"""
|
16
|
-
Applies the constraining transformation to
|
18
|
+
Applies the constraining transformation to a parameter and computes the log-absolute-Jacobian of the transformation.
|
17
19
|
|
18
20
|
# Parameters
|
19
|
-
- `x`: The unconstrained
|
21
|
+
- `x`: The unconstrained `Parameter`.
|
20
22
|
|
21
23
|
# Returns
|
22
24
|
A tuple containing:
|
23
|
-
- The constrained
|
24
|
-
- A scalar
|
25
|
+
- The constrained `Parameter`.
|
26
|
+
- A scalar Array representing the log-absolute-Jacobian of the transformation.
|
25
27
|
"""
|
26
28
|
pass
|
bayinx/core/flow.py
CHANGED
@@ -31,11 +31,13 @@ class Flow(eqx.Module):
|
|
31
31
|
Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
|
32
32
|
|
33
33
|
# Returns
|
34
|
-
|
34
|
+
A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
|
35
35
|
"""
|
36
36
|
pass
|
37
37
|
|
38
38
|
# Default filter specification
|
39
|
+
@property
|
40
|
+
@eqx.filter_jit
|
39
41
|
def filter_spec(self):
|
40
42
|
"""
|
41
43
|
Generates a filter specification to subset relevant parameters for the flow.
|
@@ -53,7 +55,7 @@ class Flow(eqx.Module):
|
|
53
55
|
return filter_spec
|
54
56
|
|
55
57
|
@eqx.filter_jit
|
56
|
-
def
|
58
|
+
def constrain_params(self: Self):
|
57
59
|
"""
|
58
60
|
Constrain `params` to the appropriate domain.
|
59
61
|
|
@@ -68,11 +70,11 @@ class Flow(eqx.Module):
|
|
68
70
|
return t_params
|
69
71
|
|
70
72
|
@eqx.filter_jit
|
71
|
-
def
|
73
|
+
def transform_params(self: Self) -> Dict[str, Array]:
|
72
74
|
"""
|
73
75
|
Apply a custom transformation to `params` if needed.
|
74
76
|
|
75
77
|
# Returns
|
76
78
|
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
77
79
|
"""
|
78
|
-
return self.
|
80
|
+
return self.constrain_params()
|
bayinx/core/model.py
CHANGED
@@ -7,6 +7,7 @@ import jax.tree as jt
|
|
7
7
|
from jaxtyping import Array, PyTree, Scalar
|
8
8
|
|
9
9
|
from bayinx.core.constraint import Constraint
|
10
|
+
from bayinx.core.parameter import Parameter
|
10
11
|
|
11
12
|
|
12
13
|
class Model(eqx.Module):
|
@@ -14,11 +15,11 @@ class Model(eqx.Module):
|
|
14
15
|
An abstract base class used to define probabilistic models.
|
15
16
|
|
16
17
|
# Attributes
|
17
|
-
- `params`: A dictionary of
|
18
|
+
- `params`: A dictionary of Arrays or PyTrees containing Arrays representing parameters of the model.
|
18
19
|
- `constraints`: A dictionary of constraints.
|
19
20
|
"""
|
20
21
|
|
21
|
-
params: Dict[str,
|
22
|
+
params: Dict[str, Parameter]
|
22
23
|
constraints: Dict[str, Constraint]
|
23
24
|
|
24
25
|
@abstractmethod
|
@@ -26,6 +27,8 @@ class Model(eqx.Module):
|
|
26
27
|
pass
|
27
28
|
|
28
29
|
# Default filter specification
|
30
|
+
@property
|
31
|
+
@eqx.filter_jit
|
29
32
|
def filter_spec(self):
|
30
33
|
"""
|
31
34
|
Generates a filter specification to subset relevant parameters for the model.
|
@@ -33,25 +36,25 @@ class Model(eqx.Module):
|
|
33
36
|
# Generate empty specification
|
34
37
|
filter_spec = jt.map(lambda _: False, self)
|
35
38
|
|
36
|
-
# Specify
|
39
|
+
# Specify relevant parameters
|
37
40
|
filter_spec = eqx.tree_at(
|
38
41
|
lambda model: model.params,
|
39
42
|
filter_spec,
|
40
|
-
replace=
|
43
|
+
replace={key: param.filter_spec for key, param in self.params.items()}
|
41
44
|
)
|
42
45
|
|
43
46
|
return filter_spec
|
44
47
|
|
45
48
|
# Add constrain method
|
46
49
|
@eqx.filter_jit
|
47
|
-
def
|
50
|
+
def constrain_params(self) -> Tuple[Dict[str, Parameter], Scalar]:
|
48
51
|
"""
|
49
52
|
Constrain `params` to the appropriate domain.
|
50
53
|
|
51
54
|
# Returns
|
52
|
-
A dictionary of
|
55
|
+
A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
|
53
56
|
"""
|
54
|
-
t_params: Dict[str, Array] = self.params
|
57
|
+
t_params: Dict[str, Array | PyTree] = self.params
|
55
58
|
target: Scalar = jnp.array(0.0)
|
56
59
|
|
57
60
|
for par, map in self.constraints.items():
|
@@ -64,11 +67,11 @@ class Model(eqx.Module):
|
|
64
67
|
return t_params, target
|
65
68
|
|
66
69
|
# Add default transform method
|
67
|
-
def
|
70
|
+
def transform_params(self) -> Tuple[Dict[str, Parameter], Scalar]:
|
68
71
|
"""
|
69
72
|
Apply a custom transformation to `params` if needed.
|
70
73
|
|
71
74
|
# Returns
|
72
75
|
A dictionary of transformed JAX Arrays representing the transformed parameters.
|
73
76
|
"""
|
74
|
-
return self.
|
77
|
+
return self.constrain_params()
|
bayinx/core/parameter.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Self
|
2
|
+
|
3
|
+
import equinox as eqx
|
4
|
+
import jax.tree as jt
|
5
|
+
from jaxtyping import Array, PyTree
|
6
|
+
|
7
|
+
|
8
|
+
class Parameter(eqx.Module):
|
9
|
+
"""
|
10
|
+
A container for a parameter of a `Model`.
|
11
|
+
|
12
|
+
Subclasses can be constructed for custom filter specifications(`filter_spec`).
|
13
|
+
|
14
|
+
# Attributes
|
15
|
+
- `vals`: The parameter's value(s).
|
16
|
+
"""
|
17
|
+
vals: Array | PyTree
|
18
|
+
|
19
|
+
|
20
|
+
def __init__(self, values: Array | PyTree):
|
21
|
+
# Insert parameter values
|
22
|
+
self.vals = values
|
23
|
+
|
24
|
+
# Default filter specification
|
25
|
+
@property
|
26
|
+
@eqx.filter_jit
|
27
|
+
def filter_spec(self) -> Self:
|
28
|
+
"""
|
29
|
+
Generates a filter specification to filter out static parameters.
|
30
|
+
"""
|
31
|
+
# Generate empty specification
|
32
|
+
filter_spec = jt.map(lambda _: False, self)
|
33
|
+
|
34
|
+
# Specify Array leaves
|
35
|
+
filter_spec = eqx.tree_at(
|
36
|
+
lambda params: params.vals,
|
37
|
+
filter_spec,
|
38
|
+
replace=jt.map(eqx.is_array_like, self.vals),
|
39
|
+
)
|
40
|
+
|
41
|
+
return filter_spec
|
bayinx/core/variational.py
CHANGED
@@ -103,7 +103,7 @@ class Variational(eqx.Module):
|
|
103
103
|
- `key`: A PRNG key.
|
104
104
|
"""
|
105
105
|
# Partition variational
|
106
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
106
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
107
107
|
|
108
108
|
# Construct scheduler
|
109
109
|
schedule: Schedule = opx.cosine_decay_schedule(
|
@@ -143,7 +143,7 @@ class Variational(eqx.Module):
|
|
143
143
|
|
144
144
|
# Compute updates
|
145
145
|
updates, opt_state = optim.update(
|
146
|
-
updates, opt_state, eqx.filter(dyn, dyn.filter_spec
|
146
|
+
updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
|
147
147
|
)
|
148
148
|
|
149
149
|
# Update variational distribution
|
@@ -46,7 +46,7 @@ class FullAffine(Flow):
|
|
46
46
|
|
47
47
|
@eqx.filter_jit
|
48
48
|
def forward(self, draws: Array) -> Array:
|
49
|
-
params = self.
|
49
|
+
params = self.transform_params()
|
50
50
|
|
51
51
|
# Extract parameters
|
52
52
|
shift: Array = params["shift"]
|
@@ -60,7 +60,7 @@ class FullAffine(Flow):
|
|
60
60
|
@eqx.filter_jit
|
61
61
|
@partial(jax.vmap, in_axes=(None, 0))
|
62
62
|
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
63
|
-
params = self.
|
63
|
+
params = self.transform_params()
|
64
64
|
|
65
65
|
# Extract parameters
|
66
66
|
shift: Array = params["shift"]
|
bayinx/mhx/vi/flows/planar.py
CHANGED
@@ -39,7 +39,7 @@ class Planar(Flow):
|
|
39
39
|
@eqx.filter_jit
|
40
40
|
@partial(jax.vmap, in_axes=(None, 0))
|
41
41
|
def forward(self, draws: Array) -> Array:
|
42
|
-
params = self.
|
42
|
+
params = self.transform_params()
|
43
43
|
|
44
44
|
# Extract parameters
|
45
45
|
w: Array = params["w"]
|
@@ -54,7 +54,7 @@ class Planar(Flow):
|
|
54
54
|
@eqx.filter_jit
|
55
55
|
@partial(jax.vmap, in_axes=(None, 0))
|
56
56
|
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
57
|
-
params = self.
|
57
|
+
params = self.transform_params()
|
58
58
|
|
59
59
|
# Extract parameters
|
60
60
|
w: Array = params["w"]
|
bayinx/mhx/vi/flows/radial.py
CHANGED
@@ -49,7 +49,7 @@ class Radial(Flow):
|
|
49
49
|
# Returns
|
50
50
|
The transformed samples.
|
51
51
|
"""
|
52
|
-
params = self.
|
52
|
+
params = self.transform_params()
|
53
53
|
|
54
54
|
# Extract parameters
|
55
55
|
alpha = params["alpha"]
|
@@ -67,7 +67,7 @@ class Radial(Flow):
|
|
67
67
|
@partial(jax.vmap, in_axes=(None, 0))
|
68
68
|
@eqx.filter_jit
|
69
69
|
def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
|
70
|
-
params = self.
|
70
|
+
params = self.transform_params()
|
71
71
|
|
72
72
|
# Extract parameters
|
73
73
|
alpha = params["alpha"]
|
bayinx/mhx/vi/meanfield.py
CHANGED
@@ -29,7 +29,7 @@ class MeanField(Variational):
|
|
29
29
|
- `model`: A probabilistic `Model` object.
|
30
30
|
"""
|
31
31
|
# Partition model
|
32
|
-
params, self._constraints = eqx.partition(model, model.filter_spec
|
32
|
+
params, self._constraints = eqx.partition(model, model.filter_spec)
|
33
33
|
|
34
34
|
# Flatten params component
|
35
35
|
params, self._unflatten = ravel_pytree(params)
|
@@ -40,6 +40,22 @@ class MeanField(Variational):
|
|
40
40
|
"log_std": jnp.zeros(params.size, dtype=params.dtype),
|
41
41
|
}
|
42
42
|
|
43
|
+
@property
|
44
|
+
@eqx.filter_jit
|
45
|
+
def filter_spec(self):
|
46
|
+
# Generate empty specification
|
47
|
+
filter_spec = jtu.tree_map(lambda _: False, self)
|
48
|
+
|
49
|
+
# Specify variational parameters
|
50
|
+
filter_spec = eqx.tree_at(
|
51
|
+
lambda mf: mf.var_params,
|
52
|
+
filter_spec,
|
53
|
+
replace=True,
|
54
|
+
)
|
55
|
+
|
56
|
+
return filter_spec
|
57
|
+
|
58
|
+
|
43
59
|
@eqx.filter_jit
|
44
60
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
45
61
|
# Sample variational draws
|
@@ -59,23 +75,9 @@ class MeanField(Variational):
|
|
59
75
|
sigma=jnp.exp(self.var_params["log_std"]),
|
60
76
|
).sum(axis=1)
|
61
77
|
|
62
|
-
@eqx.filter_jit
|
63
|
-
def filter_spec(self):
|
64
|
-
# Generate empty specification
|
65
|
-
filter_spec = jtu.tree_map(lambda _: False, self)
|
66
|
-
|
67
|
-
# Specify variational parameters
|
68
|
-
filter_spec = eqx.tree_at(
|
69
|
-
lambda mf: mf.var_params,
|
70
|
-
filter_spec,
|
71
|
-
replace=True,
|
72
|
-
)
|
73
|
-
|
74
|
-
return filter_spec
|
75
|
-
|
76
78
|
@eqx.filter_jit
|
77
79
|
def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
|
78
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
80
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
79
81
|
|
80
82
|
@eqx.filter_jit
|
81
83
|
def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
|
@@ -97,7 +99,7 @@ class MeanField(Variational):
|
|
97
99
|
|
98
100
|
@eqx.filter_jit
|
99
101
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
100
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
102
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
101
103
|
|
102
104
|
@eqx.filter_grad
|
103
105
|
@eqx.filter_jit
|
@@ -33,7 +33,7 @@ class NormalizingFlow(Variational):
|
|
33
33
|
- `model`: A probabilistic `Model` object.
|
34
34
|
"""
|
35
35
|
# Partition model
|
36
|
-
params, self._constraints = eqx.partition(model,
|
36
|
+
params, self._constraints = eqx.partition(model, model.filter_spec)
|
37
37
|
|
38
38
|
# Flatten params component
|
39
39
|
_, self._unflatten = jfu.ravel_pytree(params)
|
@@ -41,6 +41,8 @@ class NormalizingFlow(Variational):
|
|
41
41
|
self.base = base
|
42
42
|
self.flows = flows
|
43
43
|
|
44
|
+
@property
|
45
|
+
@eqx.filter_jit
|
44
46
|
def filter_spec(self):
|
45
47
|
# Generate empty specification
|
46
48
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
@@ -49,7 +51,7 @@ class NormalizingFlow(Variational):
|
|
49
51
|
filter_spec = eqx.tree_at(
|
50
52
|
lambda vari: vari.flows,
|
51
53
|
filter_spec,
|
52
|
-
replace=[flow.filter_spec
|
54
|
+
replace=[flow.filter_spec for flow in self.flows],
|
53
55
|
)
|
54
56
|
|
55
57
|
return filter_spec
|
@@ -112,7 +114,7 @@ class NormalizingFlow(Variational):
|
|
112
114
|
|
113
115
|
@eqx.filter_jit
|
114
116
|
def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
|
115
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
117
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
116
118
|
|
117
119
|
@eqx.filter_jit
|
118
120
|
def elbo(dyn: Self, n: int, key: Key, data: Any = None):
|
@@ -129,7 +131,7 @@ class NormalizingFlow(Variational):
|
|
129
131
|
|
130
132
|
@eqx.filter_jit
|
131
133
|
def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
|
132
|
-
dyn, static = eqx.partition(self, self.filter_spec
|
134
|
+
dyn, static = eqx.partition(self, self.filter_spec)
|
133
135
|
|
134
136
|
@eqx.filter_grad
|
135
137
|
@eqx.filter_jit
|
bayinx/mhx/vi/standard.py
CHANGED
@@ -19,7 +19,7 @@ class Standard(Variational):
|
|
19
19
|
- `dim`: Dimension of the parameter space.
|
20
20
|
"""
|
21
21
|
|
22
|
-
dim: int
|
22
|
+
dim: int
|
23
23
|
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
24
|
_constraints: Model
|
25
25
|
|
@@ -31,7 +31,7 @@ class Standard(Variational):
|
|
31
31
|
- `model`: A probabilistic `Model` object.
|
32
32
|
"""
|
33
33
|
# Partition model
|
34
|
-
params, self._constraints = eqx.partition(model, model.filter_spec
|
34
|
+
params, self._constraints = eqx.partition(model, model.filter_spec)
|
35
35
|
|
36
36
|
# Flatten params component
|
37
37
|
params, self._unflatten = ravel_pytree(params)
|
@@ -54,7 +54,7 @@ class Standard(Variational):
|
|
54
54
|
sigma=jnp.array(1.0),
|
55
55
|
).sum(axis=1, keepdims=True)
|
56
56
|
|
57
|
-
@
|
57
|
+
@property
|
58
58
|
def filter_spec(self):
|
59
59
|
filter_spec = jtu.tree_map(lambda _: False, self)
|
60
60
|
|
@@ -0,0 +1,30 @@
|
|
1
|
+
bayinx/__init__.py,sha256=htihTsJ54k-ljBLzt4ye8DR7ORwZhxv-bLMcEaDQeqY,86
|
2
|
+
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
bayinx/constraints/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
|
5
|
+
bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
|
6
|
+
bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
|
7
|
+
bayinx/core/flow.py,sha256=lAPJdQnrIxC3JoowTp77Gvm0p0v_xQn8FMjFJWMnWbc,2340
|
8
|
+
bayinx/core/model.py,sha256=QnJUKaR6d5RCe_WIxD2oJtI8NJyFKWUWyCRVwOm0j3s,2276
|
9
|
+
bayinx/core/parameter.py,sha256=fdyzun6TDnXxQT_KlarIJvWzn9n8bQgzfiVjWIIHk6k,998
|
10
|
+
bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
|
11
|
+
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
+
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
13
|
+
bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
|
14
|
+
bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
|
15
|
+
bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
|
16
|
+
bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
|
18
|
+
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
19
|
+
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
20
|
+
bayinx/mhx/vi/meanfield.py,sha256=M4QrOuHaIMLTuQSD6JNF9vELnTm370tXV68JPB7B67M,3652
|
21
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=9c5ayMJ_Wgq6pUb1GYHIFIzw3Bf1AsIdOjcerLoYMrA,4655
|
22
|
+
bayinx/mhx/vi/standard.py,sha256=DfSV0r9oXzp9UM8OsZBpoJPRUhiDoAq_X2_2z_M83lA,1685
|
23
|
+
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
24
|
+
bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
|
25
|
+
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
26
|
+
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
27
|
+
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
28
|
+
bayinx-0.3.0.dist-info/METADATA,sha256=RLbnLgyMmnEh2BJmqex3MMWFFS3HgSU9NEeQEvkyfC0,3057
|
29
|
+
bayinx-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
30
|
+
bayinx-0.3.0.dist-info/RECORD,,
|
bayinx-0.2.33.dist-info/RECORD
DELETED
@@ -1,29 +0,0 @@
|
|
1
|
-
bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
|
2
|
-
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx/constraints/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
bayinx/constraints/lower.py,sha256=MAAsWpZhqu1ySMskQ0fPVkCzW6CVDCU67q2bkCyzbLc,936
|
5
|
-
bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
|
6
|
-
bayinx/core/constraint.py,sha256=60KzDILVLQTCY3jt4YEnNKJ5gnG8IHTv_nNqrdwt_3c,751
|
7
|
-
bayinx/core/flow.py,sha256=A5Vw5t76LPasnMgghjw6ulBkIm5L2jBprusVt-tuwko,2296
|
8
|
-
bayinx/core/model.py,sha256=7Gt7HkFLzSUbRY9PxTDp6CrXzmld25NL9aQo34ePeno,2135
|
9
|
-
bayinx/core/variational.py,sha256=2stsYKZDri1rLP7mrz7X2GWehBXNESdlWtmF2N9CEas,4787
|
10
|
-
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
12
|
-
bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
|
13
|
-
bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
|
14
|
-
bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
|
15
|
-
bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
|
-
bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
|
17
|
-
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
18
|
-
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
19
|
-
bayinx/mhx/vi/meanfield.py,sha256=BobfTagVGA5R-dclv-E0jSA80KZg1X6GGjiw7XR61vE,3643
|
20
|
-
bayinx/mhx/vi/normalizing_flow.py,sha256=DYhvTiu2Fr5x8KpWAMZVUaio7ctG2X2SMUO0l8zfZ5g,4622
|
21
|
-
bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
|
22
|
-
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
23
|
-
bayinx/mhx/vi/flows/fullaffine.py,sha256=Z_G2Cg90Asgvqel8buMx5kEVsiIxDDh8rc-L_nP9OCY,1950
|
24
|
-
bayinx/mhx/vi/flows/planar.py,sha256=WVj-oxcRctuoRA6KJjU63ek1ZgKNG2vI-TLN0QqjtKA,1916
|
25
|
-
bayinx/mhx/vi/flows/radial.py,sha256=Obj3SraliawIHmP14F9wRpWt34y3kscY--Izy24eCvM,2499
|
26
|
-
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
27
|
-
bayinx-0.2.33.dist-info/METADATA,sha256=8d-BDtz7NrXSs5kJd-9Yr5zHTzEPtQvhgZGD-3VX7FI,3058
|
28
|
-
bayinx-0.2.33.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
29
|
-
bayinx-0.2.33.dist-info/RECORD,,
|
File without changes
|