bayinx 0.1.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ from bayinx.core.model import Model as Model
@@ -0,0 +1,3 @@
1
+ from bayinx.core.flow import Flow as Flow
2
+ from bayinx.core.model import Model as Model
3
+ from bayinx.core.variational import Variational as Variational
bayinx/core/flow.py ADDED
@@ -0,0 +1,68 @@
1
+ from abc import abstractmethod
2
+ from typing import Callable, Dict, Self, Tuple
3
+
4
+ import equinox as eqx
5
+ from jaxtyping import Array, Float
6
+
7
+ from bayinx.core.utils import __MyMeta
8
+
9
+
10
+ class Flow(eqx.Module, metaclass=__MyMeta):
11
+ """
12
+ A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
13
+
14
+ # Attributes
15
+ - `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
16
+ - `constraints`: A dictionary of functions that constrain their corresponding parameter.
17
+ """
18
+
19
+ params: Dict[str, Float[Array, "..."]]
20
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
21
+
22
+ @abstractmethod
23
+ def forward(self, draws: Array) -> Array:
24
+ """
25
+ Computes the forward transformation of `draws`.
26
+ """
27
+ pass
28
+
29
+ @abstractmethod
30
+ def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
31
+ """
32
+ Computes the log-absolute-determinant of the Jacobian at `draws` and applies the forward transformation.
33
+
34
+ # Returns
35
+ A tuple of JAX Arrays containing the log-absolute-determinant of the Jacobians and transformed draws.
36
+ """
37
+ pass
38
+
39
+ def __init_subclass__(cls):
40
+ # Add contrain_pars method
41
+ def constrain_pars(self: Self):
42
+ """
43
+ Constrain `params` to the appropriate domain.
44
+
45
+ # Returns
46
+ A dictionary of transformed JAX Arrays representing the constrained parameters.
47
+ """
48
+ t_params = self.params
49
+
50
+ for par, map in self.constraints.items():
51
+ t_params[par] = map(t_params[par])
52
+
53
+ return t_params
54
+
55
+ cls.constrain_pars = eqx.filter_jit(constrain_pars)
56
+
57
+ # Add transform_pars method if not present
58
+ if not callable(getattr(cls, "transform_pars", None)):
59
+ def transform_pars(self: Self) -> Dict[str, Array]:
60
+ """
61
+ Apply a custom transformation to `params` if needed.
62
+
63
+ # Returns
64
+ A dictionary of transformed JAX Arrays representing the transformed parameters.
65
+ """
66
+ return self.constrain_pars()
67
+
68
+ cls.transform_pars = eqx.filter_jit(transform_pars)
bayinx/core/model.py ADDED
@@ -0,0 +1,55 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Callable, Dict
3
+
4
+ import equinox as eqx
5
+ from jaxtyping import Array, Scalar
6
+
7
+ from bayinx.core.utils import __MyMeta
8
+
9
+
10
+ class Model(eqx.Module, metaclass=__MyMeta):
11
+ """
12
+ A superclass used to define probabilistic models.
13
+
14
+ # Attributes
15
+ - `params`: A dictionary of JAX Arrays representing parameters of the model.
16
+ - `constraints`: A dictionary of functions that constrain their corresponding parameter.
17
+ """
18
+
19
+ params: Dict[str, Array]
20
+ constraints: Dict[str, Callable[[Array], Array]]
21
+
22
+ @abstractmethod
23
+ def eval(self, data: Any) -> Scalar:
24
+ pass
25
+
26
+ def __init_subclass__(cls):
27
+ # Add constrain method
28
+ def constrain_pars(self: Model) -> Dict[str, Array]:
29
+ """
30
+ Constrain `params` to the appropriate domain.
31
+
32
+ # Returns
33
+ A dictionary of transformed JAX Arrays representing the constrained parameters.
34
+ """
35
+ t_params = self.params
36
+
37
+ for par, map in self.constraints.items():
38
+ t_params[par] = map(t_params[par])
39
+
40
+ return t_params
41
+
42
+ cls.constrain_pars = eqx.filter_jit(constrain_pars)
43
+
44
+ # Add transform_pars method if not present
45
+ if not callable(getattr(cls, "transform_pars", None)):
46
+ def transform_pars(self: Model) -> Dict[str, Array]:
47
+ """
48
+ Apply a custom transformation to `params` if needed.
49
+
50
+ # Returns
51
+ A dictionary of transformed JAX Arrays representing the transformed parameters.
52
+ """
53
+ return self.constrain_pars()
54
+
55
+ cls.transform_pars = eqx.filter_jit(transform_pars)
bayinx/core/utils.py ADDED
@@ -0,0 +1,54 @@
1
+ from typing import Callable, Dict
2
+
3
+ import equinox as eqx
4
+ from jaxtyping import Array
5
+
6
+
7
+ class __MyMeta(type(eqx.Module)):
8
+ """
9
+ Metaclass to ensure attribute types are respected.
10
+ """
11
+
12
+ def __call__(cls, *args, **kwargs):
13
+ obj = super().__call__(*args, **kwargs)
14
+
15
+ # Check parameters are a Dict of JAX Arrays
16
+ if not isinstance(obj.params, Dict):
17
+ raise ValueError(
18
+ f"Model {cls.__name__} must initialize 'params' as a dictionary."
19
+ )
20
+
21
+ for key, value in obj.params.items():
22
+ if not isinstance(value, Array):
23
+ raise TypeError(f"Parameter '{key}' must be a JAX Array.")
24
+
25
+ # Check constraints are a Dict of functions
26
+ if not isinstance(obj.constraints, Dict):
27
+ raise ValueError(
28
+ f"Model {cls.__name__} must initialize 'constraints' as a dictionary."
29
+ )
30
+
31
+ for key, value in obj.constraints.items():
32
+ if not isinstance(value, Callable):
33
+ raise TypeError(f"Constraint for parameter '{key}' must be a function.")
34
+
35
+ # Check that the constrain method returns a dict equivalent to `params`
36
+ t_params: Dict[str, Array] = obj.constrain_pars()
37
+
38
+ if not isinstance(t_params, Dict):
39
+ raise ValueError(
40
+ f"The 'constrain' method of {cls.__name__} must return a Dict of JAX Arrays."
41
+ )
42
+
43
+ for key, value in t_params.items():
44
+ if not isinstance(value, Array):
45
+ raise TypeError(f"Constrained parameter '{key}' must be a JAX Array.")
46
+
47
+ if not value.shape == obj.params[key].shape:
48
+ raise ValueError(
49
+ f"Constrained parameter '{key}' must have same shape as unconstrained counterpart."
50
+ )
51
+
52
+ # Check transform_pars
53
+
54
+ return obj
@@ -0,0 +1,159 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Callable, Self, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax
6
+ import jax.lax as lax
7
+ import jax.numpy as jnp
8
+ import jax.random as jr
9
+ import optax as opx
10
+ from jaxtyping import Array, Float, Key, PyTree, Scalar
11
+ from optax import GradientTransformation, OptState, Schedule
12
+
13
+ from bayinx.core import Model
14
+
15
+
16
+ class Variational(eqx.Module):
17
+ """
18
+ A superclass used to define variational methods.
19
+
20
+ # Attributes
21
+ - `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
22
+ - `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
23
+ """
24
+
25
+ _unflatten: Callable[[Float[Array, "..."]], Model]
26
+ _constraints: Model
27
+
28
+ @abstractmethod
29
+ def sample(self, n: int, key: Key) -> Array:
30
+ """
31
+ Sample from the variational distribution.
32
+ """
33
+ pass
34
+
35
+ @abstractmethod
36
+ def eval(self, draws: Array) -> Array:
37
+ """
38
+ Evaluate the variational distribution at `draws`.
39
+ """
40
+ pass
41
+
42
+ @abstractmethod
43
+ def elbo(self, n: int, key: Key, data: Any = None) -> Array:
44
+ """
45
+ Evaluate the ELBO.
46
+ """
47
+ pass
48
+
49
+ @abstractmethod
50
+ def elbo_grad(self, n: int, key: Key, data: Any = None) -> PyTree:
51
+ """
52
+ Evaluate the gradient of the ELBO.
53
+ """
54
+ pass
55
+
56
+ @abstractmethod
57
+ def filter_spec(self):
58
+ """
59
+ Filter specification for dynamic and static components of the `Variational`.
60
+ """
61
+ pass
62
+
63
+ def __init_subclass__(cls):
64
+ """
65
+ Construct methods that are shared across all VI methods.
66
+ """
67
+
68
+ def eval_model(self, draws: Array, data: Any = None) -> Array:
69
+ """
70
+ Reconstruct models from variational draws and evaluate their posterior density.
71
+
72
+ # Parameters
73
+ - `draws`: A set of variational draws.
74
+ - `data`: Data used to evaluate the posterior(if needed).
75
+ """
76
+ # Unflatten variational draw
77
+ model: Model = self._unflatten(draws)
78
+
79
+ # Combine with constraints
80
+ model: Model = eqx.combine(model, self._constraints)
81
+
82
+ # Evaluate posterior density
83
+ return model.eval(data)
84
+
85
+ cls.eval_model = jax.vmap(eqx.filter_jit(eval_model), (None, 0, None))
86
+
87
+ def fit(
88
+ self,
89
+ max_iters: int,
90
+ data: Any = None,
91
+ learning_rate: float = 1,
92
+ tolerance: float = 1e-4,
93
+ var_draws: int = 1,
94
+ key: Key = jr.PRNGKey(0),
95
+ ) -> Self:
96
+ """
97
+ Optimize the variational distribution.
98
+
99
+ # Parameters
100
+ - `max_iters`: Maximum number of iterations for the optimization loop.
101
+ - `data`: Data to evaluate the posterior density with(if available).
102
+ - `learning_rate`: Initial learning rate for optimization.
103
+ - `tolerance`: Relative tolerance of ELBO decrease for stopping early.
104
+ - `var_draws`: Number of variational draws to draw each iteration.
105
+ - `key`: A PRNG key.
106
+ """
107
+ # Construct scheduler
108
+ schedule: Schedule = opx.cosine_decay_schedule(
109
+ init_value=learning_rate, decay_steps=max_iters
110
+ )
111
+
112
+ # Initialize optimizer
113
+ optim: GradientTransformation = opx.chain(
114
+ opx.scale(-1.0), opx.nadam(schedule)
115
+ )
116
+ opt_state: OptState = optim.init(eqx.filter(self, self.filter_spec()))
117
+
118
+ # Optimization loop helper functions
119
+ @eqx.filter_jit
120
+ def condition(state: Tuple[Self, OptState, Scalar, Key]):
121
+ # Unpack iteration state
122
+ self, opt_state, i, key = state
123
+
124
+ return i < max_iters
125
+
126
+ @eqx.filter_jit
127
+ def body(state: Tuple[Self, OptState, Scalar, Key]):
128
+ # Unpack iteration state
129
+ self, opt_state, i, key = state
130
+
131
+ # Update iteration
132
+ i = i + 1
133
+
134
+ # Update PRNG key
135
+ key, _ = jr.split(key)
136
+
137
+ # Compute gradient of the ELBO
138
+ updates: PyTree = self.elbo_grad(var_draws, key, data)
139
+
140
+ # Compute updates
141
+ updates, opt_state = optim.update(
142
+ updates, opt_state, eqx.filter(self, self.filter_spec())
143
+ )
144
+
145
+ # Update variational distribution
146
+ self: Self = eqx.apply_updates(self, updates)
147
+
148
+ return self, opt_state, i, key
149
+
150
+ # Run optimization loop
151
+ self = lax.while_loop(
152
+ cond_fun=condition,
153
+ body_fun=body,
154
+ init_val=(self, opt_state, jnp.array(0, jnp.uint32), key),
155
+ )[0]
156
+
157
+ return self
158
+
159
+ cls.fit = eqx.filter_jit(fit)
File without changes
@@ -0,0 +1,33 @@
1
+ import jax.lax as lax
2
+ from jaxtyping import Array, ArrayLike, Real, UInt
3
+
4
+
5
+ # MARK: Functions ----
6
+ def prob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
7
+ """
8
+ The probability mass function (PMF) for a Bernoulli distribution.
9
+
10
+ # Parameters
11
+ - `x`: Value(s) at which to evaluate the PDF.
12
+ - `p`: The probability parameter(s).
13
+
14
+ # Returns
15
+ The PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
16
+ """
17
+
18
+ return lax.pow(p, x) * lax.pow(1 - p, 1 - x)
19
+
20
+
21
+ def logprob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
22
+ """
23
+ The log probability mass function (log PMF) for a Bernoulli distribution.
24
+
25
+ # Parameters
26
+ - `x`: Value(s) at which to evaluate the log PMF.
27
+ - `p`: The probability parameter(s).
28
+
29
+ # Returns
30
+ The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
31
+ """
32
+
33
+ return lax.log(p) * x + (1 - x) * lax.log(1 - p)
File without changes
bayinx/dists/gamma.py ADDED
File without changes
bayinx/dists/gamma2.py ADDED
File without changes
bayinx/dists/normal.py ADDED
@@ -0,0 +1,83 @@
1
+ # MARK: Imports ----
2
+ import jax.lax as _lax
3
+
4
+ ## Typing
5
+ from jaxtyping import Array, Real
6
+
7
+ # MARK: Constants
8
+ _PI = 3.141592653589793
9
+
10
+
11
+ # MARK: Functions ----
12
+ def prob(
13
+ x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
14
+ ) -> Real[Array, "..."]:
15
+ """
16
+ The probability density function (PDF) for a Normal distribution.
17
+
18
+ # Parameters
19
+ - `x`: Value(s) at which to evaluate the PDF.
20
+ - `mu`: The mean/location parameter(s).
21
+ - `sigma`: The non-negative standard deviation parameter(s).
22
+
23
+ # Returns
24
+ The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
25
+ """
26
+
27
+ return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / (
28
+ sigma * _lax.sqrt(2.0 * _PI)
29
+ )
30
+
31
+
32
+ def logprob(
33
+ x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
34
+ ) -> Real[Array, "..."]:
35
+ """
36
+ The log of the probability density function (log PDF) for a Normal distribution.
37
+
38
+ # Parameters
39
+ - `x`: Value(s) at which to evaluate the log PDF.
40
+ - `mu`: The mean/location parameter(s).
41
+ - `sigma`: The non-negative standard deviation parameter(s).
42
+
43
+ # Returns
44
+ The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
45
+ """
46
+
47
+ return -_lax.log(sigma * _lax.sqrt(2.0 * _PI)) - 0.5 * _lax.square((x - mu) / sigma)
48
+
49
+
50
+ def uprob(
51
+ x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
52
+ ) -> Real[Array, "..."]:
53
+ """
54
+ The unnormalized probability density function (uPDF) for a Normal distribution.
55
+
56
+ # Parameters
57
+ - `x`: Value(s) at which to evaluate the uPDF.
58
+ - `mu`: The mean/location parameter(s).
59
+ - `sigma`: The non-negative standard deviation parameter(s).
60
+
61
+ # Returns
62
+ The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
63
+ """
64
+
65
+ return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma
66
+
67
+
68
+ def ulogprob(
69
+ x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
70
+ ) -> Real[Array, "..."]:
71
+ """
72
+ The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
73
+
74
+ # Parameters
75
+ - `x`: Value(s) at which to evaluate the log uPDF.
76
+ - `mu`: The mean/location parameter(s).
77
+ - `sigma`: The non-negative standard deviation parameter(s).
78
+
79
+ # Returns
80
+ The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
81
+ """
82
+
83
+ return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma)
File without changes
@@ -0,0 +1,5 @@
1
+ from bayinx.machinery.variational.meanfield import MeanField as MeanField
2
+ from bayinx.machinery.variational.normalizing_flow import (
3
+ NormalizingFlow as NormalizingFlow,
4
+ )
5
+ from bayinx.machinery.variational.standard import Standard as Standard
@@ -0,0 +1,3 @@
1
+ from bayinx.machinery.variational.flows.affine import Affine as Affine
2
+ from bayinx.machinery.variational.flows.planar import Planar as Planar
3
+ from bayinx.machinery.variational.flows.radial import Radial as Radial
@@ -0,0 +1,68 @@
1
+ from functools import partial
2
+ from typing import Callable, Dict, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jaxtyping import Array, Float, Scalar
8
+
9
+ from bayinx.core import Flow
10
+
11
+
12
+ class Affine(Flow):
13
+ """
14
+ An affine flow.
15
+
16
+ # Attributes
17
+ - `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
18
+ - `constraints`: A dictionary of constraining transformations.
19
+ """
20
+
21
+ params: Dict[str, Float[Array, "..."]]
22
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
23
+ eqx.field(static=True)
24
+ )
25
+
26
+ def __init__(self, dim: int):
27
+ """
28
+ Initializes an affine flow.
29
+
30
+ # Parameters
31
+ - `dim`: The dimension of the parameter space.
32
+ """
33
+ self.params = {
34
+ "shift": jnp.zeros(dim),
35
+ "scale": jnp.zeros((dim, dim)),
36
+ }
37
+
38
+ self.constraints = {"scale": lambda m: jnp.tril(jnp.exp(m))}
39
+
40
+ @eqx.filter_jit
41
+ def forward(self, draws: Array) -> Array:
42
+ params = self.constrain_pars()
43
+
44
+ # Extract parameters
45
+ shift: Array = params["shift"]
46
+ scale: Array = params["scale"]
47
+
48
+ # Compute forward transformation
49
+ draws = draws @ scale + shift
50
+
51
+ return draws
52
+
53
+ @eqx.filter_jit
54
+ @partial(jax.vmap, in_axes=(None, 0))
55
+ def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
56
+ params = self.constrain_pars()
57
+
58
+ # Extract parameters
59
+ shift: Array = params["shift"]
60
+ scale: Array = params["scale"]
61
+
62
+ # Compute forward transformation
63
+ draws = draws @ scale + shift
64
+
65
+ # Compute ladj
66
+ ladj: Scalar = jnp.log(jnp.diag(scale)).sum()
67
+
68
+ return ladj, draws
@@ -0,0 +1,76 @@
1
+ from functools import partial
2
+ from typing import Callable, Dict, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax.random as jr
8
+ from jaxtyping import Array, Float, Scalar
9
+
10
+ from bayinx.core import Flow
11
+
12
+
13
+ class Planar(Flow):
14
+ """
15
+ A planar flow.
16
+
17
+ # Attributes
18
+ - `params`: A dictionary containing the JAX Arrays representing the flow parameters.
19
+ - `constraints`: A dictionary of constraining transformations.
20
+ """
21
+
22
+ params: Dict[str, Float[Array, "..."]]
23
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
24
+ eqx.field(static=True)
25
+ )
26
+
27
+ def __init__(self, dim: int, key=jr.PRNGKey(0)):
28
+ """
29
+ Initializes a planar flow.
30
+
31
+ # Parameters
32
+ - `dim`: The dimension of the parameter space.
33
+ """
34
+ self.params = {
35
+ "u": jr.normal(key, (dim,)),
36
+ "w": jr.normal(key, (dim,)),
37
+ "b": jr.normal(key, (1,)),
38
+ }
39
+ self.constraints = {}
40
+
41
+ @eqx.filter_jit
42
+ @partial(jax.vmap, in_axes=(None, 0))
43
+ def forward(self, draws: Array) -> Array:
44
+ params = self.constrain_pars()
45
+
46
+ # Extract parameters
47
+ w: Array = params["w"]
48
+ u: Array = params["u"]
49
+ b: Array = params["b"]
50
+
51
+ # Compute forward transformation
52
+ draws = draws + u * jnp.tanh(draws.dot(w) + b)
53
+
54
+ return draws
55
+
56
+ @eqx.filter_jit
57
+ @partial(jax.vmap, in_axes=(None, 0))
58
+ def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
59
+ params = self.constrain_pars()
60
+
61
+ # Extract parameters
62
+ w: Array = params["w"]
63
+ u: Array = params["u"]
64
+ b: Array = params["b"]
65
+
66
+ # Compute shared intermediates
67
+ x: Array = draws.dot(w) + b
68
+
69
+ # Compute forward transformation
70
+ draws = draws + u * jnp.tanh(x)
71
+
72
+ # Compute ladj
73
+ h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
74
+ ladj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
75
+
76
+ return ladj, draws
@@ -0,0 +1,95 @@
1
+ from functools import partial
2
+ from typing import Callable, Dict, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax.random as jr
8
+ from jax.numpy.linalg import norm
9
+ from jaxtyping import Array, Float, Scalar
10
+
11
+ from bayinx.core import Flow
12
+
13
+
14
+ class Radial(Flow):
15
+ """
16
+ A radial flow.
17
+
18
+ # Attributes
19
+ - `params`: A dictionary containing the JAX Arrays representing the flow parameters.
20
+ - `constraints`: A dictionary of constraining transformations.
21
+ """
22
+
23
+ params: Dict[str, Float[Array, "..."]]
24
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
25
+ eqx.field(static=True)
26
+ )
27
+
28
+ def __init__(self, dim: int, key=jr.PRNGKey(0)):
29
+ """
30
+ Initializes a planar flow.
31
+
32
+ # Parameters
33
+ - `dim`: The dimension of the parameter space.
34
+ """
35
+ self.params = {
36
+ "alpha": jnp.array(1.0),
37
+ "beta": jnp.array(1.0),
38
+ "center": jnp.ones(dim),
39
+ }
40
+ self.constraints = {"beta": jnp.exp}
41
+
42
+ @partial(jax.vmap, in_axes=(None, 0))
43
+ @eqx.filter_jit
44
+ def forward(self, draws: Array) -> Array:
45
+ """
46
+ Applies the forward radial transformation for each draw.
47
+
48
+ # Parameters
49
+ - `draws`: Draws from some layer of a normalizing flow.
50
+
51
+ # Returns
52
+ The transformed samples.
53
+ """
54
+ params = self.transform_pars()
55
+
56
+ # Extract parameters
57
+ alpha = params["alpha"]
58
+ beta = params["beta"]
59
+ center = params["center"]
60
+
61
+ # Compute distance to center per-draw
62
+ r: Array = norm(draws - center)
63
+
64
+ # Apply forward transformation
65
+ draws = draws + (beta / (alpha + r)) * (draws - center)
66
+
67
+ return draws
68
+
69
+ @partial(jax.vmap, in_axes=(None, 0))
70
+ @eqx.filter_jit
71
+ def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
72
+ params = self.transform_pars()
73
+
74
+ # Extract parameters
75
+ alpha = params["alpha"]
76
+ beta = params["beta"]
77
+ center = params["center"]
78
+
79
+ # Compute distance to center per-draw
80
+ r: Array = norm(draws - center)
81
+
82
+ # Compute shared intermediates
83
+ x: Array = beta / (alpha + r)
84
+
85
+ # Apply forward transformation
86
+ draws = draws + (x) * (draws - center)
87
+
88
+ # Compute density adjustment
89
+ ladj = jnp.log(
90
+ jnp.abs(
91
+ (1.0 + alpha * beta / (alpha + r) ** 2.0) * (1.0 + x) ** (center.size - 1.0)
92
+ )
93
+ )
94
+
95
+ return ladj, draws
@@ -0,0 +1,76 @@
1
+ from functools import partial
2
+ from typing import Callable, Dict, Tuple
3
+
4
+ import equinox as eqx
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax.random as jr
8
+ from jaxtyping import Array, Float, Scalar
9
+
10
+ from bayinx.core import Flow
11
+
12
+
13
+ class Sylvester(Flow):
14
+ """
15
+ A sylvester flow.
16
+
17
+ # Attributes
18
+ - `params`: A dictionary containing the JAX Arrays representing the flow parameters.
19
+ - `constraints`: A dictionary of constraining transformations.
20
+ """
21
+
22
+ params: Dict[str, Float[Array, "..."]]
23
+ constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]] = (
24
+ eqx.field(static=True)
25
+ )
26
+
27
+ def __init__(self, dim: int, key=jr.PRNGKey(0)):
28
+ """
29
+ Initializes a planar flow.
30
+
31
+ # Parameters
32
+ - `dim`: The dimension of the parameter space.
33
+ """
34
+ self.params = {
35
+ "u": jr.normal(key, (dim,)),
36
+ "w": jr.normal(key, (dim,)),
37
+ "b": jr.normal(key, (1,)),
38
+ }
39
+ self.constraints = {}
40
+
41
+ @eqx.filter_jit
42
+ @partial(jax.vmap, in_axes=(None, 0))
43
+ def forward(self, draws: Array) -> Array:
44
+ params = self.constrain_pars()
45
+
46
+ # Extract parameters
47
+ w: Array = params["w"]
48
+ u: Array = params["u"]
49
+ b: Array = params["b"]
50
+
51
+ # Compute forward transformation
52
+ draws = draws + u * jnp.tanh(draws.dot(w) + b)
53
+
54
+ return draws
55
+
56
+ @eqx.filter_jit
57
+ @partial(jax.vmap, in_axes=(None, 0))
58
+ def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
59
+ params = self.constrain_pars()
60
+
61
+ # Extract parameters
62
+ w: Array = params["w"]
63
+ u: Array = params["u"]
64
+ b: Array = params["b"]
65
+
66
+ # Compute shared intermediates
67
+ x: Array = draws.dot(w) + b
68
+
69
+ # Compute forward transformation
70
+ draws = draws + u * jnp.tanh(x)
71
+
72
+ # Compute ladj
73
+ h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
74
+ ladj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
75
+
76
+ return ladj, draws
@@ -0,0 +1,124 @@
1
+ from typing import Any, Callable, Dict, Self
2
+
3
+ import equinox as eqx
4
+ import jax.numpy as jnp
5
+ import jax.random as jr
6
+ import jax.tree_util as jtu
7
+ from jax.flatten_util import ravel_pytree
8
+ from jaxtyping import Array, Float, Key, Scalar
9
+
10
+ from bayinx.core import Model, Variational
11
+ from bayinx.dists import normal
12
+
13
+
14
+ class MeanField(Variational):
15
+ """
16
+ A fully factorized Gaussian approximation to a posterior distribution.
17
+
18
+ # Attributes
19
+ - `var_params`: The variational parameters for the approximation.
20
+ """
21
+
22
+ var_params: Dict[str, Float[Array, "..."]]
23
+ _unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
24
+ _constraints: Model = eqx.field(static=True)
25
+
26
+ def __init__(self, model: Model):
27
+ """
28
+ Constructs an unoptimized meanfield posterior approximation.
29
+
30
+ # Parameters
31
+ - `model`: A probabilistic `Model` object.
32
+ """
33
+ # Partition model
34
+ params, self._constraints = eqx.partition(model, eqx.is_array)
35
+
36
+ # Flatten params component
37
+ flat_params, self._unflatten = ravel_pytree(params)
38
+
39
+ # Initialize variational parameters
40
+ self.var_params = {
41
+ "mean": flat_params,
42
+ "log_std": jnp.zeros(flat_params.size, dtype=flat_params.dtype),
43
+ }
44
+
45
+ @eqx.filter_jit
46
+ def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
47
+ # Sample variational draws
48
+ draws: Array = (
49
+ jr.normal(key=key, shape=(n, self.var_params["mean"].size))
50
+ * jnp.exp(self.var_params["log_std"])
51
+ + self.var_params["mean"]
52
+ )
53
+
54
+ return draws
55
+
56
+ @eqx.filter_jit
57
+ def eval(self, draws: Array) -> Array:
58
+ return normal.logprob(
59
+ x=draws,
60
+ mu=self.var_params["mean"],
61
+ sigma=jnp.exp(self.var_params["log_std"]),
62
+ ).sum(axis=1)
63
+
64
+ @eqx.filter_jit
65
+ def filter_spec(self):
66
+ filter_spec = jtu.tree_map(lambda _: False, self)
67
+ filter_spec = eqx.tree_at(
68
+ lambda mf: mf.var_params,
69
+ filter_spec,
70
+ replace=True,
71
+ )
72
+ return filter_spec
73
+
74
+ @eqx.filter_jit
75
+ def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
76
+ """
77
+ Estimate the ELBO and its gradient(w.r.t the variational parameters).
78
+ """
79
+ # Partition variational
80
+ dyn, static = eqx.partition(self, self.filter_spec())
81
+
82
+ @eqx.filter_jit
83
+ def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
84
+ # Combine
85
+ vari = eqx.combine(dyn, static)
86
+
87
+ # Sample draws from variational distribution
88
+ draws: Array = vari.sample(n, key)
89
+
90
+ # Evaluate posterior density for each draw
91
+ posterior_evals: Array = vari.eval_model(draws, data)
92
+
93
+ # Evaluate variational density for each draw
94
+ variational_evals: Array = vari.eval(draws)
95
+
96
+ # Evaluate ELBO
97
+ return jnp.mean(posterior_evals - variational_evals)
98
+
99
+ return elbo(dyn, n, key, data)
100
+
101
+ @eqx.filter_jit
102
+ def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
103
+ # Partition
104
+ dyn, static = eqx.partition(self, self.filter_spec())
105
+
106
+ @eqx.filter_grad
107
+ @eqx.filter_jit
108
+ def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
109
+ # Combine
110
+ vari = eqx.combine(dyn, static)
111
+
112
+ # Sample draws from variational distribution
113
+ draws: Array = vari.sample(n, key)
114
+
115
+ # Evaluate posterior density for each draw
116
+ posterior_evals: Array = vari.eval_model(draws, data)
117
+
118
+ # Evaluate variational density for each draw
119
+ variational_evals: Array = vari.eval(draws)
120
+
121
+ # Evaluate ELBO
122
+ return jnp.mean(posterior_evals - variational_evals)
123
+
124
+ return elbo_grad(dyn, n, key, data)
@@ -0,0 +1,152 @@
1
+ from typing import Any, Callable, Self, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.flatten_util as jfu
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ import jax.tree_util as jtu
8
+ from jaxtyping import Array, Float, Key, Scalar
9
+
10
+ from bayinx.core import Flow, Model, Variational
11
+
12
+
13
+ class NormalizingFlow(Variational):
14
+ """
15
+ An ordered collection of diffeomorphisms that map a base distribution to a
16
+ normalized approximation of a posterior distribution.
17
+
18
+ # Attributes
19
+ - `base`: A base variational distribution.
20
+ - `flows`: An ordered collection of continuously parameterized
21
+ diffeomorphisms.
22
+ """
23
+
24
+ flows: list[Flow]
25
+ base: Variational
26
+ _unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
27
+ _constraints: Model = eqx.field(static=True)
28
+
29
+ def __init__(self, base: Variational, flows: list[Flow], model: Model):
30
+ """
31
+ Constructs an unoptimized normalizing flow posterior approximation.
32
+
33
+ # Parameters
34
+ - `base`: The base variational distribution.
35
+ - `flows`: A list of diffeomorphisms.
36
+ - `model`: A probabilistic `Model` object.
37
+ """
38
+ # Partition model
39
+ params, self._constraints = eqx.partition(model, eqx.is_array)
40
+
41
+ # Flatten params component
42
+ flat_params, self._unflatten = jfu.ravel_pytree(params)
43
+
44
+ self.base = base
45
+ self.flows = flows
46
+
47
+ @eqx.filter_jit
48
+ def sample(self, n: int, key: Key = jr.PRNGKey(0)):
49
+ """
50
+ Sample from the variational distribution `n` times.
51
+ """
52
+ # Sample from the base distribution
53
+ draws: Array = self.base.sample(n, key)
54
+
55
+ # Apply forward transformations
56
+ for map in self.flows:
57
+ draws = map.forward(draws)
58
+
59
+ return draws
60
+
61
+ @eqx.filter_jit
62
+ def eval(self, draws: Array) -> Array:
63
+ # Evaluate base density
64
+ variational_evals: Array = self.base.eval(draws)
65
+
66
+ for map in self.flows:
67
+ # Compute adjustment
68
+ ladj, draws = map.adjust_density(draws)
69
+
70
+ # Adjust variational density
71
+ variational_evals = variational_evals - ladj
72
+
73
+ return variational_evals
74
+
75
+ @eqx.filter_jit
76
+ def _eval(self, draws: Array, data=None) -> Tuple[Scalar, Array]:
77
+ """
78
+ Evaluate the posterior and variational densities at the transformed
79
+ `draws` to avoid extra compute when requiring variational draws for
80
+ the posterior evaluation.
81
+
82
+ # Parameters
83
+ - `draws`: Draws from the base variational distribution.
84
+ - `data`: Any data required to evaluate the posterior density.
85
+
86
+ # Returns
87
+ The posterior and variational densities.
88
+ """
89
+ # Evaluate base density
90
+ variational_evals: Array = self.base.eval(draws)
91
+
92
+ for map in self.flows:
93
+ # Compute adjustment
94
+ ladj, draws = map.adjust_density(draws)
95
+
96
+ # Adjust variational density
97
+ variational_evals = variational_evals - ladj
98
+
99
+ # Evaluate posterior at final variational draws
100
+ posterior_evals = self.eval_model(draws, data)
101
+
102
+ return posterior_evals, variational_evals
103
+
104
+ def filter_spec(self):
105
+ # Only optimize the parameters of the flows
106
+ filter_spec = jtu.tree_map(lambda _: False, self)
107
+ filter_spec = eqx.tree_at(
108
+ lambda nf: nf.flows,
109
+ filter_spec,
110
+ replace=True,
111
+ )
112
+
113
+ return filter_spec
114
+
115
+ @eqx.filter_jit
116
+ def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
117
+ # Partition
118
+ dyn, static = eqx.partition(self, self.filter_spec())
119
+
120
+ @eqx.filter_jit
121
+ def elbo(dyn: Self, n: int, key: Key, data: Any = None):
122
+ # Combine
123
+ self = eqx.combine(dyn, static)
124
+
125
+ # Sample draws from variational distribution
126
+ draws: Array = self.base.sample(n, key)
127
+
128
+ posterior_evals, variational_evals = self._eval(draws, data)
129
+ # Evaluate ELBO
130
+ return jnp.mean(posterior_evals - variational_evals)
131
+
132
+ return elbo(dyn, n, key, data)
133
+
134
+ @eqx.filter_jit
135
+ def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
136
+ # Partition
137
+ dyn, static = eqx.partition(self, self.filter_spec())
138
+
139
+ @eqx.filter_grad
140
+ @eqx.filter_jit
141
+ def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
142
+ # Combine
143
+ self = eqx.combine(dyn, static)
144
+
145
+ # Sample draws from variational distribution
146
+ draws: Array = self.base.sample(n, key)
147
+
148
+ posterior_evals, variational_evals = self._eval(draws, data)
149
+ # Evaluate ELBO
150
+ return jnp.mean(posterior_evals - variational_evals)
151
+
152
+ return elbo_grad(dyn, n, key, data)
@@ -0,0 +1,67 @@
1
+ from typing import Callable
2
+
3
+ import equinox as eqx
4
+ import jax.numpy as jnp
5
+ import jax.random as jr
6
+ import jax.tree_util as jtu
7
+ from jax.flatten_util import ravel_pytree
8
+ from jaxtyping import Array, Float, Key
9
+
10
+ from bayinx.core import Model, Variational
11
+ from bayinx.dists import normal
12
+
13
+
14
+ class Standard(Variational):
15
+ """
16
+ A standard normal distribution approximation to a posterior distribution.
17
+
18
+ # Attributes
19
+ - `dim`: Dimension of the parameter space.
20
+ """
21
+
22
+ dim: int = eqx.field(static=True)
23
+ _unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
24
+ _constraints: Model = eqx.field(static=True)
25
+
26
+ def __init__(self, model: Model):
27
+ """
28
+ Constructs a standard normal approximation to a posterior distribution.
29
+
30
+ # Parameters
31
+ - `model`: A probabilistic `Model` object.
32
+ """
33
+ # Partition model
34
+ _, self._constraints = eqx.partition(model, eqx.is_array)
35
+
36
+ # Flatten params component
37
+ _, self._unflatten = ravel_pytree(_)
38
+
39
+ # Store dimension of parameter space
40
+ self.dim = jnp.size(_)
41
+
42
+ @eqx.filter_jit
43
+ def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
44
+ # Sample variational draws
45
+ draws: Array = jr.normal(key=key, shape=(n, self.dim))
46
+
47
+ return draws
48
+
49
+ @eqx.filter_jit
50
+ def eval(self, draws: Array) -> Array:
51
+ return normal.logprob(
52
+ x=draws,
53
+ mu=jnp.array(0.0),
54
+ sigma=jnp.array(1.0),
55
+ ).sum(axis=1)
56
+
57
+ @eqx.filter_jit
58
+ def filter_spec(self):
59
+ filter_spec = jtu.tree_map(lambda _: False, self)
60
+
61
+ return filter_spec
62
+
63
+ def elbo(self):
64
+ return None
65
+
66
+ def elbo_grad(self):
67
+ return None
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.4
2
+ Name: bayinx
3
+ Version: 0.2.2
4
+ Summary: Bayesian Inference with JAX
5
+ Requires-Python: >=3.12
6
+ Requires-Dist: equinox>=0.11.12
7
+ Requires-Dist: jax>=0.4.38
8
+ Requires-Dist: jaxtyping>=0.2.36
9
+ Requires-Dist: optax>=0.2.4
10
+ Requires-Dist: pytest-benchmark>=5.1.0
11
+ Requires-Dist: pytest>=8.3.5
12
+ Description-Content-Type: text/markdown
13
+
14
+ # <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
15
+
16
+ The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is equivalent to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
17
+
18
+ In the short-term, I'm going to focus on:
19
+ 1) Implementing as much machinery as I feel is enough.
20
+ 2) Figuring out how to design the `Model` superclass to have something like the `transformed pars {}` block but unifies transformations and constraints.
21
+ 3) Figuring out how to design the library to automatically recognize what kind of machinery is amenable to a given probabilistic model.
22
+
23
+ In the long-term, I'm going to focus on:
24
+ 1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
25
+ 2) How to make working with the posterior as easy as possible.
26
+ - That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
27
+
28
+ Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
29
+
30
+ Additionally, when I get around to it I'd like the package documentation to also include theoretical and implementation details for all machinery implemented(with overthinking boxes because I do like that design from McElreath's book).
31
+
32
+
33
+ # TODO
34
+ - Find some way to discern between models with all floating-point parameters and weirder models with integer parameters. Useful for restricting variational methods like `MeanField` to `Model`s that only have floating-point parameters.
35
+ - Look into adaptively tuning ADAM hyperparameters.
36
+ - Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
37
+ - Low-rank affine flow?
38
+ - https://arxiv.org/pdf/1803.05649 implement sylvester flows.
39
+ - Learn how to generate documentation lol.
40
+ - Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
@@ -0,0 +1,26 @@
1
+ bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
2
+ bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
+ bayinx/core/flow.py,sha256=4vj1t2xNPGp1VPE4xUshY-rHAw__KvSwjGDtKkW2taE,2252
5
+ bayinx/core/model.py,sha256=AI4eHrXAds3K7eWgZ9g5E6Kh76HP6WTn6s6q_0tnhck,1719
6
+ bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
7
+ bayinx/core/variational.py,sha256=T42uUNkF2tP1HJPyeIv7ISdika_G28wR_OOFXzx_hgo,4978
8
+ bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
10
+ bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
14
+ bayinx/machinery/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ bayinx/machinery/variational/__init__.py,sha256=5GdhqBOHKsXg2tZGAMNlxyrLPD0-s64wAEy8998cHZ4,247
16
+ bayinx/machinery/variational/meanfield.py,sha256=stI96DNhHNIROnr8rLNEN9SN_0lXqkwit0KNto34Q6A,3889
17
+ bayinx/machinery/variational/normalizing_flow.py,sha256=qypgPq9vIqSIJNOHCDaN-hvwFfttNQ5_yXqvmi5hslI,4796
18
+ bayinx/machinery/variational/standard.py,sha256=IQdNd5QIE8u3zcOw7K4EW69lIQ0ZUGGDvwZVyvrYHxA,1739
19
+ bayinx/machinery/variational/flows/__init__.py,sha256=VGh-ffuUfMso_0JxwGCJQ2yVnFJdOrkFsSnorojQldY,213
20
+ bayinx/machinery/variational/flows/affine.py,sha256=TPyUUPRoSkyDMwGO5wtq-Ei8DAUvlb_N6JCk7uPlbJQ,1748
21
+ bayinx/machinery/variational/flows/planar.py,sha256=rJ1XpqoWzig_5Udq6oCh5JV4ptlTRLRS7tb9DCX22lE,2013
22
+ bayinx/machinery/variational/flows/radial.py,sha256=NF1tCd_PH6m8eqjJkom2c30sRUQ04Vf8zeRx_RQCDcg,2526
23
+ bayinx/machinery/variational/flows/sylvester.py,sha256=DeZl4Fkz9XpCGsfDcjS0eWlrMR0xMO0MPfJvoDhixSA,2019
24
+ bayinx-0.2.2.dist-info/METADATA,sha256=gNOkUv-EdtqnLp154L29ycWd4t2c2wsRmyqhNXYc1C8,3099
25
+ bayinx-0.2.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
+ bayinx-0.2.2.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: bayinx
3
- Version: 0.1.0
4
- Summary: A personal library for Bayesian inference
5
- Requires-Python: >=3.12
6
- Requires-Dist: baycompx
7
- Requires-Dist: jax>=0.4.38
8
- Requires-Dist: jaxtyping>=0.2.36
@@ -1,5 +0,0 @@
1
- bayinx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx-0.1.0.dist-info/METADATA,sha256=vo6K-1pFW6eW-SeCGyfJSxjkJ92h21qxUgLeVCL8C3Q,209
4
- bayinx-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- bayinx-0.1.0.dist-info/RECORD,,
File without changes