bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/core/_parameter.py DELETED
@@ -1,44 +0,0 @@
1
- from typing import Generic, Self, TypeVar
2
-
3
- import equinox as eqx
4
- import jax.tree as jt
5
- from jaxtyping import PyTree
6
-
7
- T = TypeVar("T", bound=PyTree)
8
- class Parameter(eqx.Module, Generic[T]):
9
- """
10
- A container for a parameter of a `Model`.
11
-
12
- Subclasses can be constructed for custom filter specifications(`filter_spec`).
13
-
14
- # Attributes
15
- - `vals`: The parameter's value(s).
16
- """
17
-
18
- vals: T
19
-
20
- def __init__(self, values: T):
21
- # Insert parameter values
22
- self.vals = values
23
-
24
- def __call__(self) -> T:
25
- return self.vals
26
-
27
- # Default filter specification
28
- @property
29
- @eqx.filter_jit
30
- def filter_spec(self) -> Self:
31
- """
32
- Generates a filter specification to filter out static parameters.
33
- """
34
- # Generate empty specification
35
- filter_spec = jt.map(lambda _: False, self)
36
-
37
- # Specify Array leaves
38
- filter_spec = eqx.tree_at(
39
- lambda params: params.vals,
40
- filter_spec,
41
- replace=jt.map(eqx.is_array_like, self.vals),
42
- )
43
-
44
- return filter_spec
@@ -1,181 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
- from typing import Any, Callable, Generic, Self, Tuple, TypeVar
4
-
5
- import equinox as eqx
6
- import jax
7
- import jax.lax as lax
8
- import jax.numpy as jnp
9
- import jax.random as jr
10
- import optax as opx
11
- from jaxtyping import Array, Key, PyTree, Scalar
12
- from optax import GradientTransformation, OptState, Schedule
13
-
14
- from ._model import Model
15
-
16
- M = TypeVar('M', bound=Model)
17
- class Variational(eqx.Module, Generic[M]):
18
- """
19
- An abstract base class used to define variational methods.
20
-
21
- # Attributes
22
- - `_unflatten`: A function to transform draws from the variational distribution back to a `Model`.
23
- - `_constraints`: The static component of a partitioned `Model` used to initialize the `Variational` object.
24
- """
25
-
26
- _unflatten: Callable[[Array], M]
27
- _constraints: M
28
-
29
- @abstractmethod
30
- def filter_spec(self):
31
- """
32
- Filter specification for dynamic and static components of the `Variational`.
33
- """
34
- pass
35
-
36
- @abstractmethod
37
- def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
38
- """
39
- Sample from the variational distribution.
40
- """
41
- pass
42
-
43
- @abstractmethod
44
- def eval(self, draws: Array) -> Array:
45
- """
46
- Evaluate the variational distribution at `draws`.
47
- """
48
- pass
49
-
50
- @abstractmethod
51
- def elbo(self, n: int, key: Key, data: Any = None) -> Array:
52
- """
53
- Evaluate the ELBO.
54
- """
55
- pass
56
-
57
- @abstractmethod
58
- def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
59
- """
60
- Evaluate the gradient of the ELBO.
61
- """
62
- pass
63
-
64
- @eqx.filter_jit
65
- @partial(jax.vmap, in_axes=(None, 0, None))
66
- def eval_model(self, draws: Array, data: Any = None) -> Array:
67
- """
68
- Reconstruct models from variational draws and evaluate their posterior density.
69
-
70
- # Parameters
71
- - `draws`: A set of variational draws.
72
- - `data`: Data used to evaluate the posterior(if needed).
73
- """
74
- # Unflatten variational draw
75
- model: M = self._unflatten(draws)
76
-
77
- # Combine with constraints
78
- model: M = eqx.combine(model, self._constraints)
79
-
80
- # Evaluate posterior density
81
- return model.eval(data)
82
-
83
- @eqx.filter_jit
84
- def fit(
85
- self,
86
- max_iters: int,
87
- data: Any = None,
88
- learning_rate: float = 1,
89
- weight_decay: float = 1e-4,
90
- tolerance: float = 1e-4,
91
- var_draws: int = 1,
92
- key: Key = jr.PRNGKey(0),
93
- ) -> Self:
94
- """
95
- Optimize the variational distribution.
96
-
97
- # Parameters
98
- - `max_iters`: Maximum number of iterations for the optimization loop.
99
- - `data`: Data to evaluate the posterior density with(if available).
100
- - `learning_rate`: Initial learning rate for optimization.
101
- - `tolerance`: Relative tolerance of ELBO decrease for stopping early.
102
- - `var_draws`: Number of variational draws to draw each iteration.
103
- - `key`: A PRNG key.
104
- """
105
- # Partition variational
106
- dyn, static = eqx.partition(self, self.filter_spec)
107
-
108
- # Construct scheduler
109
- schedule: Schedule = opx.warmup_cosine_decay_schedule(
110
- init_value=1e-16, peak_value=learning_rate, warmup_steps=int(max_iters/10), decay_steps=max_iters-int(max_iters/10)
111
- )
112
-
113
- # Initialize optimizer
114
- optim: GradientTransformation = opx.chain(
115
- opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
116
- )
117
- opt_state: OptState = optim.init(dyn)
118
-
119
- # Optimization loop helper functions
120
- @eqx.filter_jit
121
- def condition(state: Tuple[Self, OptState, Scalar, Key]):
122
- # Unpack iteration state
123
- dyn, opt_state, i, key = state
124
-
125
- return i < max_iters
126
-
127
- @eqx.filter_jit
128
- def body(state: Tuple[Self, OptState, Scalar, Key]):
129
- # Unpack iteration state
130
- dyn, opt_state, i, key = state
131
-
132
- # Update iteration
133
- i = i + 1
134
-
135
- # Update PRNG key
136
- key, _ = jr.split(key)
137
-
138
- # Reconstruct variational
139
- vari = eqx.combine(dyn, static)
140
-
141
- # Compute gradient of the ELBO
142
- updates: PyTree = vari.elbo_grad(var_draws, key, data)
143
-
144
- # Compute updates
145
- updates, opt_state = optim.update(
146
- updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
147
- )
148
-
149
- # Update variational distribution
150
- dyn = eqx.apply_updates(dyn, updates)
151
-
152
- return dyn, opt_state, i, key
153
-
154
- # Run optimization loop
155
- dyn = lax.while_loop(
156
- cond_fun=condition,
157
- body_fun=body,
158
- init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
159
- )[0]
160
-
161
- # Return optimized variational
162
- return eqx.combine(dyn, static)
163
-
164
- @eqx.filter_jit
165
- def posterior_predictive(
166
- self, func: Callable[[M, Any], Array], n: int, data: Any = None, key: Key = jr.PRNGKey(0)
167
- ) -> Array:
168
- # Sample draws from the variational approximation
169
- draws: Array = self.sample(n, key)
170
-
171
- # Evaluate posterior predictive
172
- @jax.jit
173
- @jax.vmap
174
- def evaluate(draw: Array, data: Any = None):
175
- # Reconstruct model
176
- model: M = self._unflatten(draw)
177
-
178
- # Evaluate
179
- return func(model, data)
180
-
181
- return evaluate(draws, data)
@@ -1,3 +0,0 @@
1
- from . import posnormal
2
-
3
- __all__ = ["posnormal"]
@@ -1,3 +0,0 @@
1
- from . import r
2
-
3
- __all__ = ["r"]
@@ -1,68 +0,0 @@
1
- import jax.lax as lax
2
- import jax.numpy as jnp
3
- from jax.scipy.special import gammaincc
4
- from jaxtyping import Array, ArrayLike, Float
5
-
6
- from bayinx.dists import gamma2
7
-
8
-
9
- def prob(
10
- x: Float[ArrayLike, "..."],
11
- mu: Float[ArrayLike, "..."],
12
- nu: Float[ArrayLike, "..."],
13
- censor: Float[ArrayLike, "..."],
14
- ) -> Float[Array, "..."]:
15
- """
16
- The mixed probability mass/density function (PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
17
-
18
- # Parameters
19
- - `x`: Value(s) at which to evaluate the PMF/PDF.
20
- - `mu`: The positive mean.
21
- - `nu`: The positive inverse dispersion.
22
- - `censor`: The positive censor value.
23
-
24
- # Returns
25
- The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
26
- """
27
- evals: Array = jnp.zeros_like(x * 1.0) # ensure float dtype
28
-
29
- # Construct boolean masks
30
- uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
31
- censored: Array = jnp.array(x == censor) # pyright: ignore
32
-
33
- # Evaluate probability mass/density function
34
- evals = jnp.where(uncensored, gamma2.prob(x, mu, nu), evals)
35
- evals = jnp.where(censored, gammaincc(nu, x * nu / mu), evals) # pyright: ignore
36
-
37
- return evals
38
-
39
-
40
- def logprob(
41
- x: Float[ArrayLike, "..."],
42
- mu: Float[ArrayLike, "..."],
43
- nu: Float[ArrayLike, "..."],
44
- censor: Float[ArrayLike, "..."],
45
- ) -> Float[Array, "..."]:
46
- """
47
- The log-transformed mixed probability mass/density function (log PMF/PDF) for a (mean-inverse dispersion parameterized) Gamma distribution.
48
-
49
- # Parameters
50
- - `x`: Value(s) at which to evaluate the log PMF/PDF.
51
- - `mu`: The positive mean/location.
52
- - `nu`: The positive inverse dispersion.
53
- - `censor`: The positive censor value.
54
-
55
- # Returns
56
- The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `nu`, and `censor`.
57
- """
58
- evals: Array = jnp.full_like(x * 1.0, -jnp.inf) # ensure float dtype
59
-
60
- # Construct boolean masks
61
- uncensored: Array = jnp.array(jnp.logical_and(0.0 < x, x < censor)) # pyright: ignore
62
- censored: Array = jnp.array(x == censor) # pyright: ignore
63
-
64
- # Evaluate log probability mass/density function
65
- evals = jnp.where(uncensored, gamma2.logprob(x, mu, nu), evals)
66
- evals = jnp.where(censored, lax.log(gammaincc(nu, x * nu / mu)), evals) # pyright: ignore
67
-
68
- return evals
@@ -1,3 +0,0 @@
1
- from . import r
2
-
3
- __all__ = ["r"]
@@ -1,116 +0,0 @@
1
- import jax.numpy as jnp
2
- import jax.random as jr
3
- from jaxtyping import Array, ArrayLike, Float, Key
4
-
5
- from bayinx.dists import posnormal
6
-
7
-
8
- def prob(
9
- x: Float[ArrayLike, "..."],
10
- mu: Float[ArrayLike, "..."],
11
- sigma: Float[ArrayLike, "..."],
12
- censor: Float[ArrayLike, "..."],
13
- ) -> Float[Array, "..."]:
14
- """
15
- The mixed probability mass/density function (PMF/PDF) for a right-censored positive Normal distribution.
16
-
17
- # Parameters
18
- - `x`: Value(s) at which to evaluate the PMF/PDF.
19
- - `mu`: The mean.
20
- - `sigma`: The positive standard deviation.
21
- - `censor`: The positive censor value.
22
-
23
- # Returns
24
- The PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
25
- """
26
- # Cast to Array
27
- x, mu, sigma, censor = (
28
- jnp.asarray(x),
29
- jnp.asarray(mu),
30
- jnp.asarray(sigma),
31
- jnp.asarray(censor),
32
- )
33
-
34
- # Construct boolean masks
35
- uncensored: Array = jnp.logical_and(0.0 < x, x < censor)
36
- censored: Array = x == censor
37
-
38
- # Evaluate probability mass/density function
39
- evals = jnp.where(uncensored, posnormal.prob(x, mu, sigma), 0.0)
40
- evals = jnp.where(censored, posnormal.ccdf(x, mu, sigma), evals)
41
-
42
- return evals
43
-
44
-
45
- def logprob(
46
- x: Float[ArrayLike, "..."],
47
- mu: Float[ArrayLike, "..."],
48
- sigma: Float[ArrayLike, "..."],
49
- censor: Float[ArrayLike, "..."],
50
- ) -> Float[Array, "..."]:
51
- """
52
- The log-transformed mixed probability mass/density function (log PMF/PDF) for a right-censored positive Normal distribution.
53
-
54
- # Parameters
55
- - `x`: Where to evaluate the log PMF/PDF.
56
- - `mu`: The mean.
57
- - `sigma`: The standard deviation.
58
- - `censor`: The censor.
59
-
60
- # Returns
61
- The log PMF/PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, `sigma`, and `censor`.
62
- """
63
- # Cast to Array
64
- x, mu, sigma, censor = (
65
- jnp.asarray(x),
66
- jnp.asarray(mu),
67
- jnp.asarray(sigma),
68
- jnp.asarray(censor),
69
- )
70
-
71
- # Construct boolean masks for censoring
72
- uncensored: Array = jnp.logical_and(jnp.asarray(0.0) < x, x < censor)
73
- censored: Array = x == censor
74
-
75
- # Evaluate log probability mass/density function
76
- evals = jnp.where(uncensored, posnormal.logprob(x, mu, sigma), -jnp.inf)
77
- evals = jnp.where(censored, posnormal.logccdf(x, mu, sigma), evals)
78
-
79
- return evals
80
-
81
- def sample(
82
- n: int,
83
- mu: Float[ArrayLike, "..."],
84
- sigma: Float[ArrayLike, "..."],
85
- censor: Float[ArrayLike, "..."] = jnp.inf,
86
- key: Key = jr.PRNGKey(0)
87
- ) -> Float[Array, "..."]:
88
- """
89
- Sample from a right-censored positive Normal distribution.
90
-
91
- # Parameters
92
- - `n`: Number of draws to sample per-parameter.
93
- - `mu`: The mean.
94
- - `sigma`: The standard deviation.
95
- - `censor`: The censor.
96
-
97
- # Returns
98
- Draws from a right-censored positive Normal distribution. The output will have the shape of (n,) + the broadcasted shapes of `mu`, `sigma`, and `censor`.
99
- """
100
- # Cast to Array
101
- mu, sigma, censor = (
102
- jnp.asarray(mu),
103
- jnp.asarray(sigma),
104
- jnp.asarray(censor),
105
- )
106
-
107
- # Derive shape
108
- shape = (n,) + jnp.broadcast_shapes(mu.shape, sigma.shape, censor.shape)
109
-
110
- # Draw from positive normal
111
- draws = jr.truncated_normal(key, 0.0, jnp.inf, shape) * sigma + mu
112
-
113
- # Censor values
114
- draws = jnp.where(censor <= draws, censor, draws)
115
-
116
- return draws
bayinx/dists/gamma2.py DELETED
@@ -1,49 +0,0 @@
1
- import jax.lax as lax
2
- import jax.numpy as jnp
3
- from jax.scipy.special import gammaln
4
- from jaxtyping import Array, ArrayLike, Float
5
-
6
-
7
- def prob(
8
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
9
- ) -> Float[Array, "..."]:
10
- """
11
- The probability density function (PDF) for a (mean-precision parameterized) Gamma distribution.
12
-
13
- # Parameters
14
- - `x`: Value(s) at which to evaluate the PDF.
15
- - `mu`: The positive mean.
16
- - `nu`: The positive inverse dispersion.
17
-
18
- # Returns
19
- The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
20
- """
21
- # Cast to Array
22
- x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
23
-
24
- return lax.exp(logprob(x, mu, nu))
25
-
26
-
27
- def logprob(
28
- x: Float[ArrayLike, "..."], mu: Float[ArrayLike, "..."], nu: Float[ArrayLike, "..."]
29
- ) -> Float[Array, "..."]:
30
- """
31
- The log-transformed probability density function (log PDF) for a (mean-precision parameterized) Gamma distribution.
32
-
33
- # Parameters
34
- - `x`: Value(s) at which to evaluate the log PDF.
35
- - `mu`: The positive mean/location.
36
- - `nu`: The positive inverse dispersion.
37
-
38
- # Returns
39
- The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `nu`.
40
- """
41
- # Cast to Array
42
- x, mu, nu = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(nu)
43
-
44
- return (
45
- -gammaln(nu)
46
- + nu * (lax.log(nu) - lax.log(mu))
47
- + (nu - 1.0) * lax.log(x)
48
- - (x * nu / mu)
49
- ) # pyright: ignore