bayinx 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from bayinx.core import Model as Model
2
- from bayinx.core import Parameter as Parameter
3
- from bayinx.core.model import constrain as constrain
1
+ from bayinx.core import Model, Parameter, constrain
2
+
3
+ __all__ = ["Model", "Parameter", "constrain"]
@@ -1 +1,3 @@
1
- from bayinx.constraints.lower import Lower as Lower
1
+ from bayinx.constraints.lower import Lower
2
+
3
+ __all__ = ['Lower']
@@ -5,8 +5,7 @@ import jax.numpy as jnp
5
5
  import jax.tree as jt
6
6
  from jaxtyping import PyTree, Scalar, ScalarLike
7
7
 
8
- from bayinx.core.constraint import Constraint
9
- from bayinx.core.parameter import Parameter
8
+ from bayinx.core import Constraint, Parameter
10
9
 
11
10
 
12
11
  class Lower(Constraint):
@@ -39,8 +38,8 @@ class Lower(Constraint):
39
38
  dyn_params, static_params = eqx.partition(x, filter_spec)
40
39
 
41
40
  # Compute density adjustment
42
- laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
43
- laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
41
+ laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
42
+ laj: Scalar = jt.reduce(lambda a, b: a + b, laj)
44
43
 
45
44
  # Compute transformation
46
45
  dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
bayinx/core/__init__.py CHANGED
@@ -1,4 +1,7 @@
1
- from bayinx.core.flow import Flow as Flow
2
- from bayinx.core.model import Model as Model
3
- from bayinx.core.parameter import Parameter as Parameter
4
- from bayinx.core.variational import Variational as Variational
1
+ from ._constraint import Constraint
2
+ from ._flow import Flow
3
+ from ._model import Model, constrain
4
+ from ._parameter import Parameter
5
+ from ._variational import Variational
6
+
7
+ __all__ = ["Constraint", "Flow", "Model", "constrain", "Parameter", "Variational"]
@@ -4,7 +4,7 @@ from typing import Tuple
4
4
  import equinox as eqx
5
5
  from jaxtyping import Scalar
6
6
 
7
- from bayinx.core.parameter import Parameter
7
+ from bayinx.core._parameter import Parameter
8
8
 
9
9
 
10
10
  class Constraint(eqx.Module):
@@ -7,13 +7,13 @@ import jax.numpy as jnp
7
7
  import jax.tree as jt
8
8
  from jaxtyping import Scalar
9
9
 
10
- from bayinx.core.constraint import Constraint
11
- from bayinx.core.parameter import Parameter
10
+ from ._constraint import Constraint
11
+ from ._parameter import Parameter
12
12
 
13
13
 
14
14
  def constrain(constraint: Constraint):
15
15
  """Defines constraint metadata."""
16
- return field(metadata={'constraint': constraint})
16
+ return field(metadata={"constraint": constraint})
17
17
 
18
18
 
19
19
  class Model(eqx.Module):
@@ -49,12 +49,11 @@ class Model(eqx.Module):
49
49
  filter_spec = eqx.tree_at(
50
50
  lambda model: getattr(model, f.name),
51
51
  filter_spec,
52
- replace=attr.filter_spec
52
+ replace=attr.filter_spec,
53
53
  )
54
54
 
55
55
  return filter_spec
56
56
 
57
-
58
57
  @eqx.filter_jit
59
58
  def constrain_params(self) -> Tuple[Self, Scalar]:
60
59
  """
@@ -71,18 +70,16 @@ class Model(eqx.Module):
71
70
  attr = getattr(self, f.name)
72
71
 
73
72
  # Check if constrained parameter
74
- if isinstance(attr, Parameter) and 'constraint' in f.metadata:
73
+ if isinstance(attr, Parameter) and "constraint" in f.metadata:
75
74
  param = attr
76
- constraint = f.metadata['constraint']
75
+ constraint = f.metadata["constraint"]
77
76
 
78
77
  # Apply constraint
79
78
  param, laj = constraint.constrain(param)
80
79
 
81
80
  # Update parameters for constrained model
82
81
  constrained = eqx.tree_at(
83
- lambda model: getattr(model, f.name),
84
- constrained,
85
- replace=param
82
+ lambda model: getattr(model, f.name), constrained, replace=param
86
83
  )
87
84
 
88
85
  # Adjust posterior density
@@ -90,7 +87,6 @@ class Model(eqx.Module):
90
87
 
91
88
  return constrained, target
92
89
 
93
-
94
90
  @eqx.filter_jit
95
91
  def transform_params(self) -> Tuple[Self, Scalar]:
96
92
  """
@@ -4,7 +4,7 @@ import equinox as eqx
4
4
  import jax.tree as jt
5
5
  from jaxtyping import PyTree
6
6
 
7
- T = TypeVar('T', bound=PyTree)
7
+ T = TypeVar("T", bound=PyTree)
8
8
  class Parameter(eqx.Module, Generic[T]):
9
9
  """
10
10
  A container for a parameter of a `Model`.
@@ -14,8 +14,8 @@ class Parameter(eqx.Module, Generic[T]):
14
14
  # Attributes
15
15
  - `vals`: The parameter's value(s).
16
16
  """
17
- vals: T
18
17
 
18
+ vals: T
19
19
 
20
20
  def __init__(self, values: T):
21
21
  # Insert parameter values
@@ -1,6 +1,6 @@
1
1
  from abc import abstractmethod
2
2
  from functools import partial
3
- from typing import Any, Callable, Self, Tuple
3
+ from typing import Any, Callable, Generic, Self, Tuple, TypeVar
4
4
 
5
5
  import equinox as eqx
6
6
  import jax
@@ -11,10 +11,10 @@ import optax as opx
11
11
  from jaxtyping import Array, Key, PyTree, Scalar
12
12
  from optax import GradientTransformation, OptState, Schedule
13
13
 
14
- from bayinx.core import Model
14
+ from ._model import Model
15
15
 
16
-
17
- class Variational(eqx.Module):
16
+ M = TypeVar('M', bound=Model)
17
+ class Variational(eqx.Module, Generic[M]):
18
18
  """
19
19
  An abstract base class used to define variational methods.
20
20
 
@@ -23,8 +23,8 @@ class Variational(eqx.Module):
23
23
  - `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
24
24
  """
25
25
 
26
- _unflatten: Callable[[Array], Model]
27
- _constraints: Model
26
+ _unflatten: Callable[[Array], M]
27
+ _constraints: M
28
28
 
29
29
  @abstractmethod
30
30
  def filter_spec(self):
@@ -34,7 +34,7 @@ class Variational(eqx.Module):
34
34
  pass
35
35
 
36
36
  @abstractmethod
37
- def sample(self, n: int, key: Key) -> Array:
37
+ def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
38
38
  """
39
39
  Sample from the variational distribution.
40
40
  """
@@ -72,10 +72,10 @@ class Variational(eqx.Module):
72
72
  - `data`: Data used to evaluate the posterior(if needed).
73
73
  """
74
74
  # Unflatten variational draw
75
- model: Model = self._unflatten(draws)
75
+ model: M = self._unflatten(draws)
76
76
 
77
77
  # Combine with constraints
78
- model: Model = eqx.combine(model, self._constraints)
78
+ model: M = eqx.combine(model, self._constraints)
79
79
 
80
80
  # Evaluate posterior density
81
81
  return model.eval(data)
@@ -160,3 +160,22 @@ class Variational(eqx.Module):
160
160
 
161
161
  # Return optimized variational
162
162
  return eqx.combine(dyn, static)
163
+
164
+ @eqx.filter_jit
165
+ def posterior_predictive(
166
+ self, func: Callable[[M], Array], n: int, key: Key = jr.PRNGKey(0)
167
+ ) -> Array:
168
+ # Sample draws from the variational approximation
169
+ draws: Array = self.sample(n, key)
170
+
171
+ # Evaluate posterior predictive
172
+ @jax.jit
173
+ @jax.vmap
174
+ def evaluate(draw: Array):
175
+ # Reconstruct model
176
+ model: M = self._unflatten(draw)
177
+
178
+ # Evaluate
179
+ return func(model)
180
+
181
+ return evaluate(draws)
bayinx/dists/__init__.py CHANGED
@@ -0,0 +1,3 @@
1
+ from bayinx.dists import censored, gamma2, normal, posnormal
2
+
3
+ __all__ = ['censored', "gamma2", "normal", "posnormal"]
@@ -0,0 +1,3 @@
1
+ from . import posnormal
2
+
3
+ __all__ = ["posnormal"]
@@ -1 +1,3 @@
1
- from . import r as r
1
+ from . import r
2
+
3
+ __all__ = ["r"]
@@ -10,7 +10,7 @@ def prob(
10
10
  x: Float[ArrayLike, "..."],
11
11
  mu: Float[ArrayLike, "..."],
12
12
  nu: Float[ArrayLike, "..."],
13
- censor: Float[ArrayLike, "..."]
13
+ censor: Float[ArrayLike, "..."],
14
14
  ) -> Float[Array, "..."]:
15
15
  """
16
16
  The mixed probability mass/density function (PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
@@ -19,19 +19,20 @@ def prob(
19
19
  - `x`: Value(s) at which to evaluate the PMF/PDF.
20
20
  - `mu`: The positive mean.
21
21
  - `nu`: The positive inverse dispersion.
22
+ - `censor`: The positive censor value.
22
23
 
23
24
  # Returns
24
- The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
25
+ The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
25
26
  """
26
- evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
27
+ evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
27
28
 
28
29
  # Construct boolean masks
29
- uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
30
- censored: Array = jnp.array(x == censor) # pyright: ignore
30
+ uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
31
+ censored: Array = jnp.array(x == censor) # pyright: ignore
31
32
 
32
- # Evaluate mixed probability (?) function
33
+ # Evaluate probability mass/density function
33
34
  evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
34
- evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
35
+ evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
35
36
 
36
37
  return evals
37
38
 
@@ -40,7 +41,7 @@ def logprob(
40
41
  x: Float[ArrayLike, "..."],
41
42
  mu: Float[ArrayLike, "..."],
42
43
  nu: Float[ArrayLike, "..."],
43
- censor: Float[ArrayLike, "..."]
44
+ censor: Float[ArrayLike, "..."],
44
45
  ) -> Float[Array, "..."]:
45
46
  """
46
47
  The log-transformed mixed probability mass/density function (log PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
@@ -49,17 +50,19 @@ def logprob(
49
50
  - `x`: Value(s) at which to evaluate the log PMF/PDF.
50
51
  - `mu`: The positive mean/location.
51
52
  - `nu`: The positive inverse dispersion.
53
+ - `censor`: The positive censor value.
52
54
 
53
55
  # Returns
54
- The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
56
+ The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
55
57
  """
56
- evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
58
+ evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
57
59
 
58
60
  # Construct boolean masks
59
- uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
60
- censored: Array = jnp.array(x == censor) # pyright: ignore
61
+ uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
62
+ censored: Array = jnp.array(x == censor) # pyright: ignore
61
63
 
64
+ # Evaluate log probability mass/density function
62
65
  evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
63
- evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
66
+ evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
64
67
 
65
68
  return evals
@@ -0,0 +1,3 @@
1
+ from . import r
2
+
3
+ __all__ = ["r"]
@@ -0,0 +1,78 @@
1
+ import jax.numpy as jnp
2
+ from jaxtyping import Array, ArrayLike, Float
3
+
4
+ from bayinx.dists import posnormal
5
+
6
+
7
+ def prob(
8
+ x: Float[ArrayLike, "..."],
9
+ mu: Float[ArrayLike, "..."],
10
+ sigma: Float[ArrayLike, "..."],
11
+ censor: Float[ArrayLike, "..."],
12
+ ) -> Float[Array, "..."]:
13
+ """
14
+ The mixed probability mass/density function (PMF/PDF) for a right-censored positive Normal distribution.
15
+
16
+ # Parameters
17
+ - `x`: Value(s) at which to evaluate the PMF/PDF.
18
+ - `mu`: The mean.
19
+ - `sigma`: The positive standard deviation.
20
+ - `censor`: The positive censor value.
21
+
22
+ # Returns
23
+ The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
24
+ """
25
+ # Cast to Array
26
+ x, mu, sigma, censor = (
27
+ jnp.asarray(x),
28
+ jnp.asarray(mu),
29
+ jnp.asarray(sigma),
30
+ jnp.asarray(censor),
31
+ )
32
+
33
+ # Construct boolean masks
34
+ uncensored: Array = jnp.logical_and(0.0 < x, x < censor)
35
+ censored: Array = x == censor
36
+
37
+ # Evaluate probability mass/density function
38
+ evals = jnp.where(uncensored, posnormal.prob(x, mu, sigma), 0.0)
39
+ evals = jnp.where(censored, posnormal.ccdf(x, mu, sigma), evals)
40
+
41
+ return evals
42
+
43
+
44
+ def logprob(
45
+ x: Float[ArrayLike, "..."],
46
+ mu: Float[ArrayLike, "..."],
47
+ sigma: Float[ArrayLike, "..."],
48
+ censor: Float[ArrayLike, "..."],
49
+ ) -> Float[Array, "..."]:
50
+ """
51
+ The log-transformed mixed probability mass/density function (log PMF/PDF) for a right-censored positive Normal distribution.
52
+
53
+ # Parameters
54
+ - `x`: Where to evaluate the log PMF/PDF.
55
+ - `mu`: The mean.
56
+ - `sigma`: The standard deviation.
57
+ - `censor`: The censor.
58
+
59
+ # Returns
60
+ The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
61
+ """
62
+ # Cast to Array
63
+ x, mu, sigma, censor = (
64
+ jnp.asarray(x),
65
+ jnp.asarray(mu),
66
+ jnp.asarray(sigma),
67
+ jnp.asarray(censor),
68
+ )
69
+
70
+ # Construct boolean masks for censoring
71
+ uncensored: Array = jnp.logical_and(jnp.asarray(0.0) < x, x < censor)
72
+ censored: Array = x == censor
73
+
74
+ # Evaluate log probability mass/density function
75
+ evals = jnp.where(uncensored, posnormal.logprob(x, mu, sigma), -jnp.inf)
76
+ evals = jnp.where(censored, posnormal.logccdf(x, mu, sigma), evals)
77
+
78
+ return evals
bayinx/dists/gamma2.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import jax.lax as lax
2
+ import jax.numpy as jnp
2
3
  from jax.scipy.special import gammaln
3
4
  from jaxtyping import Array, ArrayLike, Float
4
5
 
@@ -17,6 +18,8 @@ def prob(
17
18
  # Returns
18
19
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
19
20
  """
21
+ # Cast to Array
22
+ x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
20
23
 
21
24
  return lax.exp(logprob(x, mu, nu))
22
25
 
@@ -35,5 +38,12 @@ def logprob(
35
38
  # Returns
36
39
  The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
37
40
  """
38
-
39
- return - gammaln(nu) + nu * (lax.log(nu) - lax.log(mu)) + (nu - 1.0) * lax.log(x) - (x * nu / mu) # pyright: ignore
41
+ # Cast to Array
42
+ x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
43
+
44
+ return (
45
+ -gammaln(nu)
46
+ + nu * (lax.log(nu) - lax.log(mu))
47
+ + (nu - 1.0) * lax.log(x)
48
+ - (x * nu / mu)
49
+ ) # pyright: ignore
bayinx/dists/normal.py CHANGED
@@ -1,80 +1,138 @@
1
1
  import jax.lax as lax
2
+ import jax.numpy as jnp
3
+ import jax.scipy.special as jss
2
4
  from jaxtyping import Array, ArrayLike, Float
3
5
 
4
6
  __PI = 3.141592653589793
5
7
 
6
8
 
7
9
  def prob(
8
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
10
+ x: Float[ArrayLike, "..."],
11
+ mu: Float[ArrayLike, "..."],
12
+ sigma: Float[ArrayLike, "..."],
9
13
  ) -> Float[Array, "..."]:
10
14
  """
11
15
  The probability density function (PDF) for a Normal distribution.
12
16
 
13
17
  # Parameters
14
- - `x`: Value(s) at which to evaluate the PDF.
15
- - `mu`: The mean/location.
16
- - `sigma`: The positive standard deviation.
18
+ - `x`: Where to evaluate the PDF.
19
+ - `mu`: The mean.
20
+ - `sigma`: The standard deviation.
17
21
 
18
22
  # Returns
19
23
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
20
24
  """
25
+ # Cast to Array
26
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
21
27
 
22
- return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / ( # pyright: ignore
23
- sigma * lax.sqrt(2.0 * __PI)
24
- )
28
+ return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (sigma * lax.sqrt(2.0 * __PI))
25
29
 
26
30
 
27
31
  def logprob(
28
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
32
+ x: Float[ArrayLike, "..."],
33
+ mu: Float[ArrayLike, "..."],
34
+ sigma: Float[ArrayLike, "..."],
29
35
  ) -> Float[Array, "..."]:
30
36
  """
31
37
  The log of the probability density function (log PDF) for a Normal distribution.
32
38
 
33
39
  # Parameters
34
- - `x`: Value(s) at which to evaluate the log PDF.
35
- - `mu`: The mean/location parameter(s).
36
- - `sigma`: The non-negative standard deviation parameter(s).
40
+ - `x`: Where to evaluate the log PDF.
41
+ - `mu`: The mean.
42
+ - `sigma`: The standard deviation.
37
43
 
38
44
  # Returns
39
45
  The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
40
46
  """
47
+ # Cast to Array
48
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
41
49
 
42
- return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square(
43
- (x - mu) / sigma # pyright: ignore
44
- )
50
+ return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square((x - mu) / sigma)
45
51
 
46
52
 
47
53
  def uprob(
48
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
54
+ x: Float[ArrayLike, "..."],
55
+ mu: Float[ArrayLike, "..."],
56
+ sigma: Float[ArrayLike, "..."],
49
57
  ) -> Float[Array, "..."]:
50
58
  """
51
59
  The unnormalized probability density function (uPDF) for a Normal distribution.
52
60
 
53
61
  # Parameters
54
- - `x`: Value(s) at which to evaluate the uPDF.
55
- - `mu`: The mean/location parameter(s).
56
- - `sigma`: The positive standard deviation parameter(s).
62
+ - `x`: Where to evaluate the PDF.
63
+ - `mu`: The mean.
64
+ - `sigma`: The standard deviation.
57
65
 
58
66
  # Returns
59
67
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
60
68
  """
69
+ # Cast to Array
70
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
61
71
 
62
- return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma # pyright: ignore
72
+ return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma
63
73
 
64
74
 
65
75
  def ulogprob(
66
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
76
+ x: Float[ArrayLike, "..."],
77
+ mu: Float[ArrayLike, "..."],
78
+ sigma: Float[ArrayLike, "..."],
67
79
  ) -> Float[Array, "..."]:
68
80
  """
69
81
  The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
70
82
 
71
83
  # Parameters
72
- - `x`: Value(s) at which to evaluate the log uPDF.
73
- - `mu`: The mean/location parameter(s).
74
- - `sigma`: The non-negative standard deviation parameter(s).
84
+ - `x`: Where to evaluate the PDF.
85
+ - `mu`: The mean.
86
+ - `sigma`: The standard deviation.
75
87
 
76
88
  # Returns
77
89
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
78
90
  """
91
+ # Cast to Array
92
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
79
93
 
80
- return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma) # pyright: ignore
94
+ return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
95
+
96
+
97
+ def cdf(
98
+ x: Float[ArrayLike, "..."],
99
+ mu: Float[ArrayLike, "..."],
100
+ sigma: Float[ArrayLike, "..."],
101
+ ) -> Float[Array, "..."]:
102
+ # Cast to Array
103
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
104
+
105
+ return jss.ndtr((x - mu) / sigma)
106
+
107
+
108
+ def logcdf(
109
+ x: Float[ArrayLike, "..."],
110
+ mu: Float[ArrayLike, "..."],
111
+ sigma: Float[ArrayLike, "..."],
112
+ ) -> Float[Array, "..."]:
113
+ # Cast to Array
114
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
115
+
116
+ return jss.log_ndtr((x - mu) / sigma)
117
+
118
+
119
+ def ccdf(
120
+ x: Float[ArrayLike, "..."],
121
+ mu: Float[ArrayLike, "..."],
122
+ sigma: Float[ArrayLike, "..."],
123
+ ) -> Float[Array, "..."]:
124
+ # Cast to Array
125
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
126
+
127
+ return jss.ndtr((mu - x) / sigma)
128
+
129
+
130
+ def logccdf(
131
+ x: Float[ArrayLike, "..."],
132
+ mu: Float[ArrayLike, "..."],
133
+ sigma: Float[ArrayLike, "..."],
134
+ ) -> Float[Array, "..."]:
135
+ # Cast to Array
136
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
137
+
138
+ return jss.log_ndtr((mu - x) / sigma)
@@ -0,0 +1,260 @@
1
+ import jax.numpy as jnp
2
+ from jaxtyping import Array, ArrayLike, Float
3
+
4
+ from bayinx.dists import normal
5
+
6
+
7
+ def prob(
8
+ x: Float[ArrayLike, "..."],
9
+ mu: Float[ArrayLike, "..."],
10
+ sigma: Float[ArrayLike, "..."],
11
+ ) -> Float[Array, "..."]:
12
+ """
13
+ The probability density function (PDF) for a positive Normal distribution.
14
+
15
+ # Parameters
16
+ - `x`: Where to evaluate the PDF.
17
+ - `mu`: The mean.
18
+ - `sigma`: The standard deviation.
19
+
20
+ # Returns
21
+ The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
22
+ """
23
+ # Cast to Array
24
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
25
+
26
+ # Construct boolean mask for non-negative elements
27
+ non_negative: Array = jnp.asarray(0.0) <= x
28
+
29
+ # Evaluate PDF
30
+ evals = jnp.where(
31
+ non_negative,
32
+ normal.prob(x, mu, sigma) / normal.cdf(mu / sigma, 0.0, 1.0),
33
+ jnp.asarray(0.0),
34
+ )
35
+
36
+ return evals
37
+
38
+
39
+ def logprob(
40
+ x: Float[ArrayLike, "..."],
41
+ mu: Float[ArrayLike, "..."],
42
+ sigma: Float[ArrayLike, "..."],
43
+ ) -> Float[Array, "..."]:
44
+ """
45
+ The log of the probability density function (log PDF) for a positive Normal distribution.
46
+
47
+ # Parameters
48
+ - `x`: Where to evaluate the log PDF.
49
+ - `mu`: The mean.
50
+ - `sigma`: The standard deviation.
51
+
52
+ # Returns
53
+ The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
54
+ """
55
+ # Cast to Array
56
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
57
+
58
+ # Construct boolean mask for non-negative elements
59
+ non_negative: Array = jnp.asarray(0.0) <= x
60
+
61
+ # Evaluate log PDF
62
+ evals = jnp.where(
63
+ non_negative,
64
+ normal.logprob(x, mu, sigma) - normal.logcdf(mu / sigma, 0.0, 1.0),
65
+ -jnp.inf,
66
+ )
67
+
68
+ return evals
69
+
70
+
71
+ def uprob(
72
+ x: Float[ArrayLike, "..."],
73
+ mu: Float[ArrayLike, "..."],
74
+ sigma: Float[ArrayLike, "..."],
75
+ ) -> Float[Array, "..."]:
76
+ """
77
+ The unnormalized probability density function (uPDF) for a positive Normal distribution.
78
+
79
+ # Parameters
80
+ - `x`: Where to evaluate the uPDF.
81
+ - `mu`: The mean.
82
+ - `sigma`: The standard deviation.
83
+
84
+ # Returns
85
+ The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
86
+ """
87
+ # Cast to Array
88
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
89
+
90
+ # Construct boolean mask for non-negative elements
91
+ non_negative: Array = jnp.asarray(0.0) <= x
92
+
93
+ # Evaluate PDF
94
+ evals = jnp.where(non_negative, normal.prob(x, mu, sigma), jnp.asarray(0.0))
95
+
96
+ return evals
97
+
98
+
99
+ def ulogprob(
100
+ x: Float[ArrayLike, "..."],
101
+ mu: Float[ArrayLike, "..."],
102
+ sigma: Float[ArrayLike, "..."],
103
+ ) -> Float[Array, "..."]:
104
+ """
105
+ The log of the unnormalized probability density function (log uPDF) for a positive Normal distribution.
106
+
107
+ # Parameters
108
+ - `x`: Where to evaluate the log uPDF.
109
+ - `mu`: The mean.
110
+ - `sigma`: The standard deviation.
111
+
112
+ # Returns
113
+ The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
114
+ """
115
+ # Cast to Array
116
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
117
+
118
+ # Construct boolean mask for non-negative elements
119
+ non_negative: Array = jnp.asarray(0.0) <= x
120
+
121
+ # Evaluate log PDF
122
+ evals = jnp.where(non_negative, normal.logprob(x, mu, sigma), -jnp.inf)
123
+
124
+ return evals
125
+
126
+
127
+ def cdf(
128
+ x: Float[ArrayLike, "..."],
129
+ mu: Float[ArrayLike, "..."],
130
+ sigma: Float[ArrayLike, "..."],
131
+ ) -> Float[Array, "..."]:
132
+ """
133
+ The cumulative density function (CDF) for a positive Normal distribution.
134
+
135
+ # Parameters
136
+ - `x`: Where to evaluate the CDF.
137
+ - `mu`: The mean.
138
+ - `sigma`: The standard deviation.
139
+
140
+ # Returns
141
+ The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
142
+
143
+ # Notes
144
+ Not numerically stable for small `x`.
145
+ """
146
+ # Cast to Array
147
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
148
+
149
+ # Construct boolean mask for non-negative elements
150
+ non_negative: Array = jnp.asarray(0.0) <= x
151
+
152
+ # Compute intermediates
153
+ A: Array = normal.cdf(x, mu, sigma)
154
+ B: Array = normal.cdf(-mu / sigma, 0.0, 1.0)
155
+ C: Array = normal.cdf(mu / sigma, 0.0, 1.0)
156
+
157
+ # Evaluate CDF
158
+ evals = jnp.where(non_negative, (A - B) / C, jnp.asarray(0.0))
159
+
160
+ return evals
161
+
162
+
163
+ # TODO: make numerically stable
164
+ def logcdf(
165
+ x: Float[ArrayLike, "..."],
166
+ mu: Float[ArrayLike, "..."],
167
+ sigma: Float[ArrayLike, "..."],
168
+ ) -> Float[Array, "..."]:
169
+ """
170
+ The log-transformed cumulative density function (log CDF) for a positive Normal distribution.
171
+
172
+ # Parameters
173
+ - `x`: Where to evaluate the log CDF.
174
+ - `mu`: The mean.
175
+ - `sigma`: The standard deviation.
176
+
177
+ # Returns
178
+ The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
179
+
180
+ # Notes
181
+ Not numerically stable for small `x`.
182
+ """
183
+ # Cast to Array
184
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
185
+
186
+ # Construct boolean mask for non-negative elements
187
+ non_negative: Array = jnp.asarray(0.0) <= x
188
+
189
+ A: Array = normal.logcdf(x, mu, sigma)
190
+ B: Array = normal.logcdf(-mu / sigma, 0.0, 1.0)
191
+ C: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
192
+
193
+ # Evaluate log CDF
194
+ evals = jnp.where(non_negative, A + jnp.log1p(-jnp.exp(B - A)) - C, -jnp.inf)
195
+
196
+ return evals
197
+
198
+
199
+ def ccdf(
200
+ x: Float[ArrayLike, "..."],
201
+ mu: Float[ArrayLike, "..."],
202
+ sigma: Float[ArrayLike, "..."],
203
+ ) -> Float[Array, "..."]:
204
+ """
205
+ The complementary cumulative density function (cCDF) for a positive Normal distribution.
206
+
207
+ # Parameters
208
+ - `x`: Where to evaluate the cCDF.
209
+ - `mu`: The mean.
210
+ - `sigma`: The standard deviation.
211
+
212
+ # Returns
213
+ The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
214
+ """
215
+ # Cast to arrays
216
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
217
+
218
+ # Construct boolean mask for non-negative elements
219
+ non_negative: Array = 0.0 <= x
220
+
221
+ # Compute intermediates
222
+ A: Array = normal.cdf(-x, -mu, sigma)
223
+ B: Array = normal.cdf(mu / sigma, 0.0, 1.0)
224
+
225
+ # Evaluate cCDF
226
+ evals = jnp.where(non_negative, A / B, jnp.asarray(1.0))
227
+
228
+ return evals
229
+
230
+
231
+ def logccdf(
232
+ x: Float[ArrayLike, "..."],
233
+ mu: Float[ArrayLike, "..."],
234
+ sigma: Float[ArrayLike, "..."],
235
+ ) -> Float[Array, "..."]:
236
+ """
237
+ The log-transformed complementary cumulative density function (log cCDF) for a positive Normal distribution.
238
+
239
+ # Parameters
240
+ - `x`: Where to evaluate the log cCDF.
241
+ - `mu`: The mean.
242
+ - `sigma`: The standard deviation.
243
+
244
+ # Returns
245
+ The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
246
+ """
247
+ # Cast to arrays
248
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
249
+
250
+ # Construct boolean mask for non-negative elements
251
+ non_negative: Array = 0.0 <= x
252
+
253
+ # Compute intermediates
254
+ A: Array = normal.logcdf(-x, -mu, sigma)
255
+ B: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
256
+
257
+ # Evaluate log cCDF
258
+ evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
259
+
260
+ return evals
bayinx/mhx/vi/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
- from bayinx.mhx.vi.meanfield import MeanField as MeanField
2
- from bayinx.mhx.vi.normalizing_flow import NormalizingFlow as NormalizingFlow
3
- from bayinx.mhx.vi.standard import Standard as Standard
1
+ from bayinx.mhx.vi.meanfield import MeanField
2
+ from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
3
+ from bayinx.mhx.vi.standard import Standard
4
+
5
+ __all__ = ['MeanField', 'NormalizingFlow', 'Standard']
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Self
1
+ from typing import Any, Dict, Generic, Self, TypeVar
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.numpy as jnp
@@ -10,8 +10,8 @@ from jaxtyping import Array, Float, Key, Scalar
10
10
  from bayinx.core import Model, Variational
11
11
  from bayinx.dists import normal
12
12
 
13
-
14
- class MeanField(Variational):
13
+ M = TypeVar('M', bound=Model)
14
+ class MeanField(Variational, Generic[M]):
15
15
  """
16
16
  A fully factorized Gaussian approximation to a posterior distribution.
17
17
 
@@ -19,9 +19,9 @@ class MeanField(Variational):
19
19
  - `var_params`: The variational parameters for the approximation.
20
20
  """
21
21
 
22
- var_params: Dict[str, Float[Array, "..."]]
22
+ var_params: Dict[str, Float[Array, "..."]] #todo: just expand to attributes
23
23
 
24
- def __init__(self, model: Model):
24
+ def __init__(self, model: M):
25
25
  """
26
26
  Constructs an unoptimized meanfield posterior approximation.
27
27
 
@@ -55,7 +55,6 @@ class MeanField(Variational):
55
55
 
56
56
  return filter_spec
57
57
 
58
-
59
58
  @eqx.filter_jit
60
59
  def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
61
60
  # Sample variational draws
@@ -1,4 +1,4 @@
1
- from typing import Any, Self, Tuple
1
+ from typing import Any, Generic, Self, Tuple, TypeVar
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.flatten_util as jfu
@@ -9,8 +9,8 @@ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Flow, Model, Variational
11
11
 
12
-
13
- class NormalizingFlow(Variational):
12
+ M = TypeVar('M', bound=Model)
13
+ class NormalizingFlow(Variational, Generic[M]):
14
14
  """
15
15
  An ordered collection of diffeomorphisms that map a base distribution to a
16
16
  normalized approximation of a posterior distribution.
@@ -23,7 +23,7 @@ class NormalizingFlow(Variational):
23
23
  flows: list[Flow]
24
24
  base: Variational
25
25
 
26
- def __init__(self, base: Variational, flows: list[Flow], model: Model):
26
+ def __init__(self, base: Variational, flows: list[Flow], model: M):
27
27
  """
28
28
  Constructs an unoptimized normalizing flow posterior approximation.
29
29
 
bayinx/mhx/vi/standard.py CHANGED
@@ -1,29 +1,25 @@
1
- from typing import Callable
2
1
 
3
2
  import equinox as eqx
4
3
  import jax.numpy as jnp
5
4
  import jax.random as jr
6
5
  import jax.tree_util as jtu
7
6
  from jax.flatten_util import ravel_pytree
8
- from jaxtyping import Array, Float, Key
7
+ from jaxtyping import Array, Key
9
8
 
10
- from bayinx.core import Model, Variational
9
+ from bayinx.core._variational import M, Variational
11
10
  from bayinx.dists import normal
12
11
 
13
12
 
14
- class Standard(Variational):
13
+ class Standard(Variational[M]):
15
14
  """
16
15
  A standard normal approximation to a posterior distribution.
17
16
 
18
17
  # Attributes
19
18
  - `dim`: Dimension of the parameter space.
20
19
  """
21
-
22
20
  dim: int
23
- _unflatten: Callable[[Float[Array, "..."]], Model]
24
- _constraints: Model
25
21
 
26
- def __init__(self, model: Model):
22
+ def __init__(self, model: M):
27
23
  """
28
24
  Constructs a standard normal approximation to a posterior distribution.
29
25
 
@@ -1,7 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.4
3
+ Version: 0.3.6
4
4
  Summary: Bayesian Inference with JAX
5
+ License-File: LICENSE
5
6
  Requires-Python: >=3.12
6
7
  Requires-Dist: equinox>=0.11.12
7
8
  Requires-Dist: jax>=0.4.38
@@ -0,0 +1,35 @@
1
+ bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
2
+ bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ bayinx/constraints/__init__.py,sha256=PiWXZKi7YdbTMKvw-OE5f-t87jJT893uAFrwWWBfOdg,64
4
+ bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
5
+ bayinx/core/__init__.py,sha256=bZvQITgW0DWuPKl3wCLKt6WHKogYKx8Zz36g8z9Aung,253
6
+ bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
7
+ bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
+ bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
9
+ bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
10
+ bayinx/core/_variational.py,sha256=szm1WuUh_3pxzFfQy92TR4p2Sk-fR6rO-4-LrJMeVGI,5356
11
+ bayinx/dists/__init__.py,sha256=9DdPea7HAnBOzaV_4gM5noPX8YCb_p06d8PJvGfFy3Y,118
12
+ bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
13
+ bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
14
+ bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
15
+ bayinx/dists/posnormal.py,sha256=w9plA1EctXwXOiY0doc4ZndjnwptbEZBHHCGdc4gviY,7292
16
+ bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
17
+ bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
18
+ bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
19
+ bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
20
+ bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
21
+ bayinx/dists/censored/posnormal/r.py,sha256=hyuNR3HZY-Tgtso-WwjcZT6Ejxfyax_VKwIvVix44Jc,2362
22
+ bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
23
+ bayinx/mhx/vi/__init__.py,sha256=2woNB5oZxfs8pZCkOfzriGahRFLzkLdkTj8_keTN0I0,205
24
+ bayinx/mhx/vi/meanfield.py,sha256=Z7kGQAyp5iB8rEdjbwAbVTFH4GwxlTKDZFbdJ-FN5Vs,3739
25
+ bayinx/mhx/vi/normalizing_flow.py,sha256=8pLMDdZPIt5wlgbhHWSFY1ChSWM9pvSD2bQx3zgz1F8,4710
26
+ bayinx/mhx/vi/standard.py,sha256=W-ZvigJkUpqVlREgiFm9io8ansT1XpZwq5AqSmdv--E,1578
27
+ bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
28
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
29
+ bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
30
+ bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
31
+ bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
32
+ bayinx-0.3.6.dist-info/METADATA,sha256=WEdMVyISWGgK0KJvuSlkpbObsxiVfGvIxky7OsuYdXg,3079
33
+ bayinx-0.3.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
34
+ bayinx-0.3.6.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
35
+ bayinx-0.3.6.dist-info/RECORD,,
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Todd McCready
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -1,31 +0,0 @@
1
- bayinx/__init__.py,sha256=5fb_tGeEVnrNt6IQqu7gZaJskBJHqjcg08JRPrY2ANo,139
2
- bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx/constraints/__init__.py,sha256=PSxvcuSox2JL61AG1iag2PTNKPcid_DbOQzHpYdj5RE,52
4
- bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
5
- bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
6
- bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
7
- bayinx/core/flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
- bayinx/core/model.py,sha256=1vQPVjE0ebCdW7mLuabgQcCTi95o8n8CC6GuzJdNL1s,2956
9
- bayinx/core/parameter.py,sha256=eECqvfMNWSU8_CkGYaAfOCneMMQGZI21kF0mErsh2Rc,1080
10
- bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
11
- bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
13
- bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
14
- bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
15
- bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
16
- bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- bayinx/dists/censored/gamma2/__init__.py,sha256=2EaQcgCXEwaRoHChVlD02ZMfgiwQAqey6uLPov1lcwE,21
18
- bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
19
- bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
20
- bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
21
- bayinx/mhx/vi/meanfield.py,sha256=M4QrOuHaIMLTuQSD6JNF9vELnTm370tXV68JPB7B67M,3652
22
- bayinx/mhx/vi/normalizing_flow.py,sha256=9c5ayMJ_Wgq6pUb1GYHIFIzw3Bf1AsIdOjcerLoYMrA,4655
23
- bayinx/mhx/vi/standard.py,sha256=DfSV0r9oXzp9UM8OsZBpoJPRUhiDoAq_X2_2z_M83lA,1685
24
- bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
25
- bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
26
- bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
27
- bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
28
- bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
29
- bayinx-0.3.4.dist-info/METADATA,sha256=EpVIXPifXNloZfCCWNuNaVhWO_dMEujN3V_kVZz2Q6Y,3057
30
- bayinx-0.3.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
31
- bayinx-0.3.4.dist-info/RECORD,,
File without changes
File without changes