bayinx 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from bayinx.core import Model as Model
2
- from bayinx.core import Parameter as Parameter
3
- from bayinx.core.model import constrain as constrain
1
+ from bayinx.core import Model, Parameter, constrain
2
+
3
+ __all__ = ["Model", "Parameter", "constrain"]
@@ -1 +1,3 @@
1
- from bayinx.constraints.lower import Lower as Lower
1
+ from bayinx.constraints.lower import Lower
2
+
3
+ __all__ = ['Lower']
@@ -5,8 +5,7 @@ import jax.numpy as jnp
5
5
  import jax.tree as jt
6
6
  from jaxtyping import PyTree, Scalar, ScalarLike
7
7
 
8
- from bayinx.core.constraint import Constraint
9
- from bayinx.core.parameter import Parameter
8
+ from bayinx.core import Constraint, Parameter
10
9
 
11
10
 
12
11
  class Lower(Constraint):
@@ -39,8 +38,8 @@ class Lower(Constraint):
39
38
  dyn_params, static_params = eqx.partition(x, filter_spec)
40
39
 
41
40
  # Compute density adjustment
42
- laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
43
- laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
41
+ laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
42
+ laj: Scalar = jt.reduce(lambda a, b: a + b, laj)
44
43
 
45
44
  # Compute transformation
46
45
  dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
bayinx/core/__init__.py CHANGED
@@ -1,4 +1,7 @@
1
- from bayinx.core.flow import Flow as Flow
2
- from bayinx.core.model import Model as Model
3
- from bayinx.core.parameter import Parameter as Parameter
4
- from bayinx.core.variational import Variational as Variational
1
+ from ._constraint import Constraint
2
+ from ._flow import Flow
3
+ from ._model import Model, constrain
4
+ from ._parameter import Parameter
5
+ from ._variational import Variational
6
+
7
+ __all__ = ["Constraint", "Flow", "Model", "constrain", "Parameter", "Variational"]
@@ -4,7 +4,7 @@ from typing import Tuple
4
4
  import equinox as eqx
5
5
  from jaxtyping import Scalar
6
6
 
7
- from bayinx.core.parameter import Parameter
7
+ from bayinx.core._parameter import Parameter
8
8
 
9
9
 
10
10
  class Constraint(eqx.Module):
@@ -7,13 +7,13 @@ import jax.numpy as jnp
7
7
  import jax.tree as jt
8
8
  from jaxtyping import Scalar
9
9
 
10
- from bayinx.core.constraint import Constraint
11
- from bayinx.core.parameter import Parameter
10
+ from ._constraint import Constraint
11
+ from ._parameter import Parameter
12
12
 
13
13
 
14
14
  def constrain(constraint: Constraint):
15
15
  """Defines constraint metadata."""
16
- return field(metadata={'constraint': constraint})
16
+ return field(metadata={"constraint": constraint})
17
17
 
18
18
 
19
19
  class Model(eqx.Module):
@@ -49,12 +49,11 @@ class Model(eqx.Module):
49
49
  filter_spec = eqx.tree_at(
50
50
  lambda model: getattr(model, f.name),
51
51
  filter_spec,
52
- replace=attr.filter_spec
52
+ replace=attr.filter_spec,
53
53
  )
54
54
 
55
55
  return filter_spec
56
56
 
57
-
58
57
  @eqx.filter_jit
59
58
  def constrain_params(self) -> Tuple[Self, Scalar]:
60
59
  """
@@ -71,18 +70,16 @@ class Model(eqx.Module):
71
70
  attr = getattr(self, f.name)
72
71
 
73
72
  # Check if constrained parameter
74
- if isinstance(attr, Parameter) and 'constraint' in f.metadata:
73
+ if isinstance(attr, Parameter) and "constraint" in f.metadata:
75
74
  param = attr
76
- constraint = f.metadata['constraint']
75
+ constraint = f.metadata["constraint"]
77
76
 
78
77
  # Apply constraint
79
78
  param, laj = constraint.constrain(param)
80
79
 
81
80
  # Update parameters for constrained model
82
81
  constrained = eqx.tree_at(
83
- lambda model: getattr(model, f.name),
84
- constrained,
85
- replace=param
82
+ lambda model: getattr(model, f.name), constrained, replace=param
86
83
  )
87
84
 
88
85
  # Adjust posterior density
@@ -90,7 +87,6 @@ class Model(eqx.Module):
90
87
 
91
88
  return constrained, target
92
89
 
93
-
94
90
  @eqx.filter_jit
95
91
  def transform_params(self) -> Tuple[Self, Scalar]:
96
92
  """
@@ -4,7 +4,7 @@ import equinox as eqx
4
4
  import jax.tree as jt
5
5
  from jaxtyping import PyTree
6
6
 
7
- T = TypeVar('T', bound=PyTree)
7
+ T = TypeVar("T", bound=PyTree)
8
8
  class Parameter(eqx.Module, Generic[T]):
9
9
  """
10
10
  A container for a parameter of a `Model`.
@@ -14,8 +14,8 @@ class Parameter(eqx.Module, Generic[T]):
14
14
  # Attributes
15
15
  - `vals`: The parameter's value(s).
16
16
  """
17
- vals: T
18
17
 
18
+ vals: T
19
19
 
20
20
  def __init__(self, values: T):
21
21
  # Insert parameter values
@@ -1,6 +1,6 @@
1
1
  from abc import abstractmethod
2
2
  from functools import partial
3
- from typing import Any, Callable, Self, Tuple
3
+ from typing import Any, Callable, Generic, Self, Tuple, TypeVar
4
4
 
5
5
  import equinox as eqx
6
6
  import jax
@@ -11,10 +11,10 @@ import optax as opx
11
11
  from jaxtyping import Array, Key, PyTree, Scalar
12
12
  from optax import GradientTransformation, OptState, Schedule
13
13
 
14
- from bayinx.core import Model
14
+ from ._model import Model
15
15
 
16
-
17
- class Variational(eqx.Module):
16
+ M = TypeVar('M', bound=Model)
17
+ class Variational(eqx.Module, Generic[M]):
18
18
  """
19
19
  An abstract base class used to define variational methods.
20
20
 
@@ -23,8 +23,8 @@ class Variational(eqx.Module):
23
23
  - `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
24
24
  """
25
25
 
26
- _unflatten: Callable[[Array], Model]
27
- _constraints: Model
26
+ _unflatten: Callable[[Array], M]
27
+ _constraints: M
28
28
 
29
29
  @abstractmethod
30
30
  def filter_spec(self):
@@ -34,7 +34,7 @@ class Variational(eqx.Module):
34
34
  pass
35
35
 
36
36
  @abstractmethod
37
- def sample(self, n: int, key: Key) -> Array:
37
+ def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
38
38
  """
39
39
  Sample from the variational distribution.
40
40
  """
@@ -72,10 +72,10 @@ class Variational(eqx.Module):
72
72
  - `data`: Data used to evaluate the posterior(if needed).
73
73
  """
74
74
  # Unflatten variational draw
75
- model: Model = self._unflatten(draws)
75
+ model: M = self._unflatten(draws)
76
76
 
77
77
  # Combine with constraints
78
- model: Model = eqx.combine(model, self._constraints)
78
+ model: M = eqx.combine(model, self._constraints)
79
79
 
80
80
  # Evaluate posterior density
81
81
  return model.eval(data)
@@ -106,8 +106,8 @@ class Variational(eqx.Module):
106
106
  dyn, static = eqx.partition(self, self.filter_spec)
107
107
 
108
108
  # Construct scheduler
109
- schedule: Schedule = opx.cosine_decay_schedule(
110
- init_value=learning_rate, decay_steps=max_iters
109
+ schedule: Schedule = opx.warmup_cosine_decay_schedule(
110
+ init_value=1e-16, peak_value=learning_rate, warmup_steps=int(max_iters/10), decay_steps=max_iters-int(max_iters/10)
111
111
  )
112
112
 
113
113
  # Initialize optimizer
@@ -160,3 +160,22 @@ class Variational(eqx.Module):
160
160
 
161
161
  # Return optimized variational
162
162
  return eqx.combine(dyn, static)
163
+
164
+ @eqx.filter_jit
165
+ def posterior_predictive(
166
+ self, func: Callable[[M], Array], n: int, key: Key = jr.PRNGKey(0)
167
+ ) -> Array:
168
+ # Sample draws from the variational approximation
169
+ draws: Array = self.sample(n, key)
170
+
171
+ # Evaluate posterior predictive
172
+ @jax.jit
173
+ @jax.vmap
174
+ def evaluate(draw: Array):
175
+ # Reconstruct model
176
+ model: M = self._unflatten(draw)
177
+
178
+ # Evaluate
179
+ return func(model)
180
+
181
+ return evaluate(draws)
bayinx/dists/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from bayinx.dists import normal, posnormal
1
+ from bayinx.dists import censored, gamma2, normal, posnormal
2
2
 
3
- __all__ = ['normal', 'posnormal']
3
+ __all__ = ['censored', "gamma2", "normal", "posnormal"]
@@ -1,3 +1,3 @@
1
- from . import gamma2, posnormal
1
+ from . import posnormal
2
2
 
3
- __all__ = ['gamma2', 'posnormal']
3
+ __all__ = ["posnormal"]
@@ -1,3 +1,3 @@
1
1
  from . import r
2
2
 
3
- __all__ = ['r']
3
+ __all__ = ["r"]
@@ -10,7 +10,7 @@ def prob(
10
10
  x: Float[ArrayLike, "..."],
11
11
  mu: Float[ArrayLike, "..."],
12
12
  nu: Float[ArrayLike, "..."],
13
- censor: Float[ArrayLike, "..."]
13
+ censor: Float[ArrayLike, "..."],
14
14
  ) -> Float[Array, "..."]:
15
15
  """
16
16
  The mixed probability mass/density function (PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
@@ -24,15 +24,15 @@ def prob(
24
24
  # Returns
25
25
  The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
26
26
  """
27
- evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
27
+ evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
28
28
 
29
29
  # Construct boolean masks
30
- uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
31
- censored: Array = jnp.array(x == censor) # pyright: ignore
30
+ uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
31
+ censored: Array = jnp.array(x == censor) # pyright: ignore
32
32
 
33
33
  # Evaluate probability mass/density function
34
34
  evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
35
- evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
35
+ evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
36
36
 
37
37
  return evals
38
38
 
@@ -41,7 +41,7 @@ def logprob(
41
41
  x: Float[ArrayLike, "..."],
42
42
  mu: Float[ArrayLike, "..."],
43
43
  nu: Float[ArrayLike, "..."],
44
- censor: Float[ArrayLike, "..."]
44
+ censor: Float[ArrayLike, "..."],
45
45
  ) -> Float[Array, "..."]:
46
46
  """
47
47
  The log-transformed mixed probability mass/density function (log PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
@@ -55,14 +55,14 @@ def logprob(
55
55
  # Returns
56
56
  The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
57
57
  """
58
- evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
58
+ evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
59
59
 
60
60
  # Construct boolean masks
61
- uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
62
- censored: Array = jnp.array(x == censor) # pyright: ignore
61
+ uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
62
+ censored: Array = jnp.array(x == censor) # pyright: ignore
63
63
 
64
64
  # Evaluate log probability mass/density function
65
65
  evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
66
- evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
66
+ evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
67
67
 
68
68
  return evals
@@ -1,3 +1,3 @@
1
1
  from . import r
2
2
 
3
- __all__ = ['r']
3
+ __all__ = ["r"]
@@ -8,10 +8,10 @@ def prob(
8
8
  x: Float[ArrayLike, "..."],
9
9
  mu: Float[ArrayLike, "..."],
10
10
  sigma: Float[ArrayLike, "..."],
11
- censor: Float[ArrayLike, "..."]
11
+ censor: Float[ArrayLike, "..."],
12
12
  ) -> Float[Array, "..."]:
13
13
  """
14
- The mixed probability mass/density function (PMF/PDF) for a censored positive Normal distribution.
14
+ The mixed probability mass/density function (PMF/PDF) for a right-censored positive Normal distribution.
15
15
 
16
16
  # Parameters
17
17
  - `x`: Value(s) at which to evaluate the PMF/PDF.
@@ -23,7 +23,12 @@ def prob(
23
23
  The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
24
24
  """
25
25
  # Cast to Array
26
- x, mu, sigma, censor = jnp.array(x), jnp.array(mu), jnp.array(sigma), jnp.array(censor)
26
+ x, mu, sigma, censor = (
27
+ jnp.asarray(x),
28
+ jnp.asarray(mu),
29
+ jnp.asarray(sigma),
30
+ jnp.asarray(censor),
31
+ )
27
32
 
28
33
  # Construct boolean masks
29
34
  uncensored: Array = jnp.logical_and(0.0 < x, x < censor)
@@ -31,7 +36,7 @@ def prob(
31
36
 
32
37
  # Evaluate probability mass/density function
33
38
  evals = jnp.where(uncensored, posnormal.prob(x, mu, sigma), 0.0)
34
- evals = jnp.where(censored, posnormal.ccdf(x,mu,sigma), evals)
39
+ evals = jnp.where(censored, posnormal.ccdf(x, mu, sigma), evals)
35
40
 
36
41
  return evals
37
42
 
@@ -40,10 +45,10 @@ def logprob(
40
45
  x: Float[ArrayLike, "..."],
41
46
  mu: Float[ArrayLike, "..."],
42
47
  sigma: Float[ArrayLike, "..."],
43
- censor: Float[ArrayLike, "..."]
48
+ censor: Float[ArrayLike, "..."],
44
49
  ) -> Float[Array, "..."]:
45
50
  """
46
- The log-transformed mixed probability mass/density function (log PMF/PDF) for a censored positive Normal distribution.
51
+ The log-transformed mixed probability mass/density function (log PMF/PDF) for a right-censored positive Normal distribution.
47
52
 
48
53
  # Parameters
49
54
  - `x`: Where to evaluate the log PMF/PDF.
@@ -55,10 +60,15 @@ def logprob(
55
60
  The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
56
61
  """
57
62
  # Cast to Array
58
- x, mu, sigma, censor = jnp.array(x), jnp.array(mu), jnp.array(sigma), jnp.array(censor)
63
+ x, mu, sigma, censor = (
64
+ jnp.asarray(x),
65
+ jnp.asarray(mu),
66
+ jnp.asarray(sigma),
67
+ jnp.asarray(censor),
68
+ )
59
69
 
60
- # Construct boolean masks
61
- uncensored: Array = jnp.logical_and(jnp.array(0.0) < x, x < censor)
70
+ # Construct boolean masks for censoring
71
+ uncensored: Array = jnp.logical_and(jnp.asarray(0.0) < x, x < censor)
62
72
  censored: Array = x == censor
63
73
 
64
74
  # Evaluate log probability mass/density function
bayinx/dists/gamma2.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import jax.lax as lax
2
+ import jax.numpy as jnp
2
3
  from jax.scipy.special import gammaln
3
4
  from jaxtyping import Array, ArrayLike, Float
4
5
 
@@ -17,6 +18,8 @@ def prob(
17
18
  # Returns
18
19
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
19
20
  """
21
+ # Cast to Array
22
+ x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
20
23
 
21
24
  return lax.exp(logprob(x, mu, nu))
22
25
 
@@ -35,5 +38,12 @@ def logprob(
35
38
  # Returns
36
39
  The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
37
40
  """
38
-
39
- return - gammaln(nu) + nu * (lax.log(nu) - lax.log(mu)) + (nu - 1.0) * lax.log(x) - (x * nu / mu) # pyright: ignore
41
+ # Cast to Array
42
+ x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
43
+
44
+ return (
45
+ -gammaln(nu)
46
+ + nu * (lax.log(nu) - lax.log(mu))
47
+ + (nu - 1.0) * lax.log(x)
48
+ - (x * nu / mu)
49
+ ) # pyright: ignore
bayinx/dists/normal.py CHANGED
@@ -7,116 +7,132 @@ __PI = 3.141592653589793
7
7
 
8
8
 
9
9
  def prob(
10
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
10
+ x: Float[ArrayLike, "..."],
11
+ mu: Float[ArrayLike, "..."],
12
+ sigma: Float[ArrayLike, "..."],
11
13
  ) -> Float[Array, "..."]:
12
14
  """
13
15
  The probability density function (PDF) for a Normal distribution.
14
16
 
15
17
  # Parameters
16
- - `x`: Value(s) at which to evaluate the PDF.
17
- - `mu`: The mean/location.
18
- - `sigma`: The positive standard deviation.
18
+ - `x`: Where to evaluate the PDF.
19
+ - `mu`: The mean.
20
+ - `sigma`: The standard deviation.
19
21
 
20
22
  # Returns
21
23
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
22
24
  """
23
25
  # Cast to Array
24
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
26
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
25
27
 
26
- return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (
27
- sigma * lax.sqrt(2.0 * __PI)
28
- )
28
+ return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (sigma * lax.sqrt(2.0 * __PI))
29
29
 
30
30
 
31
31
  def logprob(
32
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
32
+ x: Float[ArrayLike, "..."],
33
+ mu: Float[ArrayLike, "..."],
34
+ sigma: Float[ArrayLike, "..."],
33
35
  ) -> Float[Array, "..."]:
34
36
  """
35
37
  The log of the probability density function (log PDF) for a Normal distribution.
36
38
 
37
39
  # Parameters
38
- - `x`: Value(s) at which to evaluate the log PDF.
39
- - `mu`: The mean/location parameter(s).
40
- - `sigma`: The non-negative standard deviation parameter(s).
40
+ - `x`: Where to evaluate the log PDF.
41
+ - `mu`: The mean.
42
+ - `sigma`: The standard deviation.
41
43
 
42
44
  # Returns
43
45
  The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
44
46
  """
45
47
  # Cast to Array
46
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
48
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
47
49
 
48
- return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square(
49
- (x - mu) / sigma
50
- )
50
+ return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square((x - mu) / sigma)
51
51
 
52
52
 
53
53
  def uprob(
54
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
54
+ x: Float[ArrayLike, "..."],
55
+ mu: Float[ArrayLike, "..."],
56
+ sigma: Float[ArrayLike, "..."],
55
57
  ) -> Float[Array, "..."]:
56
58
  """
57
59
  The unnormalized probability density function (uPDF) for a Normal distribution.
58
60
 
59
61
  # Parameters
60
- - `x`: Value(s) at which to evaluate the uPDF.
61
- - `mu`: The mean/location parameter(s).
62
- - `sigma`: The positive standard deviation parameter(s).
62
+ - `x`: Where to evaluate the PDF.
63
+ - `mu`: The mean.
64
+ - `sigma`: The standard deviation.
63
65
 
64
66
  # Returns
65
67
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
66
68
  """
67
69
  # Cast to Array
68
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
70
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
69
71
 
70
72
  return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma
71
73
 
72
74
 
73
75
  def ulogprob(
74
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
76
+ x: Float[ArrayLike, "..."],
77
+ mu: Float[ArrayLike, "..."],
78
+ sigma: Float[ArrayLike, "..."],
75
79
  ) -> Float[Array, "..."]:
76
80
  """
77
81
  The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
78
82
 
79
83
  # Parameters
80
- - `x`: Value(s) at which to evaluate the log uPDF.
81
- - `mu`: The mean/location parameter(s).
82
- - `sigma`: The non-negative standard deviation parameter(s).
84
+ - `x`: Where to evaluate the PDF.
85
+ - `mu`: The mean.
86
+ - `sigma`: The standard deviation.
83
87
 
84
88
  # Returns
85
89
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
86
90
  """
87
91
  # Cast to Array
88
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
92
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
89
93
 
90
94
  return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
91
95
 
96
+
92
97
  def cdf(
93
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
98
+ x: Float[ArrayLike, "..."],
99
+ mu: Float[ArrayLike, "..."],
100
+ sigma: Float[ArrayLike, "..."],
94
101
  ) -> Float[Array, "..."]:
95
102
  # Cast to Array
96
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
103
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
97
104
 
98
105
  return jss.ndtr((x - mu) / sigma)
99
106
 
107
+
100
108
  def logcdf(
101
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
109
+ x: Float[ArrayLike, "..."],
110
+ mu: Float[ArrayLike, "..."],
111
+ sigma: Float[ArrayLike, "..."],
102
112
  ) -> Float[Array, "..."]:
103
113
  # Cast to Array
104
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
114
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
105
115
 
106
116
  return jss.log_ndtr((x - mu) / sigma)
107
117
 
118
+
108
119
  def ccdf(
109
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
120
+ x: Float[ArrayLike, "..."],
121
+ mu: Float[ArrayLike, "..."],
122
+ sigma: Float[ArrayLike, "..."],
110
123
  ) -> Float[Array, "..."]:
111
124
  # Cast to Array
112
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
125
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
113
126
 
114
127
  return jss.ndtr((mu - x) / sigma)
115
128
 
129
+
116
130
  def logccdf(
117
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
131
+ x: Float[ArrayLike, "..."],
132
+ mu: Float[ArrayLike, "..."],
133
+ sigma: Float[ArrayLike, "..."],
118
134
  ) -> Float[Array, "..."]:
119
135
  # Cast to Array
120
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
136
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
121
137
 
122
138
  return jss.log_ndtr((mu - x) / sigma)
bayinx/dists/posnormal.py CHANGED
@@ -5,13 +5,15 @@ from bayinx.dists import normal
5
5
 
6
6
 
7
7
  def prob(
8
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
8
+ x: Float[ArrayLike, "..."],
9
+ mu: Float[ArrayLike, "..."],
10
+ sigma: Float[ArrayLike, "..."],
9
11
  ) -> Float[Array, "..."]:
10
12
  """
11
13
  The probability density function (PDF) for a positive Normal distribution.
12
14
 
13
15
  # Parameters
14
- - `x`: Value(s) at which to evaluate the PDF.
16
+ - `x`: Where to evaluate the PDF.
15
17
  - `mu`: The mean.
16
18
  - `sigma`: The standard deviation.
17
19
 
@@ -19,28 +21,31 @@ def prob(
19
21
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
20
22
  """
21
23
  # Cast to Array
22
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
24
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
23
25
 
24
26
  # Construct boolean mask for non-negative elements
25
- non_negative: Array = jnp.array(0.0) <= x
27
+ non_negative: Array = jnp.asarray(0.0) <= x
26
28
 
27
29
  # Evaluate PDF
28
30
  evals = jnp.where(
29
31
  non_negative,
30
- normal.prob(x, mu, sigma) / normal.cdf(mu/sigma, 0.0, 1.0),
31
- jnp.array(0.0))
32
+ normal.prob(x, mu, sigma) / normal.cdf(mu / sigma, 0.0, 1.0),
33
+ jnp.asarray(0.0),
34
+ )
32
35
 
33
36
  return evals
34
37
 
35
38
 
36
39
  def logprob(
37
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
40
+ x: Float[ArrayLike, "..."],
41
+ mu: Float[ArrayLike, "..."],
42
+ sigma: Float[ArrayLike, "..."],
38
43
  ) -> Float[Array, "..."]:
39
44
  """
40
45
  The log of the probability density function (log PDF) for a positive Normal distribution.
41
46
 
42
47
  # Parameters
43
- - `x`: Value(s) at which to evaluate the log PDF.
48
+ - `x`: Where to evaluate the log PDF.
44
49
  - `mu`: The mean.
45
50
  - `sigma`: The standard deviation.
46
51
 
@@ -48,88 +53,89 @@ def logprob(
48
53
  The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
49
54
  """
50
55
  # Cast to Array
51
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
56
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
52
57
 
53
58
  # Construct boolean mask for non-negative elements
54
- non_negative: Array = jnp.array(0.0) <= x
59
+ non_negative: Array = jnp.asarray(0.0) <= x
55
60
 
56
61
  # Evaluate log PDF
57
62
  evals = jnp.where(
58
63
  non_negative,
59
- normal.logprob(x, mu, sigma) - normal.logcdf(mu/sigma, 0.0, 1.0),
60
- -jnp.inf)
64
+ normal.logprob(x, mu, sigma) - normal.logcdf(mu / sigma, 0.0, 1.0),
65
+ -jnp.inf,
66
+ )
61
67
 
62
68
  return evals
63
69
 
64
70
 
65
71
  def uprob(
66
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
72
+ x: Float[ArrayLike, "..."],
73
+ mu: Float[ArrayLike, "..."],
74
+ sigma: Float[ArrayLike, "..."],
67
75
  ) -> Float[Array, "..."]:
68
76
  """
69
77
  The unnormalized probability density function (uPDF) for a positive Normal distribution.
70
78
 
71
79
  # Parameters
72
- - `x`: Value(s) at which to evaluate the uPDF.
73
- - `mu`: The mean/location parameter(s).
74
- - `sigma`: The positive standard deviation parameter(s).
80
+ - `x`: Where to evaluate the uPDF.
81
+ - `mu`: The mean.
82
+ - `sigma`: The standard deviation.
75
83
 
76
84
  # Returns
77
85
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
78
86
  """
79
87
  # Cast to Array
80
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
88
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
81
89
 
82
90
  # Construct boolean mask for non-negative elements
83
- non_negative: Array = jnp.array(0.0) <= x
91
+ non_negative: Array = jnp.asarray(0.0) <= x
84
92
 
85
93
  # Evaluate PDF
86
- evals = jnp.where(
87
- non_negative,
88
- normal.prob(x, mu, sigma),
89
- jnp.array(0.0))
94
+ evals = jnp.where(non_negative, normal.prob(x, mu, sigma), jnp.asarray(0.0))
90
95
 
91
96
  return evals
92
97
 
93
98
 
94
99
  def ulogprob(
95
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
100
+ x: Float[ArrayLike, "..."],
101
+ mu: Float[ArrayLike, "..."],
102
+ sigma: Float[ArrayLike, "..."],
96
103
  ) -> Float[Array, "..."]:
97
104
  """
98
105
  The log of the unnormalized probability density function (log uPDF) for a positive Normal distribution.
99
106
 
100
107
  # Parameters
101
- - `x`: Value(s) at which to evaluate the log uPDF.
102
- - `mu`: The mean/location parameter(s).
103
- - `sigma`: The non-negative standard deviation parameter(s).
108
+ - `x`: Where to evaluate the log uPDF.
109
+ - `mu`: The mean.
110
+ - `sigma`: The standard deviation.
104
111
 
105
112
  # Returns
106
113
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
107
114
  """
108
115
  # Cast to Array
109
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
116
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
110
117
 
111
118
  # Construct boolean mask for non-negative elements
112
- non_negative: Array = jnp.array(0.0) <= x
119
+ non_negative: Array = jnp.asarray(0.0) <= x
113
120
 
114
121
  # Evaluate log PDF
115
- evals = jnp.where(
116
- non_negative,
117
- normal.logprob(x, mu, sigma),
118
- -jnp.inf)
122
+ evals = jnp.where(non_negative, normal.logprob(x, mu, sigma), -jnp.inf)
119
123
 
120
124
  return evals
121
125
 
122
126
 
123
127
  def cdf(
124
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
128
+ x: Float[ArrayLike, "..."],
129
+ mu: Float[ArrayLike, "..."],
130
+ sigma: Float[ArrayLike, "..."],
125
131
  ) -> Float[Array, "..."]:
126
132
  """
127
133
  The cumulative density function (CDF) for a positive Normal distribution.
128
134
 
129
135
  # Parameters
130
- - `x`: Value(s) at which to evaluate the log uPDF.
131
- - `mu`: The mean/location parameter(s).
132
- - `sigma`: The non-negative standard deviation parameter(s).
136
+ - `x`: Where to evaluate the CDF.
137
+ - `mu`: The mean.
138
+ - `sigma`: The standard deviation.
133
139
 
134
140
  # Returns
135
141
  The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
@@ -138,35 +144,35 @@ def cdf(
138
144
  Not numerically stable for small `x`.
139
145
  """
140
146
  # Cast to Array
141
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
147
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
142
148
 
143
149
  # Construct boolean mask for non-negative elements
144
- non_negative: Array = jnp.array(0.0) <= x
150
+ non_negative: Array = jnp.asarray(0.0) <= x
145
151
 
146
152
  # Compute intermediates
147
153
  A: Array = normal.cdf(x, mu, sigma)
148
- B: Array = normal.cdf(- mu / sigma, 0.0, 1.0)
154
+ B: Array = normal.cdf(-mu / sigma, 0.0, 1.0)
149
155
  C: Array = normal.cdf(mu / sigma, 0.0, 1.0)
150
156
 
151
157
  # Evaluate CDF
152
- evals = jnp.where(
153
- non_negative,
154
- (A - B) / C,
155
- jnp.array(0.0))
158
+ evals = jnp.where(non_negative, (A - B) / C, jnp.asarray(0.0))
156
159
 
157
160
  return evals
158
161
 
162
+
159
163
  # TODO: make numerically stable
160
164
  def logcdf(
161
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
165
+ x: Float[ArrayLike, "..."],
166
+ mu: Float[ArrayLike, "..."],
167
+ sigma: Float[ArrayLike, "..."],
162
168
  ) -> Float[Array, "..."]:
163
169
  """
164
170
  The log-transformed cumulative density function (log CDF) for a positive Normal distribution.
165
171
 
166
172
  # Parameters
167
- - `x`: Value(s) at which to evaluate the log uPDF.
168
- - `mu`: The mean/location parameter(s).
169
- - `sigma`: The non-negative standard deviation parameter(s).
173
+ - `x`: Where to evaluate the log CDF.
174
+ - `mu`: The mean.
175
+ - `sigma`: The standard deviation.
170
176
 
171
177
  # Returns
172
178
  The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
@@ -175,84 +181,80 @@ def logcdf(
175
181
  Not numerically stable for small `x`.
176
182
  """
177
183
  # Cast to Array
178
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
184
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
179
185
 
180
186
  # Construct boolean mask for non-negative elements
181
- non_negative: Array = jnp.array(0.0) <= x
187
+ non_negative: Array = jnp.asarray(0.0) <= x
182
188
 
183
189
  A: Array = normal.logcdf(x, mu, sigma)
184
- B: Array = normal.logcdf(- mu/sigma, 0.0, 1.0)
185
- C: Array = normal.logcdf(mu/sigma, 0.0, 1.0)
190
+ B: Array = normal.logcdf(-mu / sigma, 0.0, 1.0)
191
+ C: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
186
192
 
187
193
  # Evaluate log CDF
188
- evals = jnp.where(
189
- non_negative,
190
- A + jnp.log1p(-jnp.exp(B - A)) - C,
191
- -jnp.inf)
194
+ evals = jnp.where(non_negative, A + jnp.log1p(-jnp.exp(B - A)) - C, -jnp.inf)
192
195
 
193
196
  return evals
194
197
 
198
+
195
199
  def ccdf(
196
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
200
+ x: Float[ArrayLike, "..."],
201
+ mu: Float[ArrayLike, "..."],
202
+ sigma: Float[ArrayLike, "..."],
197
203
  ) -> Float[Array, "..."]:
198
204
  """
199
205
  The complementary cumulative density function (cCDF) for a positive Normal distribution.
200
206
 
201
207
  # Parameters
202
- - `x`: Value(s) at which to evaluate the log uPDF.
203
- - `mu`: The mean/location parameter(s).
204
- - `sigma`: The non-negative standard deviation parameter(s).
208
+ - `x`: Where to evaluate the cCDF.
209
+ - `mu`: The mean.
210
+ - `sigma`: The standard deviation.
205
211
 
206
212
  # Returns
207
213
  The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
208
-
209
- # Notes
210
- Not numerically stable for small `x`.
211
214
  """
212
215
  # Cast to arrays
213
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
216
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
214
217
 
215
218
  # Construct boolean mask for non-negative elements
216
219
  non_negative: Array = 0.0 <= x
217
220
 
218
221
  # Compute intermediates
219
222
  A: Array = normal.cdf(-x, -mu, sigma)
220
- B: Array = normal.cdf(mu/sigma, 0.0, 1.0)
223
+ B: Array = normal.cdf(mu / sigma, 0.0, 1.0)
221
224
 
222
225
  # Evaluate cCDF
223
- evals = jnp.where(non_negative, A / B, jnp.array(1.0))
226
+ evals = jnp.where(non_negative, A / B, jnp.asarray(1.0))
224
227
 
225
228
  return evals
226
229
 
227
230
 
228
231
  def logccdf(
229
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], sigma: Float[ArrayLike, "..."]
232
+ x: Float[ArrayLike, "..."],
233
+ mu: Float[ArrayLike, "..."],
234
+ sigma: Float[ArrayLike, "..."],
230
235
  ) -> Float[Array, "..."]:
231
236
  """
232
237
  The log-transformed complementary cumulative density function (log cCDF) for a positive Normal distribution.
233
238
 
234
239
  # Parameters
235
- - `x`: Value(s) at which to evaluate the log uPDF.
236
- - `mu`: The mean/location parameter(s).
237
- - `sigma`: The non-negative standard deviation parameter(s).
240
+ - `x`: Where to evaluate the log cCDF.
241
+ - `mu`: The mean.
242
+ - `sigma`: The standard deviation.
238
243
 
239
244
  # Returns
240
245
  The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
241
-
242
- # Notes
243
- Not numerically stable for small `x`.
244
246
  """
245
247
  # Cast to arrays
246
- x, mu, sigma = jnp.array(x), jnp.array(mu), jnp.array(sigma)
248
+ x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
247
249
 
248
250
  # Construct boolean mask for non-negative elements
249
251
  non_negative: Array = 0.0 <= x
250
252
 
251
253
  # Compute intermediates
252
254
  A: Array = normal.logcdf(-x, -mu, sigma)
253
- B: Array = normal.logcdf(mu/sigma, 0.0, 1.0)
255
+ B: Array = normal.logcdf(mu / sigma, 0.0, 1.0)
254
256
 
255
257
  # Evaluate log cCDF
256
- evals = jnp.where(non_negative, A - B, jnp.array(0.0))
258
+ evals = jnp.where(non_negative, A - B, jnp.asarray(0.0))
257
259
 
258
260
  return evals
bayinx/mhx/vi/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
- from bayinx.mhx.vi.meanfield import MeanField as MeanField
2
- from bayinx.mhx.vi.normalizing_flow import NormalizingFlow as NormalizingFlow
3
- from bayinx.mhx.vi.standard import Standard as Standard
1
+ from bayinx.mhx.vi.meanfield import MeanField
2
+ from bayinx.mhx.vi.normalizing_flow import NormalizingFlow
3
+ from bayinx.mhx.vi.standard import Standard
4
+
5
+ __all__ = ['MeanField', 'NormalizingFlow', 'Standard']
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Self
1
+ from typing import Any, Dict, Generic, Self, TypeVar
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.numpy as jnp
@@ -10,8 +10,8 @@ from jaxtyping import Array, Float, Key, Scalar
10
10
  from bayinx.core import Model, Variational
11
11
  from bayinx.dists import normal
12
12
 
13
-
14
- class MeanField(Variational):
13
+ M = TypeVar('M', bound=Model)
14
+ class MeanField(Variational, Generic[M]):
15
15
  """
16
16
  A fully factorized Gaussian approximation to a posterior distribution.
17
17
 
@@ -19,9 +19,9 @@ class MeanField(Variational):
19
19
  - `var_params`: The variational parameters for the approximation.
20
20
  """
21
21
 
22
- var_params: Dict[str, Float[Array, "..."]]
22
+ var_params: Dict[str, Float[Array, "..."]] #todo: just expand to attributes
23
23
 
24
- def __init__(self, model: Model):
24
+ def __init__(self, model: M):
25
25
  """
26
26
  Constructs an unoptimized meanfield posterior approximation.
27
27
 
@@ -55,7 +55,6 @@ class MeanField(Variational):
55
55
 
56
56
  return filter_spec
57
57
 
58
-
59
58
  @eqx.filter_jit
60
59
  def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
61
60
  # Sample variational draws
@@ -1,4 +1,4 @@
1
- from typing import Any, Self, Tuple
1
+ from typing import Any, Generic, Self, Tuple, TypeVar
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.flatten_util as jfu
@@ -9,8 +9,8 @@ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Flow, Model, Variational
11
11
 
12
-
13
- class NormalizingFlow(Variational):
12
+ M = TypeVar('M', bound=Model)
13
+ class NormalizingFlow(Variational, Generic[M]):
14
14
  """
15
15
  An ordered collection of diffeomorphisms that map a base distribution to a
16
16
  normalized approximation of a posterior distribution.
@@ -23,7 +23,7 @@ class NormalizingFlow(Variational):
23
23
  flows: list[Flow]
24
24
  base: Variational
25
25
 
26
- def __init__(self, base: Variational, flows: list[Flow], model: Model):
26
+ def __init__(self, base: Variational, flows: list[Flow], model: M):
27
27
  """
28
28
  Constructs an unoptimized normalizing flow posterior approximation.
29
29
 
bayinx/mhx/vi/standard.py CHANGED
@@ -1,29 +1,25 @@
1
- from typing import Callable
2
1
 
3
2
  import equinox as eqx
4
3
  import jax.numpy as jnp
5
4
  import jax.random as jr
6
5
  import jax.tree_util as jtu
7
6
  from jax.flatten_util import ravel_pytree
8
- from jaxtyping import Array, Float, Key
7
+ from jaxtyping import Array, Key
9
8
 
10
- from bayinx.core import Model, Variational
9
+ from bayinx.core._variational import M, Variational
11
10
  from bayinx.dists import normal
12
11
 
13
12
 
14
- class Standard(Variational):
13
+ class Standard(Variational[M]):
15
14
  """
16
15
  A standard normal approximation to a posterior distribution.
17
16
 
18
17
  # Attributes
19
18
  - `dim`: Dimension of the parameter space.
20
19
  """
21
-
22
20
  dim: int
23
- _unflatten: Callable[[Float[Array, "..."]], Model]
24
- _constraints: Model
25
21
 
26
- def __init__(self, model: Model):
22
+ def __init__(self, model: M):
27
23
  """
28
24
  Constructs a standard normal approximation to a posterior distribution.
29
25
 
@@ -1,7 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.5
3
+ Version: 0.3.7
4
4
  Summary: Bayesian Inference with JAX
5
+ License-File: LICENSE
5
6
  Requires-Python: >=3.12
6
7
  Requires-Dist: equinox>=0.11.12
7
8
  Requires-Dist: jax>=0.4.38
@@ -0,0 +1,35 @@
1
+ bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
2
+ bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ bayinx/constraints/__init__.py,sha256=PiWXZKi7YdbTMKvw-OE5f-t87jJT893uAFrwWWBfOdg,64
4
+ bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
5
+ bayinx/core/__init__.py,sha256=bZvQITgW0DWuPKl3wCLKt6WHKogYKx8Zz36g8z9Aung,253
6
+ bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
7
+ bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
+ bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
9
+ bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
10
+ bayinx/core/_variational.py,sha256=b7xlUcw8JDDBfXgDLMcjsOMHpFZ2Tg3sEt965eWmctI,5431
11
+ bayinx/dists/__init__.py,sha256=9DdPea7HAnBOzaV_4gM5noPX8YCb_p06d8PJvGfFy3Y,118
12
+ bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
13
+ bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
14
+ bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
15
+ bayinx/dists/posnormal.py,sha256=w9plA1EctXwXOiY0doc4ZndjnwptbEZBHHCGdc4gviY,7292
16
+ bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
17
+ bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
18
+ bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
19
+ bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
20
+ bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
21
+ bayinx/dists/censored/posnormal/r.py,sha256=hyuNR3HZY-Tgtso-WwjcZT6Ejxfyax_VKwIvVix44Jc,2362
22
+ bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
23
+ bayinx/mhx/vi/__init__.py,sha256=2woNB5oZxfs8pZCkOfzriGahRFLzkLdkTj8_keTN0I0,205
24
+ bayinx/mhx/vi/meanfield.py,sha256=Z7kGQAyp5iB8rEdjbwAbVTFH4GwxlTKDZFbdJ-FN5Vs,3739
25
+ bayinx/mhx/vi/normalizing_flow.py,sha256=8pLMDdZPIt5wlgbhHWSFY1ChSWM9pvSD2bQx3zgz1F8,4710
26
+ bayinx/mhx/vi/standard.py,sha256=W-ZvigJkUpqVlREgiFm9io8ansT1XpZwq5AqSmdv--E,1578
27
+ bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
28
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
29
+ bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
30
+ bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
31
+ bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
32
+ bayinx-0.3.7.dist-info/METADATA,sha256=bQGouAjty73m1UeFCOWgRMw7Is0ffja7xDXAyS-EzDM,3079
33
+ bayinx-0.3.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
34
+ bayinx-0.3.7.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
35
+ bayinx-0.3.7.dist-info/RECORD,,
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Todd McCready
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -1,34 +0,0 @@
1
- bayinx/__init__.py,sha256=5fb_tGeEVnrNt6IQqu7gZaJskBJHqjcg08JRPrY2ANo,139
2
- bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx/constraints/__init__.py,sha256=PSxvcuSox2JL61AG1iag2PTNKPcid_DbOQzHpYdj5RE,52
4
- bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
5
- bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
6
- bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
7
- bayinx/core/flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
- bayinx/core/model.py,sha256=1vQPVjE0ebCdW7mLuabgQcCTi95o8n8CC6GuzJdNL1s,2956
9
- bayinx/core/parameter.py,sha256=eECqvfMNWSU8_CkGYaAfOCneMMQGZI21kF0mErsh2Rc,1080
10
- bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
11
- bayinx/dists/__init__.py,sha256=qPQrl5vkS9K56GzIaHZXkSUP07YAu4lVB8K2yQ1m3SY,78
12
- bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
13
- bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
14
- bayinx/dists/normal.py,sha256=BLlp7hGMAxUbroROvzA5ChH5YLXgadeK4VOuBtjjdjs,3978
15
- bayinx/dists/posnormal.py,sha256=NNr5OHv1fWCxYvc6hwUMIGXX31UAg0sEnc4tsxHLjUg,7726
16
- bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
17
- bayinx/dists/censored/__init__.py,sha256=p8T03TenD-_8YNiOgB_RKksq8hFNFejA5bnoK4JJ8Ms,67
18
- bayinx/dists/censored/gamma2/__init__.py,sha256=qqm0n2hfid617PvyFRHAOMAp3AvpOlt5v3ns8HgTD7E,33
19
- bayinx/dists/censored/gamma2/r.py,sha256=dE0MNTAl0E6npQhFONv341U7XbomBB-fNzQhgRjxYpk,2436
20
- bayinx/dists/censored/posnormal/__init__.py,sha256=qqm0n2hfid617PvyFRHAOMAp3AvpOlt5v3ns8HgTD7E,33
21
- bayinx/dists/censored/posnormal/r.py,sha256=4MfFkQ2klzOZJNjxS9g4zz1bdoJ6ehBxZQi6QkmPGgE,2232
22
- bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
23
- bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
24
- bayinx/mhx/vi/meanfield.py,sha256=M4QrOuHaIMLTuQSD6JNF9vELnTm370tXV68JPB7B67M,3652
25
- bayinx/mhx/vi/normalizing_flow.py,sha256=9c5ayMJ_Wgq6pUb1GYHIFIzw3Bf1AsIdOjcerLoYMrA,4655
26
- bayinx/mhx/vi/standard.py,sha256=DfSV0r9oXzp9UM8OsZBpoJPRUhiDoAq_X2_2z_M83lA,1685
27
- bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
28
- bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
29
- bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
30
- bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
31
- bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
32
- bayinx-0.3.5.dist-info/METADATA,sha256=Hj8GWJef3kfJ6umsHGIFWovYXXtPegAlcsopunoHFFs,3057
33
- bayinx-0.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
34
- bayinx-0.3.5.dist-info/RECORD,,
File without changes
File without changes