bayinx 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,16 +9,12 @@ from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
9
9
  class Constraint(eqx.Module):
10
10
  """
11
11
  Abstract base class for defining parameter constraints.
12
-
13
- Subclasses should implement the `constrain` method to apply the
14
- transformation and compute the ladj adjustment.
15
12
  """
13
+
16
14
  @abstractmethod
17
15
  def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
18
16
  """
19
- Applies the constraining transformation to an unconstrained input
20
- and computes the log absolute determinant of the Jacobian (ladj)
21
- of this transformation.
17
+ Applies the constraining transformation to an unconstrained input and computes the log-absolute-jacobian of the transformation.
22
18
 
23
19
  # Parameters
24
20
  - `x`: The unconstrained JAX Array-like input.
@@ -26,7 +22,7 @@ class Constraint(eqx.Module):
26
22
  # Returns
27
23
  A tuple containing:
28
24
  - The constrained JAX Array.
29
- - A scalar JAX Array representing the ladj of the transformation.
25
+ - A scalar JAX Array representing the laj of the transformation.
30
26
  """
31
27
  pass
32
28
 
@@ -35,6 +31,7 @@ class LowerBound(Constraint):
35
31
  """
36
32
  Enforces a lower bound on the parameter.
37
33
  """
34
+
38
35
  lb: ScalarLike
39
36
 
40
37
  def __init__(self, lb: ScalarLike):
@@ -42,7 +39,7 @@ class LowerBound(Constraint):
42
39
 
43
40
  def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
44
41
  """
45
- Applies the lower bound constraint and computes the ladj.
42
+ Applies the lower bound constraint and computes the laj.
46
43
 
47
44
  # Parameters
48
45
  - `x`: The unconstrained JAX Array-like input.
@@ -50,10 +47,10 @@ class LowerBound(Constraint):
50
47
  # Parameters
51
48
  A tuple containing:
52
49
  - The constrained JAX Array (x > self.lb).
53
- - A scalar JAX Array representing the ladj of the transformation.
50
+ - A scalar JAX Array representing the laj of the transformation.
54
51
  """
55
52
  # Compute transformation adjustment
56
- ladj = jnp.sum(x)
53
+ ladj: Scalar = jnp.sum(x)
57
54
 
58
55
  # Compute transformation
59
56
  x = jnp.exp(x) + self.lb
bayinx/core/flow.py CHANGED
@@ -8,11 +8,11 @@ from jaxtyping import Array, Float
8
8
 
9
9
  class Flow(eqx.Module):
10
10
  """
11
- A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
11
+ An abstract base class for a flow(of a normalizing flow).
12
12
 
13
13
  # Attributes
14
14
  - `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
15
- - `constraints`: A dictionary of functions that constrain their corresponding parameter.
15
+ - `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
16
16
  """
17
17
 
18
18
  params: Dict[str, Float[Array, "..."]]
@@ -28,10 +28,10 @@ class Flow(eqx.Module):
28
28
  @abstractmethod
29
29
  def adjust_density(self, draws: Array) -> Tuple[Array, Array]:
30
30
  """
31
- Computes the log-absolute-determinant of the Jacobian at `draws` and applies the forward transformation.
31
+ Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
32
32
 
33
33
  # Returns
34
- A tuple of JAX Arrays containing the log-absolute-determinant of the Jacobians and transformed draws.
34
+ A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
35
35
  """
36
36
  pass
37
37
 
bayinx/core/model.py CHANGED
@@ -11,11 +11,11 @@ from bayinx.core.constraints import Constraint
11
11
 
12
12
  class Model(eqx.Module):
13
13
  """
14
- A superclass used to define probabilistic models.
14
+ An abstract base class used to define probabilistic models.
15
15
 
16
16
  # Attributes
17
17
  - `params`: A dictionary of JAX Arrays representing parameters of the model.
18
- - `constraints`: A dictionary of functions that constrain their corresponding parameter.
18
+ - `constraints`: A dictionary of constraints.
19
19
  """
20
20
 
21
21
  params: Dict[str, Array]
@@ -63,7 +63,7 @@ class Model(eqx.Module):
63
63
 
64
64
  return t_params, target
65
65
 
66
-
66
+ # Add default transform method
67
67
  def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
68
68
  """
69
69
  Apply a custom transformation to `params` if needed.
@@ -8,7 +8,7 @@ import jax.lax as lax
8
8
  import jax.numpy as jnp
9
9
  import jax.random as jr
10
10
  import optax as opx
11
- from jaxtyping import Array, Float, Key, PyTree, Scalar
11
+ from jaxtyping import Array, Key, PyTree, Scalar
12
12
  from optax import GradientTransformation, OptState, Schedule
13
13
 
14
14
  from bayinx.core import Model
@@ -16,16 +16,23 @@ from bayinx.core import Model
16
16
 
17
17
  class Variational(eqx.Module):
18
18
  """
19
- A superclass used to define variational methods.
19
+ An abstract base class used to define variational methods.
20
20
 
21
21
  # Attributes
22
22
  - `_unflatten`: A static function to transform draws from the variational distribution back to a `Model`.
23
23
  - `_constraints`: A static partitioned `Model` with the constraints of the `Model` used to initialize the `Variational` object.
24
24
  """
25
25
 
26
- _unflatten: Callable[[Float[Array, "..."]], Model]
26
+ _unflatten: Callable[[Array], Model]
27
27
  _constraints: Model
28
28
 
29
+ @abstractmethod
30
+ def filter_spec(self):
31
+ """
32
+ Filter specification for dynamic and static components of the `Variational`.
33
+ """
34
+ pass
35
+
29
36
  @abstractmethod
30
37
  def sample(self, n: int, key: Key) -> Array:
31
38
  """
@@ -54,13 +61,6 @@ class Variational(eqx.Module):
54
61
  """
55
62
  pass
56
63
 
57
- @abstractmethod
58
- def filter_spec(self):
59
- """
60
- Filter specification for dynamic and static components of the `Variational`.
61
- """
62
- pass
63
-
64
64
  @eqx.filter_jit
65
65
  @partial(jax.vmap, in_axes=(None, 0, None))
66
66
  def eval_model(self, draws: Array, data: Any = None) -> Array:
@@ -135,7 +135,7 @@ class Variational(eqx.Module):
135
135
  # Update PRNG key
136
136
  key, _ = jr.split(key)
137
137
 
138
- # Combine variational
138
+ # Reconstruct variational
139
139
  vari = eqx.combine(dyn, static)
140
140
 
141
141
  # Compute gradient of the ELBO
bayinx/dists/normal.py CHANGED
@@ -19,7 +19,7 @@ def prob(
19
19
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
20
20
  """
21
21
 
22
- return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / ( # pyright: ignore
22
+ return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / ( # pyright: ignore
23
23
  sigma * _lax.sqrt(2.0 * __PI)
24
24
  )
25
25
 
@@ -39,7 +39,9 @@ def logprob(
39
39
  The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
40
40
  """
41
41
 
42
- return -_lax.log(sigma * _lax.sqrt(2.0 * __PI)) - 0.5 * _lax.square((x - mu) / sigma) # pyright: ignore
42
+ return -_lax.log(sigma * _lax.sqrt(2.0 * __PI)) - 0.5 * _lax.square(
43
+ (x - mu) / sigma # pyright: ignore
44
+ )
43
45
 
44
46
 
45
47
  def uprob(
@@ -57,7 +59,7 @@ def uprob(
57
59
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
58
60
  """
59
61
 
60
- return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma # pyright: ignore
62
+ return _lax.exp(-0.5 * _lax.square((x - mu) / sigma)) / sigma # pyright: ignore
61
63
 
62
64
 
63
65
  def ulogprob(
@@ -75,4 +77,4 @@ def ulogprob(
75
77
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
76
78
  """
77
79
 
78
- return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma) # pyright: ignore
80
+ return -_lax.log(sigma) - 0.5 * _lax.square((x - mu) / sigma) # pyright: ignore
bayinx/dists/uniform.py CHANGED
@@ -18,7 +18,7 @@ def prob(
18
18
  The PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
19
19
  """
20
20
 
21
- return 1.0 / (ub - lb) # pyright: ignore
21
+ return 1.0 / (ub - lb) # pyright: ignore
22
22
 
23
23
 
24
24
  def logprob(
@@ -36,7 +36,7 @@ def logprob(
36
36
  The log of the PDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
37
37
  """
38
38
 
39
- return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
39
+ return _lax.log(1.0) - _lax.log(ub - lb) # pyright: ignore
40
40
 
41
41
 
42
42
  def uprob(
@@ -54,7 +54,7 @@ def uprob(
54
54
  The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
55
55
  """
56
56
 
57
- return jnp.ones(jnp.broadcast_arrays(x,lb,ub))
57
+ return jnp.ones(jnp.broadcast_arrays(x, lb, ub))
58
58
 
59
59
 
60
60
  def ulogprob(
@@ -72,4 +72,4 @@ def ulogprob(
72
72
  The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `lb`, and `ub`.
73
73
  """
74
74
 
75
- return jnp.zeros(jnp.broadcast_arrays(x,lb,ub))
75
+ return jnp.zeros(jnp.broadcast_arrays(x, lb, ub))
@@ -1,29 +1,26 @@
1
1
  from functools import partial
2
- from typing import Callable, Dict, Tuple
2
+ from typing import Tuple
3
3
 
4
4
  import equinox as eqx
5
5
  import jax
6
6
  import jax.numpy as jnp
7
- from jaxtyping import Array, Float, Scalar
7
+ from jaxtyping import Array, Scalar
8
8
 
9
9
  from bayinx.core import Flow
10
10
 
11
11
 
12
12
  class FullAffine(Flow):
13
13
  """
14
- An affine flow.
14
+ A full affine flow.
15
15
 
16
16
  # Attributes
17
17
  - `params`: A dictionary containing the JAX Arrays representing the scale and shift parameters.
18
18
  - `constraints`: A dictionary of constraining transformations.
19
19
  """
20
20
 
21
- params: Dict[str, Float[Array, "..."]]
22
- constraints: Dict[str, Callable[[Float[Array, "..."]], Float[Array, "..."]]]
23
-
24
21
  def __init__(self, dim: int):
25
22
  """
26
- Initializes an affine flow.
23
+ Initializes a full affine flow.
27
24
 
28
25
  # Parameters
29
26
  - `dim`: The dimension of the parameter space.
@@ -35,21 +32,18 @@ class FullAffine(Flow):
35
32
 
36
33
  self.constraints = {"scale": lambda m: jnp.tril(m)}
37
34
 
35
+ @eqx.filter_jit
38
36
  def transform_pars(self):
39
- # Get constrained parameters
40
37
  params = self.constrain_pars()
41
38
 
42
39
  # Extract diagonal and apply exponential
43
- diag: Array = jnp.exp(jnp.diag(params['scale']))
40
+ diag: Array = jnp.exp(jnp.diag(params["scale"]))
44
41
 
45
42
  # Fill diagonal
46
- params['scale'] = jnp.fill_diagonal(params['scale'], diag, inplace=False)
47
-
43
+ params["scale"] = jnp.fill_diagonal(params["scale"], diag, inplace=False)
48
44
 
49
45
  return params
50
46
 
51
-
52
-
53
47
  @eqx.filter_jit
54
48
  def forward(self, draws: Array) -> Array:
55
49
  params = self.transform_pars()
@@ -65,7 +59,7 @@ class FullAffine(Flow):
65
59
 
66
60
  @eqx.filter_jit
67
61
  @partial(jax.vmap, in_axes=(None, 0))
68
- def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
62
+ def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
69
63
  params = self.transform_pars()
70
64
 
71
65
  # Extract parameters
@@ -75,7 +69,7 @@ class FullAffine(Flow):
75
69
  # Compute forward transformation
76
70
  draws = draws @ scale + shift
77
71
 
78
- # Compute ladj
79
- ladj: Scalar = jnp.log(jnp.diag(scale)).sum()
72
+ # Compute laj
73
+ laj: Scalar = jnp.log(jnp.diag(scale)).sum()
80
74
 
81
- return ladj, draws
75
+ return draws, laj
@@ -53,7 +53,7 @@ class Planar(Flow):
53
53
 
54
54
  @eqx.filter_jit
55
55
  @partial(jax.vmap, in_axes=(None, 0))
56
- def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
56
+ def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
57
57
  params = self.transform_pars()
58
58
 
59
59
  # Extract parameters
@@ -67,8 +67,8 @@ class Planar(Flow):
67
67
  # Compute forward transformation
68
68
  draws = draws + u * jnp.tanh(x)
69
69
 
70
- # Compute ladj
70
+ # Compute laj
71
71
  h_prime: Scalar = 1.0 - jnp.square(jnp.tanh(x))
72
- ladj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
72
+ laj: Scalar = jnp.log(jnp.abs(1.0 + h_prime * u.dot(w)))
73
73
 
74
- return ladj, draws
74
+ return draws, laj
@@ -66,7 +66,7 @@ class Radial(Flow):
66
66
 
67
67
  @partial(jax.vmap, in_axes=(None, 0))
68
68
  @eqx.filter_jit
69
- def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
69
+ def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
70
70
  params = self.transform_pars()
71
71
 
72
72
  # Extract parameters
@@ -84,11 +84,11 @@ class Radial(Flow):
84
84
  draws = draws + (x) * (draws - center)
85
85
 
86
86
  # Compute density adjustment
87
- ladj = jnp.log(
87
+ laj = jnp.log(
88
88
  jnp.abs(
89
89
  (1.0 + alpha * beta / (alpha + r) ** 2.0)
90
90
  * (1.0 + x) ** (center.size - 1.0)
91
91
  )
92
92
  )
93
93
 
94
- return ladj, draws
94
+ return draws, laj
@@ -20,8 +20,8 @@ class MeanField(Variational):
20
20
  """
21
21
 
22
22
  var_params: Dict[str, Float[Array, "..."]]
23
- _unflatten: Callable[[Float[Array, "..."]], Model] = eqx.field(static=True)
24
- _constraints: Model = eqx.field(static=True)
23
+ _unflatten: Callable[[Float[Array, "..."]], Model]
24
+ _constraints: Model
25
25
 
26
26
  def __init__(self, model: Model):
27
27
  """
@@ -63,25 +63,24 @@ class MeanField(Variational):
63
63
 
64
64
  @eqx.filter_jit
65
65
  def filter_spec(self):
66
+ # Generate empty specification
66
67
  filter_spec = jtu.tree_map(lambda _: False, self)
68
+
69
+ # Specify variational parameters
67
70
  filter_spec = eqx.tree_at(
68
71
  lambda mf: mf.var_params,
69
72
  filter_spec,
70
73
  replace=True,
71
74
  )
75
+
72
76
  return filter_spec
73
77
 
74
78
  @eqx.filter_jit
75
79
  def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
76
- """
77
- Estimate the ELBO and its gradient(w.r.t the variational parameters).
78
- """
79
- # Partition variational
80
80
  dyn, static = eqx.partition(self, self.filter_spec())
81
81
 
82
82
  @eqx.filter_jit
83
83
  def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
84
- # Combine
85
84
  vari = eqx.combine(dyn, static)
86
85
 
87
86
  # Sample draws from variational distribution
@@ -100,7 +99,6 @@ class MeanField(Variational):
100
99
 
101
100
  @eqx.filter_jit
102
101
  def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
103
- # Partition
104
102
  dyn, static = eqx.partition(self, self.filter_spec())
105
103
 
106
104
  @eqx.filter_grad
@@ -1,11 +1,11 @@
1
- from typing import Any, Callable, Self, Tuple
1
+ from typing import Any, Self, Tuple
2
2
 
3
3
  import equinox as eqx
4
4
  import jax.flatten_util as jfu
5
5
  import jax.numpy as jnp
6
6
  import jax.random as jr
7
7
  import jax.tree_util as jtu
8
- from jaxtyping import Array, Float, Key, Scalar
8
+ from jaxtyping import Array, Key, Scalar
9
9
 
10
10
  from bayinx.core import Flow, Model, Variational
11
11
 
@@ -17,14 +17,11 @@ class NormalizingFlow(Variational):
17
17
 
18
18
  # Attributes
19
19
  - `base`: A base variational distribution.
20
- - `flows`: An ordered collection of continuously parameterized
21
- diffeomorphisms.
20
+ - `flows`: An ordered collection of continuously parameterized diffeomorphisms.
22
21
  """
23
22
 
24
23
  flows: list[Flow]
25
24
  base: Variational
26
- _unflatten: Callable[[Float[Array, "..."]], Model]
27
- _constraints: Model
28
25
 
29
26
  def __init__(self, base: Variational, flows: list[Flow], model: Model):
30
27
  """
@@ -44,6 +41,19 @@ class NormalizingFlow(Variational):
44
41
  self.base = base
45
42
  self.flows = flows
46
43
 
44
+ def filter_spec(self):
45
+ # Generate empty specification
46
+ filter_spec = jtu.tree_map(lambda _: False, self)
47
+
48
+ # Specify variational parameters based on each flow's filter spec.
49
+ filter_spec = eqx.tree_at(
50
+ lambda vari: vari.flows,
51
+ filter_spec,
52
+ replace=[flow.filter_spec() for flow in self.flows],
53
+ )
54
+
55
+ return filter_spec
56
+
47
57
  @eqx.filter_jit
48
58
  def sample(self, n: int, key: Key = jr.PRNGKey(0)):
49
59
  """
@@ -65,19 +75,18 @@ class NormalizingFlow(Variational):
65
75
 
66
76
  for map in self.flows:
67
77
  # Compute adjustment
68
- ladj, draws = map.adjust_density(draws)
78
+ laj, draws = map.adjust_density(draws)
69
79
 
70
80
  # Adjust variational density
71
- variational_evals = variational_evals - ladj
81
+ variational_evals = variational_evals - laj
72
82
 
73
83
  return variational_evals
74
84
 
75
85
  @eqx.filter_jit
76
86
  def __eval(self, draws: Array, data=None) -> Tuple[Array, Array]:
77
87
  """
78
- Evaluate the posterior and variational densities at the transformed
79
- `draws` to avoid extra compute when requiring variational draws for
80
- the posterior evaluation.
88
+ Evaluate the posterior and variational densities together at the
89
+ transformed `draws` to avoid extra compute.
81
90
 
82
91
  # Parameters
83
92
  - `draws`: Draws from the base variational distribution.
@@ -91,29 +100,16 @@ class NormalizingFlow(Variational):
91
100
 
92
101
  for map in self.flows:
93
102
  # Compute adjustment
94
- ladj, draws = map.adjust_density(draws)
103
+ draws, laj = map.adjust_density(draws)
95
104
 
96
105
  # Adjust variational density
97
- variational_evals = variational_evals - ladj
106
+ variational_evals = variational_evals - laj
98
107
 
99
108
  # Evaluate posterior at final variational draws
100
109
  posterior_evals = self.eval_model(draws, data)
101
110
 
102
111
  return posterior_evals, variational_evals
103
112
 
104
- def filter_spec(self):
105
- # Generate empty specification
106
- filter_spec = jtu.tree_map(lambda _: False, self)
107
-
108
- # Specify variational parameters based on each flow's filter spec.
109
- filter_spec = eqx.tree_at(
110
- lambda vari: vari.flows,
111
- filter_spec,
112
- replace=[flow.filter_spec() for flow in self.flows],
113
- )
114
-
115
- return filter_spec
116
-
117
113
  @eqx.filter_jit
118
114
  def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
119
115
  dyn, static = eqx.partition(self, self.filter_spec())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.27
3
+ Version: 0.2.28
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -0,0 +1,27 @@
1
+ bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
2
+ bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
+ bayinx/core/constraints.py,sha256=lbVs2-xjGRue17YRPGHz3s_mJ0ZiunpYowbD0QvcD-I,1525
5
+ bayinx/core/flow.py,sha256=A5Vw5t76LPasnMgghjw6ulBkIm5L2jBprusVt-tuwko,2296
6
+ bayinx/core/model.py,sha256=Z_HaFr0_-keMjG5tg3xxP3hGML7aDFIcCI8Y5dGrtM4,2145
7
+ bayinx/core/variational.py,sha256=W0747jfVJFAtMZqL3mpbtl2wfnARHln-dVBag4xZ23Y,4813
8
+ bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
10
+ bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ bayinx/dists/normal.py,sha256=3CXSgHWnuglmP8cKVUh2Yt4Rb9_LR_mwPRXDm_LuSRo,2679
14
+ bayinx/dists/uniform.py,sha256=mogFe8VuDelM9KXE6RxGek0-tuZYFrwmo_oMOPHXleA,2359
15
+ bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
16
+ bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
17
+ bayinx/mhx/vi/meanfield.py,sha256=8hM1KZ52TpRPLwiQcowsJLlQ-5nJzUEcKrtDiGrFoSs,3732
18
+ bayinx/mhx/vi/normalizing_flow.py,sha256=FvxDtqGRtaEeeF-bXCYnIEAvOOXVHKUK0oCTF9ma02Y,4622
19
+ bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
20
+ bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
21
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=Kvaa8epqaqz9tdMCnf9T_-2P3Bh_TkhA6NrilKHY93A,1886
22
+ bayinx/mhx/vi/flows/planar.py,sha256=WVj-oxcRctuoRA6KJjU63ek1ZgKNG2vI-TLN0QqjtKA,1916
23
+ bayinx/mhx/vi/flows/radial.py,sha256=Obj3SraliawIHmP14F9wRpWt34y3kscY--Izy24eCvM,2499
24
+ bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
25
+ bayinx-0.2.28.dist-info/METADATA,sha256=xe3Wlo3UlD3VuTc42ChwnPTL6lp3BZmxnuf0gnZxWv0,3058
26
+ bayinx-0.2.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
27
+ bayinx-0.2.28.dist-info/RECORD,,
bayinx/core/utils.py DELETED
@@ -1 +0,0 @@
1
-
@@ -1,28 +0,0 @@
1
- bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
2
- bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
- bayinx/core/constraints.py,sha256=Y8FJX3CkgnLQ3HXuTPGuzvLtXlKs0B7z0-YymoHgdfg,1682
5
- bayinx/core/flow.py,sha256=9swS5wh7AsIZWgG_IhQS-upcPlr7G-juaP_5rsbX6G0,2363
6
- bayinx/core/model.py,sha256=U1xBnAXnIvFJjWF-XIT8BYjP9PtoRZY_PwyhRwJf-HA,2144
7
- bayinx/core/utils.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
8
- bayinx/core/variational.py,sha256=vUZ6u5CXCHfs6ZrA8PF4PHfmUXHTK2RJGHyZ3afFfsg,4820
9
- bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
11
- bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- bayinx/dists/normal.py,sha256=rtSDi0NAObH1LGRWiPZk_6cbSVv2dOPHkgxtWn6gFgM,2662
15
- bayinx/dists/uniform.py,sha256=PSZIIc2QfNF5XYi-TLGltnr_vnAIA-MZsn1rKV8QXAo,2353
16
- bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
17
- bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
18
- bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
19
- bayinx/mhx/vi/normalizing_flow.py,sha256=nj7bpIoMJl6GTOXPxQCAsPArchbHd5vwwqMm3cLbMII,4791
20
- bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
21
- bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
22
- bayinx/mhx/vi/flows/fullaffine.py,sha256=nXcPTZ_GSxIg7tmVxag694Fl1F95SKFSyDyt-9EDC9I,2049
23
- bayinx/mhx/vi/flows/planar.py,sha256=u9ZVwEeOv4fMfwiORlseCz463atsWMuid_9crRg05Z8,1919
24
- bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
25
- bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
26
- bayinx-0.2.27.dist-info/METADATA,sha256=5RPhGKmb6wWJquxrUlyt6QXWTSPEQycu5nFVZmQN9bU,3058
27
- bayinx-0.2.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
- bayinx-0.2.27.dist-info/RECORD,,