bayinx 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +3 -3
- bayinx/constraints/__init__.py +3 -1
- bayinx/constraints/lower.py +3 -4
- bayinx/core/__init__.py +7 -4
- bayinx/core/{constraint.py → _constraint.py} +1 -1
- bayinx/core/{model.py → _model.py} +7 -11
- bayinx/core/{parameter.py → _parameter.py} +2 -2
- bayinx/core/{variational.py → _variational.py} +28 -9
- bayinx/dists/__init__.py +3 -0
- bayinx/dists/censored/__init__.py +3 -0
- bayinx/dists/censored/gamma2/__init__.py +3 -1
- bayinx/dists/censored/gamma2/r.py +16 -13
- bayinx/dists/censored/posnormal/__init__.py +3 -0
- bayinx/dists/censored/posnormal/r.py +78 -0
- bayinx/dists/gamma2.py +12 -2
- bayinx/dists/normal.py +82 -24
- bayinx/dists/posnormal.py +260 -0
- bayinx/mhx/vi/__init__.py +5 -3
- bayinx/mhx/vi/meanfield.py +5 -6
- bayinx/mhx/vi/normalizing_flow.py +4 -4
- bayinx/mhx/vi/standard.py +4 -8
- {bayinx-0.3.4.dist-info → bayinx-0.3.6.dist-info}/METADATA +2 -1
- bayinx-0.3.6.dist-info/RECORD +35 -0
- bayinx-0.3.6.dist-info/licenses/LICENSE +21 -0
- bayinx-0.3.4.dist-info/RECORD +0 -31
- /bayinx/core/{flow.py → _flow.py} +0 -0
- {bayinx-0.3.4.dist-info → bayinx-0.3.6.dist-info}/WHEEL +0 -0
bayinx/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1
|
-
from bayinx.core import Model
|
2
|
-
|
3
|
-
|
1
|
+
from bayinx.core import Model, Parameter, constrain
|
2
|
+
|
3
|
+
__all__ = ["Model", "Parameter", "constrain"]
|
bayinx/constraints/__init__.py
CHANGED
bayinx/constraints/lower.py
CHANGED
@@ -5,8 +5,7 @@ import jax.numpy as jnp
|
|
5
5
|
import jax.tree as jt
|
6
6
|
from jaxtyping import PyTree, Scalar, ScalarLike
|
7
7
|
|
8
|
-
from bayinx.core
|
9
|
-
from bayinx.core.parameter import Parameter
|
8
|
+
from bayinx.core import Constraint, Parameter
|
10
9
|
|
11
10
|
|
12
11
|
class Lower(Constraint):
|
@@ -39,8 +38,8 @@ class Lower(Constraint):
|
|
39
38
|
dyn_params, static_params = eqx.partition(x, filter_spec)
|
40
39
|
|
41
40
|
# Compute density adjustment
|
42
|
-
laj: PyTree = jt.map(jnp.sum, dyn_params)
|
43
|
-
laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
|
41
|
+
laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
|
42
|
+
laj: Scalar = jt.reduce(lambda a, b: a + b, laj)
|
44
43
|
|
45
44
|
# Compute transformation
|
46
45
|
dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
|
bayinx/core/__init__.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
from
|
4
|
-
from
|
1
|
+
from ._constraint import Constraint
|
2
|
+
from ._flow import Flow
|
3
|
+
from ._model import Model, constrain
|
4
|
+
from ._parameter import Parameter
|
5
|
+
from ._variational import Variational
|
6
|
+
|
7
|
+
__all__ = ["Constraint", "Flow", "Model", "constrain", "Parameter", "Variational"]
|
@@ -7,13 +7,13 @@ import jax.numpy as jnp
|
|
7
7
|
import jax.tree as jt
|
8
8
|
from jaxtyping import Scalar
|
9
9
|
|
10
|
-
from
|
11
|
-
from
|
10
|
+
from ._constraint import Constraint
|
11
|
+
from ._parameter import Parameter
|
12
12
|
|
13
13
|
|
14
14
|
def constrain(constraint: Constraint):
|
15
15
|
"""Defines constraint metadata."""
|
16
|
-
return field(metadata={
|
16
|
+
return field(metadata={"constraint": constraint})
|
17
17
|
|
18
18
|
|
19
19
|
class Model(eqx.Module):
|
@@ -49,12 +49,11 @@ class Model(eqx.Module):
|
|
49
49
|
filter_spec = eqx.tree_at(
|
50
50
|
lambda model: getattr(model, f.name),
|
51
51
|
filter_spec,
|
52
|
-
replace=attr.filter_spec
|
52
|
+
replace=attr.filter_spec,
|
53
53
|
)
|
54
54
|
|
55
55
|
return filter_spec
|
56
56
|
|
57
|
-
|
58
57
|
@eqx.filter_jit
|
59
58
|
def constrain_params(self) -> Tuple[Self, Scalar]:
|
60
59
|
"""
|
@@ -71,18 +70,16 @@ class Model(eqx.Module):
|
|
71
70
|
attr = getattr(self, f.name)
|
72
71
|
|
73
72
|
# Check if constrained parameter
|
74
|
-
if isinstance(attr, Parameter) and
|
73
|
+
if isinstance(attr, Parameter) and "constraint" in f.metadata:
|
75
74
|
param = attr
|
76
|
-
constraint = f.metadata[
|
75
|
+
constraint = f.metadata["constraint"]
|
77
76
|
|
78
77
|
# Apply constraint
|
79
78
|
param, laj = constraint.constrain(param)
|
80
79
|
|
81
80
|
# Update parameters for constrained model
|
82
81
|
constrained = eqx.tree_at(
|
83
|
-
lambda model: getattr(model, f.name),
|
84
|
-
constrained,
|
85
|
-
replace=param
|
82
|
+
lambda model: getattr(model, f.name), constrained, replace=param
|
86
83
|
)
|
87
84
|
|
88
85
|
# Adjust posterior density
|
@@ -90,7 +87,6 @@ class Model(eqx.Module):
|
|
90
87
|
|
91
88
|
return constrained, target
|
92
89
|
|
93
|
-
|
94
90
|
@eqx.filter_jit
|
95
91
|
def transform_params(self) -> Tuple[Self, Scalar]:
|
96
92
|
"""
|
@@ -4,7 +4,7 @@ import equinox as eqx
|
|
4
4
|
import jax.tree as jt
|
5
5
|
from jaxtyping import PyTree
|
6
6
|
|
7
|
-
T = TypeVar(
|
7
|
+
T = TypeVar("T", bound=PyTree)
|
8
8
|
class Parameter(eqx.Module, Generic[T]):
|
9
9
|
"""
|
10
10
|
A container for a parameter of a `Model`.
|
@@ -14,8 +14,8 @@ class Parameter(eqx.Module, Generic[T]):
|
|
14
14
|
# Attributes
|
15
15
|
- `vals`: The parameter's value(s).
|
16
16
|
"""
|
17
|
-
vals: T
|
18
17
|
|
18
|
+
vals: T
|
19
19
|
|
20
20
|
def __init__(self, values: T):
|
21
21
|
# Insert parameter values
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from abc import abstractmethod
|
2
2
|
from functools import partial
|
3
|
-
from typing import Any, Callable, Self, Tuple
|
3
|
+
from typing import Any, Callable, Generic, Self, Tuple, TypeVar
|
4
4
|
|
5
5
|
import equinox as eqx
|
6
6
|
import jax
|
@@ -11,10 +11,10 @@ import optax as opx
|
|
11
11
|
from jaxtyping import Array, Key, PyTree, Scalar
|
12
12
|
from optax import GradientTransformation, OptState, Schedule
|
13
13
|
|
14
|
-
from
|
14
|
+
from ._model import Model
|
15
15
|
|
16
|
-
|
17
|
-
class Variational(eqx.Module):
|
16
|
+
M = TypeVar('M', bound=Model)
|
17
|
+
class Variational(eqx.Module, Generic[M]):
|
18
18
|
"""
|
19
19
|
An abstract base class used to define variational methods.
|
20
20
|
|
@@ -23,8 +23,8 @@ class Variational(eqx.Module):
|
|
23
23
|
- `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
|
24
24
|
"""
|
25
25
|
|
26
|
-
_unflatten: Callable[[Array],
|
27
|
-
_constraints:
|
26
|
+
_unflatten: Callable[[Array], M]
|
27
|
+
_constraints: M
|
28
28
|
|
29
29
|
@abstractmethod
|
30
30
|
def filter_spec(self):
|
@@ -34,7 +34,7 @@ class Variational(eqx.Module):
|
|
34
34
|
pass
|
35
35
|
|
36
36
|
@abstractmethod
|
37
|
-
def sample(self, n: int, key: Key) -> Array:
|
37
|
+
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
38
38
|
"""
|
39
39
|
Sample from the variational distribution.
|
40
40
|
"""
|
@@ -72,10 +72,10 @@ class Variational(eqx.Module):
|
|
72
72
|
- `data`: Data used to evaluate the posterior(if needed).
|
73
73
|
"""
|
74
74
|
# Unflatten variational draw
|
75
|
-
model:
|
75
|
+
model: M = self._unflatten(draws)
|
76
76
|
|
77
77
|
# Combine with constraints
|
78
|
-
model:
|
78
|
+
model: M = eqx.combine(model, self._constraints)
|
79
79
|
|
80
80
|
# Evaluate posterior density
|
81
81
|
return model.eval(data)
|
@@ -160,3 +160,22 @@ class Variational(eqx.Module):
|
|
160
160
|
|
161
161
|
# Return optimized variational
|
162
162
|
return eqx.combine(dyn, static)
|
163
|
+
|
164
|
+
@eqx.filter_jit
|
165
|
+
def posterior_predictive(
|
166
|
+
self, func: Callable[[M], Array], n: int, key: Key = jr.PRNGKey(0)
|
167
|
+
) -> Array:
|
168
|
+
# Sample draws from the variational approximation
|
169
|
+
draws: Array = self.sample(n, key)
|
170
|
+
|
171
|
+
# Evaluate posterior predictive
|
172
|
+
@jax.jit
|
173
|
+
@jax.vmap
|
174
|
+
def evaluate(draw: Array):
|
175
|
+
# Reconstruct model
|
176
|
+
model: M = self._unflatten(draw)
|
177
|
+
|
178
|
+
# Evaluate
|
179
|
+
return func(model)
|
180
|
+
|
181
|
+
return evaluate(draws)
|
bayinx/dists/__init__.py
CHANGED
@@ -10,7 +10,7 @@ def prob(
|
|
10
10
|
x: Float[ArrayLike, "..."],
|
11
11
|
mu: Float[ArrayLike, "..."],
|
12
12
|
nu: Float[ArrayLike, "..."],
|
13
|
-
censor: Float[ArrayLike, "..."]
|
13
|
+
censor: Float[ArrayLike, "..."],
|
14
14
|
) -> Float[Array, "..."]:
|
15
15
|
"""
|
16
16
|
The mixed probability mass/density function (PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
|
@@ -19,19 +19,20 @@ def prob(
|
|
19
19
|
- `x`: Value(s) at which to evaluate the PMF/PDF.
|
20
20
|
- `mu`: The positive mean.
|
21
21
|
- `nu`: The positive inverse dispersion.
|
22
|
+
- `censor`: The positive censor value.
|
22
23
|
|
23
24
|
# Returns
|
24
|
-
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `
|
25
|
+
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
|
25
26
|
"""
|
26
|
-
evals: Array = jnp.zeros_like(x * 1.0)
|
27
|
+
evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
|
27
28
|
|
28
29
|
# Construct boolean masks
|
29
|
-
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor))
|
30
|
-
censored: Array = jnp.array(x == censor)
|
30
|
+
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
31
|
+
censored: Array = jnp.array(x == censor) # pyright: ignore
|
31
32
|
|
32
|
-
# Evaluate
|
33
|
+
# Evaluate probability mass/density function
|
33
34
|
evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
|
34
|
-
evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals)
|
35
|
+
evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
|
35
36
|
|
36
37
|
return evals
|
37
38
|
|
@@ -40,7 +41,7 @@ def logprob(
|
|
40
41
|
x: Float[ArrayLike, "..."],
|
41
42
|
mu: Float[ArrayLike, "..."],
|
42
43
|
nu: Float[ArrayLike, "..."],
|
43
|
-
censor: Float[ArrayLike, "..."]
|
44
|
+
censor: Float[ArrayLike, "..."],
|
44
45
|
) -> Float[Array, "..."]:
|
45
46
|
"""
|
46
47
|
The log-transformed mixed probability mass/density function (log PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
|
@@ -49,17 +50,19 @@ def logprob(
|
|
49
50
|
- `x`: Value(s) at which to evaluate the log PMF/PDF.
|
50
51
|
- `mu`: The positive mean/location.
|
51
52
|
- `nu`: The positive inverse dispersion.
|
53
|
+
- `censor`: The positive censor value.
|
52
54
|
|
53
55
|
# Returns
|
54
|
-
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `
|
56
|
+
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
|
55
57
|
"""
|
56
|
-
evals: Array = jnp.full_like(x * 1.0, -jnp.inf)
|
58
|
+
evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
|
57
59
|
|
58
60
|
# Construct boolean masks
|
59
|
-
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor))
|
60
|
-
censored: Array = jnp.array(x == censor)
|
61
|
+
uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
|
62
|
+
censored: Array = jnp.array(x == censor) # pyright: ignore
|
61
63
|
|
64
|
+
# Evaluate log probability mass/density function
|
62
65
|
evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
|
63
|
-
evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals)
|
66
|
+
evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
|
64
67
|
|
65
68
|
return evals
|
@@ -0,0 +1,78 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jaxtyping import Array, ArrayLike, Float
|
3
|
+
|
4
|
+
from bayinx.dists import posnormal
|
5
|
+
|
6
|
+
|
7
|
+
def prob(
|
8
|
+
x: Float[ArrayLike, "..."],
|
9
|
+
mu: Float[ArrayLike, "..."],
|
10
|
+
sigma: Float[ArrayLike, "..."],
|
11
|
+
censor: Float[ArrayLike, "..."],
|
12
|
+
) -> Float[Array, "..."]:
|
13
|
+
"""
|
14
|
+
The mixed probability mass/density function (PMF/PDF) for a right-censored positive Normal distribution.
|
15
|
+
|
16
|
+
# Parameters
|
17
|
+
- `x`: Value(s) at which to evaluate the PMF/PDF.
|
18
|
+
- `mu`: The mean.
|
19
|
+
- `sigma`: The positive standard deviation.
|
20
|
+
- `censor`: The positive censor value.
|
21
|
+
|
22
|
+
# Returns
|
23
|
+
The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
|
24
|
+
"""
|
25
|
+
# Cast to Array
|
26
|
+
x, mu, sigma, censor = (
|
27
|
+
jnp.asarray(x),
|
28
|
+
jnp.asarray(mu),
|
29
|
+
jnp.asarray(sigma),
|
30
|
+
jnp.asarray(censor),
|
31
|
+
)
|
32
|
+
|
33
|
+
# Construct boolean masks
|
34
|
+
uncensored: Array = jnp.logical_and(0.0 < x, x < censor)
|
35
|
+
censored: Array = x == censor
|
36
|
+
|
37
|
+
# Evaluate probability mass/density function
|
38
|
+
evals = jnp.where(uncensored, posnormal.prob(x, mu, sigma), 0.0)
|
39
|
+
evals = jnp.where(censored, posnormal.ccdf(x, mu, sigma), evals)
|
40
|
+
|
41
|
+
return evals
|
42
|
+
|
43
|
+
|
44
|
+
def logprob(
|
45
|
+
x: Float[ArrayLike, "..."],
|
46
|
+
mu: Float[ArrayLike, "..."],
|
47
|
+
sigma: Float[ArrayLike, "..."],
|
48
|
+
censor: Float[ArrayLike, "..."],
|
49
|
+
) -> Float[Array, "..."]:
|
50
|
+
"""
|
51
|
+
The log-transformed mixed probability mass/density function (log PMF/PDF) for a right-censored positive Normal distribution.
|
52
|
+
|
53
|
+
# Parameters
|
54
|
+
- `x`: Where to evaluate the log PMF/PDF.
|
55
|
+
- `mu`: The mean.
|
56
|
+
- `sigma`: The standard deviation.
|
57
|
+
- `censor`: The censor.
|
58
|
+
|
59
|
+
# Returns
|
60
|
+
The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
|
61
|
+
"""
|
62
|
+
# Cast to Array
|
63
|
+
x, mu, sigma, censor = (
|
64
|
+
jnp.asarray(x),
|
65
|
+
jnp.asarray(mu),
|
66
|
+
jnp.asarray(sigma),
|
67
|
+
jnp.asarray(censor),
|
68
|
+
)
|
69
|
+
|
70
|
+
# Construct boolean masks for censoring
|
71
|
+
uncensored: Array = jnp.logical_and(jnp.asarray(0.0) < x, x < censor)
|
72
|
+
censored: Array = x == censor
|
73
|
+
|
74
|
+
# Evaluate log probability mass/density function
|
75
|
+
evals = jnp.where(uncensored, posnormal.logprob(x, mu, sigma), -jnp.inf)
|
76
|
+
evals = jnp.where(censored, posnormal.logccdf(x, mu, sigma), evals)
|
77
|
+
|
78
|
+
return evals
|
bayinx/dists/gamma2.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import jax.lax as lax
|
2
|
+
import jax.numpy as jnp
|
2
3
|
from jax.scipy.special import gammaln
|
3
4
|
from jaxtyping import Array, ArrayLike, Float
|
4
5
|
|
@@ -17,6 +18,8 @@ def prob(
|
|
17
18
|
# Returns
|
18
19
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
19
20
|
"""
|
21
|
+
# Cast to Array
|
22
|
+
x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
|
20
23
|
|
21
24
|
return lax.exp(logprob(x, mu, nu))
|
22
25
|
|
@@ -35,5 +38,12 @@ def logprob(
|
|
35
38
|
# Returns
|
36
39
|
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
|
37
40
|
"""
|
38
|
-
|
39
|
-
|
41
|
+
# Cast to Array
|
42
|
+
x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
|
43
|
+
|
44
|
+
return (
|
45
|
+
-gammaln(nu)
|
46
|
+
+ nu * (lax.log(nu) - lax.log(mu))
|
47
|
+
+ (nu - 1.0) * lax.log(x)
|
48
|
+
- (x * nu / mu)
|
49
|
+
) # pyright: ignore
|
bayinx/dists/normal.py
CHANGED
@@ -1,80 +1,138 @@
|
|
1
1
|
import jax.lax as lax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jax.scipy.special as jss
|
2
4
|
from jaxtyping import Array, ArrayLike, Float
|
3
5
|
|
4
6
|
__PI = 3.141592653589793
|
5
7
|
|
6
8
|
|
7
9
|
def prob(
|
8
|
-
x: Float[ArrayLike, "..."],
|
10
|
+
x: Float[ArrayLike, "..."],
|
11
|
+
mu: Float[ArrayLike, "..."],
|
12
|
+
sigma: Float[ArrayLike, "..."],
|
9
13
|
) -> Float[Array, "..."]:
|
10
14
|
"""
|
11
15
|
The probability density function (PDF) for a Normal distribution.
|
12
16
|
|
13
17
|
# Parameters
|
14
|
-
- `x`:
|
15
|
-
- `mu`:
|
16
|
-
- `sigma`:
|
18
|
+
- `x`: Where to evaluate the PDF.
|
19
|
+
- `mu`: The mean.
|
20
|
+
- `sigma`: The standard deviation.
|
17
21
|
|
18
22
|
# Returns
|
19
23
|
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
20
24
|
"""
|
25
|
+
# Cast to Array
|
26
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
21
27
|
|
22
|
-
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (
|
23
|
-
sigma * lax.sqrt(2.0 * __PI)
|
24
|
-
)
|
28
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (sigma * lax.sqrt(2.0 * __PI))
|
25
29
|
|
26
30
|
|
27
31
|
def logprob(
|
28
|
-
x: Float[ArrayLike, "..."],
|
32
|
+
x: Float[ArrayLike, "..."],
|
33
|
+
mu: Float[ArrayLike, "..."],
|
34
|
+
sigma: Float[ArrayLike, "..."],
|
29
35
|
) -> Float[Array, "..."]:
|
30
36
|
"""
|
31
37
|
The log of the probability density function (log PDF) for a Normal distribution.
|
32
38
|
|
33
39
|
# Parameters
|
34
|
-
- `x`:
|
35
|
-
- `mu`:
|
36
|
-
- `sigma`:
|
40
|
+
- `x`: Where to evaluate the log PDF.
|
41
|
+
- `mu`: The mean.
|
42
|
+
- `sigma`: The standard deviation.
|
37
43
|
|
38
44
|
# Returns
|
39
45
|
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
40
46
|
"""
|
47
|
+
# Cast to Array
|
48
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
41
49
|
|
42
|
-
return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square(
|
43
|
-
(x - mu) / sigma # pyright: ignore
|
44
|
-
)
|
50
|
+
return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square((x - mu) / sigma)
|
45
51
|
|
46
52
|
|
47
53
|
def uprob(
|
48
|
-
x: Float[ArrayLike, "..."],
|
54
|
+
x: Float[ArrayLike, "..."],
|
55
|
+
mu: Float[ArrayLike, "..."],
|
56
|
+
sigma: Float[ArrayLike, "..."],
|
49
57
|
) -> Float[Array, "..."]:
|
50
58
|
"""
|
51
59
|
The unnormalized probability density function (uPDF) for a Normal distribution.
|
52
60
|
|
53
61
|
# Parameters
|
54
|
-
- `x`:
|
55
|
-
- `mu`:
|
56
|
-
- `sigma`:
|
62
|
+
- `x`: Where to evaluate the PDF.
|
63
|
+
- `mu`: The mean.
|
64
|
+
- `sigma`: The standard deviation.
|
57
65
|
|
58
66
|
# Returns
|
59
67
|
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
60
68
|
"""
|
69
|
+
# Cast to Array
|
70
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
61
71
|
|
62
|
-
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma
|
72
|
+
return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma
|
63
73
|
|
64
74
|
|
65
75
|
def ulogprob(
|
66
|
-
x: Float[ArrayLike, "..."],
|
76
|
+
x: Float[ArrayLike, "..."],
|
77
|
+
mu: Float[ArrayLike, "..."],
|
78
|
+
sigma: Float[ArrayLike, "..."],
|
67
79
|
) -> Float[Array, "..."]:
|
68
80
|
"""
|
69
81
|
The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
|
70
82
|
|
71
83
|
# Parameters
|
72
|
-
- `x`:
|
73
|
-
- `mu`:
|
74
|
-
- `sigma`:
|
84
|
+
- `x`: Where to evaluate the PDF.
|
85
|
+
- `mu`: The mean.
|
86
|
+
- `sigma`: The standard deviation.
|
75
87
|
|
76
88
|
# Returns
|
77
89
|
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
78
90
|
"""
|
91
|
+
# Cast to Array
|
92
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
79
93
|
|
80
|
-
return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
|
94
|
+
return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
|
95
|
+
|
96
|
+
|
97
|
+
def cdf(
|
98
|
+
x: Float[ArrayLike, "..."],
|
99
|
+
mu: Float[ArrayLike, "..."],
|
100
|
+
sigma: Float[ArrayLike, "..."],
|
101
|
+
) -> Float[Array, "..."]:
|
102
|
+
# Cast to Array
|
103
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
104
|
+
|
105
|
+
return jss.ndtr((x - mu) / sigma)
|
106
|
+
|
107
|
+
|
108
|
+
def logcdf(
|
109
|
+
x: Float[ArrayLike, "..."],
|
110
|
+
mu: Float[ArrayLike, "..."],
|
111
|
+
sigma: Float[ArrayLike, "..."],
|
112
|
+
) -> Float[Array, "..."]:
|
113
|
+
# Cast to Array
|
114
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
115
|
+
|
116
|
+
return jss.log_ndtr((x - mu) / sigma)
|
117
|
+
|
118
|
+
|
119
|
+
def ccdf(
|
120
|
+
x: Float[ArrayLike, "..."],
|
121
|
+
mu: Float[ArrayLike, "..."],
|
122
|
+
sigma: Float[ArrayLike, "..."],
|
123
|
+
) -> Float[Array, "..."]:
|
124
|
+
# Cast to Array
|
125
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
126
|
+
|
127
|
+
return jss.ndtr((mu - x) / sigma)
|
128
|
+
|
129
|
+
|
130
|
+
def logccdf(
|
131
|
+
x: Float[ArrayLike, "..."],
|
132
|
+
mu: Float[ArrayLike, "..."],
|
133
|
+
sigma: Float[ArrayLike, "..."],
|
134
|
+
) -> Float[Array, "..."]:
|
135
|
+
# Cast to Array
|
136
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
137
|
+
|
138
|
+
return jss.log_ndtr((mu - x) / sigma)
|
@@ -0,0 +1,260 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jaxtyping import Array, ArrayLike, Float
|
3
|
+
|
4
|
+
from bayinx.dists import normal
|
5
|
+
|
6
|
+
|
7
|
+
def prob(
|
8
|
+
x: Float[ArrayLike, "..."],
|
9
|
+
mu: Float[ArrayLike, "..."],
|
10
|
+
sigma: Float[ArrayLike, "..."],
|
11
|
+
) -> Float[Array, "..."]:
|
12
|
+
"""
|
13
|
+
The probability density function (PDF) for a positive Normal distribution.
|
14
|
+
|
15
|
+
# Parameters
|
16
|
+
- `x`: Where to evaluate the PDF.
|
17
|
+
- `mu`: The mean.
|
18
|
+
- `sigma`: The standard deviation.
|
19
|
+
|
20
|
+
# Returns
|
21
|
+
The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
22
|
+
"""
|
23
|
+
# Cast to Array
|
24
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
25
|
+
|
26
|
+
# Construct boolean mask for non-negative elements
|
27
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
28
|
+
|
29
|
+
# Evaluate PDF
|
30
|
+
evals = jnp.where(
|
31
|
+
non_negative,
|
32
|
+
normal.prob(x, mu, sigma) / normal.cdf(mu / sigma, 0.0, 1.0),
|
33
|
+
jnp.asarray(0.0),
|
34
|
+
)
|
35
|
+
|
36
|
+
return evals
|
37
|
+
|
38
|
+
|
39
|
+
def logprob(
|
40
|
+
x: Float[ArrayLike, "..."],
|
41
|
+
mu: Float[ArrayLike, "..."],
|
42
|
+
sigma: Float[ArrayLike, "..."],
|
43
|
+
) -> Float[Array, "..."]:
|
44
|
+
"""
|
45
|
+
The log of the probability density function (log PDF) for a positive Normal distribution.
|
46
|
+
|
47
|
+
# Parameters
|
48
|
+
- `x`: Where to evaluate the log PDF.
|
49
|
+
- `mu`: The mean.
|
50
|
+
- `sigma`: The standard deviation.
|
51
|
+
|
52
|
+
# Returns
|
53
|
+
The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
54
|
+
"""
|
55
|
+
# Cast to Array
|
56
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
57
|
+
|
58
|
+
# Construct boolean mask for non-negative elements
|
59
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
60
|
+
|
61
|
+
# Evaluate log PDF
|
62
|
+
evals = jnp.where(
|
63
|
+
non_negative,
|
64
|
+
normal.logprob(x, mu, sigma) - normal.logcdf(mu / sigma, 0.0, 1.0),
|
65
|
+
-jnp.inf,
|
66
|
+
)
|
67
|
+
|
68
|
+
return evals
|
69
|
+
|
70
|
+
|
71
|
+
def uprob(
|
72
|
+
x: Float[ArrayLike, "..."],
|
73
|
+
mu: Float[ArrayLike, "..."],
|
74
|
+
sigma: Float[ArrayLike, "..."],
|
75
|
+
) -> Float[Array, "..."]:
|
76
|
+
"""
|
77
|
+
The unnormalized probability density function (uPDF) for a positive Normal distribution.
|
78
|
+
|
79
|
+
# Parameters
|
80
|
+
- `x`: Where to evaluate the uPDF.
|
81
|
+
- `mu`: The mean.
|
82
|
+
- `sigma`: The standard deviation.
|
83
|
+
|
84
|
+
# Returns
|
85
|
+
The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
86
|
+
"""
|
87
|
+
# Cast to Array
|
88
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
89
|
+
|
90
|
+
# Construct boolean mask for non-negative elements
|
91
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
92
|
+
|
93
|
+
# Evaluate PDF
|
94
|
+
evals = jnp.where(non_negative, normal.prob(x, mu, sigma), jnp.asarray(0.0))
|
95
|
+
|
96
|
+
return evals
|
97
|
+
|
98
|
+
|
99
|
+
def ulogprob(
|
100
|
+
x: Float[ArrayLike, "..."],
|
101
|
+
mu: Float[ArrayLike, "..."],
|
102
|
+
sigma: Float[ArrayLike, "..."],
|
103
|
+
) -> Float[Array, "..."]:
|
104
|
+
"""
|
105
|
+
The log of the unnormalized probability density function (log uPDF) for a positive Normal distribution.
|
106
|
+
|
107
|
+
# Parameters
|
108
|
+
- `x`: Where to evaluate the log uPDF.
|
109
|
+
- `mu`: The mean.
|
110
|
+
- `sigma`: The standard deviation.
|
111
|
+
|
112
|
+
# Returns
|
113
|
+
The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
114
|
+
"""
|
115
|
+
# Cast to Array
|
116
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
117
|
+
|
118
|
+
# Construct boolean mask for non-negative elements
|
119
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
120
|
+
|
121
|
+
# Evaluate log PDF
|
122
|
+
evals = jnp.where(non_negative, normal.logprob(x, mu, sigma), -jnp.inf)
|
123
|
+
|
124
|
+
return evals
|
125
|
+
|
126
|
+
|
127
|
+
def cdf(
|
128
|
+
x: Float[ArrayLike, "..."],
|
129
|
+
mu: Float[ArrayLike, "..."],
|
130
|
+
sigma: Float[ArrayLike, "..."],
|
131
|
+
) -> Float[Array, "..."]:
|
132
|
+
"""
|
133
|
+
The cumulative density function (CDF) for a positive Normal distribution.
|
134
|
+
|
135
|
+
# Parameters
|
136
|
+
- `x`: Where to evaluate the CDF.
|
137
|
+
- `mu`: The mean.
|
138
|
+
- `sigma`: The standard deviation.
|
139
|
+
|
140
|
+
# Returns
|
141
|
+
The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
142
|
+
|
143
|
+
# Notes
|
144
|
+
Not numerically stable for small `x`.
|
145
|
+
"""
|
146
|
+
# Cast to Array
|
147
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
148
|
+
|
149
|
+
# Construct boolean mask for non-negative elements
|
150
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
151
|
+
|
152
|
+
# Compute intermediates
|
153
|
+
A: Array = normal.cdf(x, mu, sigma)
|
154
|
+
B: Array = normal.cdf(-mu / sigma, 0.0, 1.0)
|
155
|
+
C: Array = normal.cdf(mu / sigma, 0.0, 1.0)
|
156
|
+
|
157
|
+
# Evaluate CDF
|
158
|
+
evals = jnp.where(non_negative, (A - B) / C, jnp.asarray(0.0))
|
159
|
+
|
160
|
+
return evals
|
161
|
+
|
162
|
+
|
163
|
+
# TODO: make numerically stable
|
164
|
+
def logcdf(
|
165
|
+
x: Float[ArrayLike, "..."],
|
166
|
+
mu: Float[ArrayLike, "..."],
|
167
|
+
sigma: Float[ArrayLike, "..."],
|
168
|
+
) -> Float[Array, "..."]:
|
169
|
+
"""
|
170
|
+
The log-transformed cumulative density function (log CDF) for a positive Normal distribution.
|
171
|
+
|
172
|
+
# Parameters
|
173
|
+
- `x`: Where to evaluate the log CDF.
|
174
|
+
- `mu`: The mean.
|
175
|
+
- `sigma`: The standard deviation.
|
176
|
+
|
177
|
+
# Returns
|
178
|
+
The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
179
|
+
|
180
|
+
# Notes
|
181
|
+
Not numerically stable for small `x`.
|
182
|
+
"""
|
183
|
+
# Cast to Array
|
184
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
185
|
+
|
186
|
+
# Construct boolean mask for non-negative elements
|
187
|
+
non_negative: Array = jnp.asarray(0.0) <= x
|
188
|
+
|
189
|
+
A: Array = normal.logcdf(x, mu, sigma)
|
190
|
+
B: Array = normal.logcdf(-mu / sigma, 0.0, 1.0)
|
191
|
+
C: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
|
192
|
+
|
193
|
+
# Evaluate log CDF
|
194
|
+
evals = jnp.where(non_negative, A + jnp.log1p(-jnp.exp(B - A)) - C, -jnp.inf)
|
195
|
+
|
196
|
+
return evals
|
197
|
+
|
198
|
+
|
199
|
+
def ccdf(
|
200
|
+
x: Float[ArrayLike, "..."],
|
201
|
+
mu: Float[ArrayLike, "..."],
|
202
|
+
sigma: Float[ArrayLike, "..."],
|
203
|
+
) -> Float[Array, "..."]:
|
204
|
+
"""
|
205
|
+
The complementary cumulative density function (cCDF) for a positive Normal distribution.
|
206
|
+
|
207
|
+
# Parameters
|
208
|
+
- `x`: Where to evaluate the cCDF.
|
209
|
+
- `mu`: The mean.
|
210
|
+
- `sigma`: The standard deviation.
|
211
|
+
|
212
|
+
# Returns
|
213
|
+
The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
214
|
+
"""
|
215
|
+
# Cast to arrays
|
216
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
217
|
+
|
218
|
+
# Construct boolean mask for non-negative elements
|
219
|
+
non_negative: Array = 0.0 <= x
|
220
|
+
|
221
|
+
# Compute intermediates
|
222
|
+
A: Array = normal.cdf(-x, -mu, sigma)
|
223
|
+
B: Array = normal.cdf(mu / sigma, 0.0, 1.0)
|
224
|
+
|
225
|
+
# Evaluate cCDF
|
226
|
+
evals = jnp.where(non_negative, A / B, jnp.asarray(1.0))
|
227
|
+
|
228
|
+
return evals
|
229
|
+
|
230
|
+
|
231
|
+
def logccdf(
|
232
|
+
x: Float[ArrayLike, "..."],
|
233
|
+
mu: Float[ArrayLike, "..."],
|
234
|
+
sigma: Float[ArrayLike, "..."],
|
235
|
+
) -> Float[Array, "..."]:
|
236
|
+
"""
|
237
|
+
The log-transformed complementary cumulative density function (log cCDF) for a positive Normal distribution.
|
238
|
+
|
239
|
+
# Parameters
|
240
|
+
- `x`: Where to evaluate the log cCDF.
|
241
|
+
- `mu`: The mean.
|
242
|
+
- `sigma`: The standard deviation.
|
243
|
+
|
244
|
+
# Returns
|
245
|
+
The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
|
246
|
+
"""
|
247
|
+
# Cast to arrays
|
248
|
+
x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
|
249
|
+
|
250
|
+
# Construct boolean mask for non-negative elements
|
251
|
+
non_negative: Array = 0.0 <= x
|
252
|
+
|
253
|
+
# Compute intermediates
|
254
|
+
A: Array = normal.logcdf(-x, -mu, sigma)
|
255
|
+
B: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
|
256
|
+
|
257
|
+
# Evaluate log cCDF
|
258
|
+
evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
|
259
|
+
|
260
|
+
return evals
|
bayinx/mhx/vi/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
-
from bayinx.mhx.vi.meanfield import MeanField
|
2
|
-
from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
|
3
|
-
from bayinx.mhx.vi.standard import Standard
|
1
|
+
from bayinx.mhx.vi.meanfield import MeanField
|
2
|
+
from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
|
3
|
+
from bayinx.mhx.vi.standard import Standard
|
4
|
+
|
5
|
+
__all__ = ['MeanField', 'NormalizingFlow', 'Standard']
|
bayinx/mhx/vi/meanfield.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, Dict, Self
|
1
|
+
from typing import Any, Dict, Generic, Self, TypeVar
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.numpy as jnp
|
@@ -10,8 +10,8 @@ from jaxtyping import Array, Float, Key, Scalar
|
|
10
10
|
from bayinx.core import Model, Variational
|
11
11
|
from bayinx.dists import normal
|
12
12
|
|
13
|
-
|
14
|
-
class MeanField(Variational):
|
13
|
+
M = TypeVar('M', bound=Model)
|
14
|
+
class MeanField(Variational, Generic[M]):
|
15
15
|
"""
|
16
16
|
A fully factorized Gaussian approximation to a posterior distribution.
|
17
17
|
|
@@ -19,9 +19,9 @@ class MeanField(Variational):
|
|
19
19
|
- `var_params`: The variational parameters for the approximation.
|
20
20
|
"""
|
21
21
|
|
22
|
-
var_params: Dict[str, Float[Array, "..."]]
|
22
|
+
var_params: Dict[str, Float[Array, "..."]] #todo: just expand to attributes
|
23
23
|
|
24
|
-
def __init__(self, model:
|
24
|
+
def __init__(self, model: M):
|
25
25
|
"""
|
26
26
|
Constructs an unoptimized meanfield posterior approximation.
|
27
27
|
|
@@ -55,7 +55,6 @@ class MeanField(Variational):
|
|
55
55
|
|
56
56
|
return filter_spec
|
57
57
|
|
58
|
-
|
59
58
|
@eqx.filter_jit
|
60
59
|
def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
|
61
60
|
# Sample variational draws
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, Self, Tuple
|
1
|
+
from typing import Any, Generic, Self, Tuple, TypeVar
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.flatten_util as jfu
|
@@ -9,8 +9,8 @@ from jaxtyping import Array, Key, Scalar
|
|
9
9
|
|
10
10
|
from bayinx.core import Flow, Model, Variational
|
11
11
|
|
12
|
-
|
13
|
-
class NormalizingFlow(Variational):
|
12
|
+
M = TypeVar('M', bound=Model)
|
13
|
+
class NormalizingFlow(Variational, Generic[M]):
|
14
14
|
"""
|
15
15
|
An ordered collection of diffeomorphisms that map a base distribution to a
|
16
16
|
normalized approximation of a posterior distribution.
|
@@ -23,7 +23,7 @@ class NormalizingFlow(Variational):
|
|
23
23
|
flows: list[Flow]
|
24
24
|
base: Variational
|
25
25
|
|
26
|
-
def __init__(self, base: Variational, flows: list[Flow], model:
|
26
|
+
def __init__(self, base: Variational, flows: list[Flow], model: M):
|
27
27
|
"""
|
28
28
|
Constructs an unoptimized normalizing flow posterior approximation.
|
29
29
|
|
bayinx/mhx/vi/standard.py
CHANGED
@@ -1,29 +1,25 @@
|
|
1
|
-
from typing import Callable
|
2
1
|
|
3
2
|
import equinox as eqx
|
4
3
|
import jax.numpy as jnp
|
5
4
|
import jax.random as jr
|
6
5
|
import jax.tree_util as jtu
|
7
6
|
from jax.flatten_util import ravel_pytree
|
8
|
-
from jaxtyping import Array,
|
7
|
+
from jaxtyping import Array, Key
|
9
8
|
|
10
|
-
from bayinx.core import
|
9
|
+
from bayinx.core._variational import M, Variational
|
11
10
|
from bayinx.dists import normal
|
12
11
|
|
13
12
|
|
14
|
-
class Standard(Variational):
|
13
|
+
class Standard(Variational[M]):
|
15
14
|
"""
|
16
15
|
A standard normal approximation to a posterior distribution.
|
17
16
|
|
18
17
|
# Attributes
|
19
18
|
- `dim`: Dimension of the parameter space.
|
20
19
|
"""
|
21
|
-
|
22
20
|
dim: int
|
23
|
-
_unflatten: Callable[[Float[Array, "..."]], Model]
|
24
|
-
_constraints: Model
|
25
21
|
|
26
|
-
def __init__(self, model:
|
22
|
+
def __init__(self, model: M):
|
27
23
|
"""
|
28
24
|
Constructs a standard normal approximation to a posterior distribution.
|
29
25
|
|
@@ -0,0 +1,35 @@
|
|
1
|
+
bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
|
2
|
+
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
bayinx/constraints/__init__.py,sha256=PiWXZKi7YdbTMKvw-OE5f-t87jJT893uAFrwWWBfOdg,64
|
4
|
+
bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
|
5
|
+
bayinx/core/__init__.py,sha256=bZvQITgW0DWuPKl3wCLKt6WHKogYKx8Zz36g8z9Aung,253
|
6
|
+
bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
|
7
|
+
bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
8
|
+
bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
|
9
|
+
bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
|
10
|
+
bayinx/core/_variational.py,sha256=szm1WuUh_3pxzFfQy92TR4p2Sk-fR6rO-4-LrJMeVGI,5356
|
11
|
+
bayinx/dists/__init__.py,sha256=9DdPea7HAnBOzaV_4gM5noPX8YCb_p06d8PJvGfFy3Y,118
|
12
|
+
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
13
|
+
bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
|
14
|
+
bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
|
15
|
+
bayinx/dists/posnormal.py,sha256=w9plA1EctXwXOiY0doc4ZndjnwptbEZBHHCGdc4gviY,7292
|
16
|
+
bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
|
17
|
+
bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
|
18
|
+
bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
19
|
+
bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
|
20
|
+
bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
|
21
|
+
bayinx/dists/censored/posnormal/r.py,sha256=hyuNR3HZY-Tgtso-WwjcZT6Ejxfyax_VKwIvVix44Jc,2362
|
22
|
+
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
23
|
+
bayinx/mhx/vi/__init__.py,sha256=2woNB5oZxfs8pZCkOfzriGahRFLzkLdkTj8_keTN0I0,205
|
24
|
+
bayinx/mhx/vi/meanfield.py,sha256=Z7kGQAyp5iB8rEdjbwAbVTFH4GwxlTKDZFbdJ-FN5Vs,3739
|
25
|
+
bayinx/mhx/vi/normalizing_flow.py,sha256=8pLMDdZPIt5wlgbhHWSFY1ChSWM9pvSD2bQx3zgz1F8,4710
|
26
|
+
bayinx/mhx/vi/standard.py,sha256=W-ZvigJkUpqVlREgiFm9io8ansT1XpZwq5AqSmdv--E,1578
|
27
|
+
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
28
|
+
bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
|
29
|
+
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
30
|
+
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
31
|
+
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
32
|
+
bayinx-0.3.6.dist-info/METADATA,sha256=WEdMVyISWGgK0KJvuSlkpbObsxiVfGvIxky7OsuYdXg,3079
|
33
|
+
bayinx-0.3.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
34
|
+
bayinx-0.3.6.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
|
35
|
+
bayinx-0.3.6.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 Todd McCready
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
bayinx-0.3.4.dist-info/RECORD
DELETED
@@ -1,31 +0,0 @@
|
|
1
|
-
bayinx/__init__.py,sha256=5fb_tGeEVnrNt6IQqu7gZaJskBJHqjcg08JRPrY2ANo,139
|
2
|
-
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
bayinx/constraints/__init__.py,sha256=PSxvcuSox2JL61AG1iag2PTNKPcid_DbOQzHpYdj5RE,52
|
4
|
-
bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
|
5
|
-
bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
|
6
|
-
bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
|
7
|
-
bayinx/core/flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
8
|
-
bayinx/core/model.py,sha256=1vQPVjE0ebCdW7mLuabgQcCTi95o8n8CC6GuzJdNL1s,2956
|
9
|
-
bayinx/core/parameter.py,sha256=eECqvfMNWSU8_CkGYaAfOCneMMQGZI21kF0mErsh2Rc,1080
|
10
|
-
bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
|
11
|
-
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
13
|
-
bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
|
14
|
-
bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
|
15
|
-
bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
|
16
|
-
bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
bayinx/dists/censored/gamma2/__init__.py,sha256=2EaQcgCXEwaRoHChVlD02ZMfgiwQAqey6uLPov1lcwE,21
|
18
|
-
bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
|
19
|
-
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
20
|
-
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
21
|
-
bayinx/mhx/vi/meanfield.py,sha256=M4QrOuHaIMLTuQSD6JNF9vELnTm370tXV68JPB7B67M,3652
|
22
|
-
bayinx/mhx/vi/normalizing_flow.py,sha256=9c5ayMJ_Wgq6pUb1GYHIFIzw3Bf1AsIdOjcerLoYMrA,4655
|
23
|
-
bayinx/mhx/vi/standard.py,sha256=DfSV0r9oXzp9UM8OsZBpoJPRUhiDoAq_X2_2z_M83lA,1685
|
24
|
-
bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
|
25
|
-
bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
|
26
|
-
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
27
|
-
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
28
|
-
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
29
|
-
bayinx-0.3.4.dist-info/METADATA,sha256=EpVIXPifXNloZfCCWNuNaVhWO_dMEujN3V_kVZz2Q6Y,3057
|
30
|
-
bayinx-0.3.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
31
|
-
bayinx-0.3.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|